soupstick commited on
Commit
6e36e0d
·
1 Parent(s): 7ead058

fix: route Fireworks directly w/ OpenAI client; add HF Router fallback; pydantic init

Browse files
Files changed (2) hide show
  1. llm_provider.py +132 -45
  2. requirements.txt +2 -1
llm_provider.py CHANGED
@@ -1,100 +1,187 @@
1
  from __future__ import annotations
2
  import os, logging
3
  from dotenv import load_dotenv
4
- from huggingface_hub import InferenceClient
5
  from langchain_core.language_models.chat_models import BaseChatModel
6
  from langchain.schema import HumanMessage, SystemMessage, AIMessage
7
  from langchain_core.outputs import ChatGeneration, ChatResult
8
 
 
 
 
 
9
  load_dotenv()
10
  log = logging.getLogger("fraud-analyst")
11
  logging.basicConfig(level=logging.INFO)
12
 
13
- FIREWORKS_API_KEY = os.getenv("fireworks_api_huggingface") or os.getenv("HF_TOKEN")
14
- FW_PRIMARY_MODEL = os.getenv("FW_PRIMARY_MODEL", "openai/gpt-oss-20b")
15
- FW_SECONDARY_MODEL = os.getenv("FW_SECONDARY_MODEL", "Qwen/Qwen3-Coder-30B-A3B-Instruct")
16
-
17
  SUMMARY_NOTICE = "🔌 Please connect to an inference point to generate summary."
18
 
19
- class FireworksHFChat(BaseChatModel):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  model: str
21
  api_key: str | None = None
22
  temperature: float = 0.2
23
  max_new_tokens: int = 256
24
- timeout: int = 60
25
 
26
- def __init__(self, model: str, api_key: str | None):
27
- super().__init__()
28
- self.model = model
29
- self.api_key = api_key
30
- self._client = InferenceClient(provider="fireworks-ai", api_key=self.api_key)
 
 
31
 
32
  @property
33
  def _llm_type(self) -> str:
34
- return "fireworks_hf_chat"
35
 
36
  def _convert(self, messages):
37
  out=[]
38
  for m in messages:
39
- if isinstance(m, SystemMessage):
40
- out.append({"role":"system","content":m.content})
41
- elif isinstance(m, HumanMessage):
42
- out.append({"role":"user","content":m.content})
43
- elif isinstance(m, AIMessage):
44
- out.append({"role":"assistant","content":m.content})
45
- else:
46
- out.append({"role":"user","content":str(getattr(m,"content",m))})
47
  return out
48
 
49
  def _generate(self, messages, stop=None, run_manager=None, **kwargs) -> ChatResult:
50
  if not self.api_key:
51
  gen = ChatGeneration(message=AIMessage(content=""))
52
- return ChatResult(generations=[gen], llm_output={"error": "no_api_key"})
53
  try:
54
  resp = self._client.chat.completions.create(
55
- model=self.model,
56
  messages=self._convert(messages),
 
 
57
  stream=False,
58
- max_tokens=kwargs.get("max_tokens", 256),
59
- temperature=kwargs.get("temperature", 0.2),
60
  )
61
  text = ""
62
  if hasattr(resp, "choices") and resp.choices:
63
  ch = resp.choices[0]
64
- if hasattr(ch, "message") and ch.message and getattr(ch.message, "content", None):
 
65
  text = ch.message.content
66
- elif hasattr(ch, "text") and ch.text:
67
- text = ch.text
68
  gen = ChatGeneration(message=AIMessage(content=text or ""))
69
  return ChatResult(generations=[gen], llm_output={"model": self.model})
70
  except Exception as e:
71
- log.warning(f"Fireworks call failed for {self.model}: {type(e).__name__}: {str(e)[:200]}")
72
  gen = ChatGeneration(message=AIMessage(content=""))
73
  return ChatResult(generations=[gen], llm_output={"error": str(e)})
74
 
75
- def _heartbeat(model_id: str) -> bool:
76
  if not FIREWORKS_API_KEY: return False
77
  try:
78
- client = InferenceClient(provider="fireworks-ai", api_key=FIREWORKS_API_KEY)
79
- _ = client.chat.completions.create(
80
- model=model_id,
81
- messages=[{"role":"user","content":"ping"}],
82
- stream=False,
83
- max_tokens=1,
84
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  return True
86
  except Exception as e:
87
- log.warning(f"Heartbeat failed for {model_id}: {type(e).__name__}: {str(e)[:160]}")
88
  return False
89
 
 
90
  def build_chat_llm():
91
- log.info(f"Fireworks key present: {bool(FIREWORKS_API_KEY)} len={len(FIREWORKS_API_KEY) if FIREWORKS_API_KEY else 0}")
92
- if FIREWORKS_API_KEY and _heartbeat(FW_PRIMARY_MODEL):
93
- log.info(f"Using chat model: {FW_PRIMARY_MODEL}")
94
- return FireworksHFChat(FW_PRIMARY_MODEL, FIREWORKS_API_KEY)
95
- if FIREWORKS_API_KEY and _heartbeat(FW_SECONDARY_MODEL):
96
- log.info(f"Using fallback chat model: {FW_SECONDARY_MODEL}")
97
- return FireworksHFChat(FW_SECONDARY_MODEL, FIREWORKS_API_KEY)
 
 
 
 
 
 
 
 
 
98
  log.warning("No working chat model; notice will be shown.")
99
  return None
100
 
 
1
  from __future__ import annotations
2
  import os, logging
3
  from dotenv import load_dotenv
4
+
5
  from langchain_core.language_models.chat_models import BaseChatModel
6
  from langchain.schema import HumanMessage, SystemMessage, AIMessage
7
  from langchain_core.outputs import ChatGeneration, ChatResult
8
 
9
+ # Providers
10
+ from openai import OpenAI # Fireworks OpenAI-compatible
11
+ from huggingface_hub import InferenceClient # HF Router (provider routing)
12
+
13
  load_dotenv()
14
  log = logging.getLogger("fraud-analyst")
15
  logging.basicConfig(level=logging.INFO)
16
 
 
 
 
 
17
  SUMMARY_NOTICE = "🔌 Please connect to an inference point to generate summary."
18
 
19
+ def _first_env(*names):
20
+ for n in names:
21
+ v = os.getenv(n)
22
+ if v:
23
+ return v
24
+ return None
25
+
26
+ # Secrets (your repo secret name included)
27
+ FIREWORKS_API_KEY = _first_env(
28
+ "fireworks_api_huggingface", # your HF repo secret (Fireworks key)
29
+ "FIREWORKS_API_HUGGINGFACE",
30
+ "FIREWORKS_API_KEY",
31
+ "OPENAI_API_KEY" # also works if you export FW key here
32
+ )
33
+ HF_TOKEN = _first_env("HF_TOKEN", "HUGGINGFACE_TOKEN")
34
+
35
+ # Model IDs for each route
36
+ # Fireworks (direct, OpenAI-compatible): use fully-qualified IDs
37
+ FW_PRIMARY_MODEL = os.getenv("FW_PRIMARY_MODEL", "accounts/openai/models/gpt-oss-20b")
38
+ FW_SECONDARY_MODEL = os.getenv("FW_SECONDARY_MODEL", "accounts/fireworks/models/qwen3-coder-30b-a3b-instruct")
39
+ # HF Router route (must use HF_TOKEN). For OpenAI SDK on HF Router you’d use `...:fireworks-ai`,
40
+ # but with huggingface_hub.InferenceClient+provider we pass the plain HF model id.
41
+ HF_PRIMARY_MODEL = os.getenv("HF_PRIMARY_MODEL", "openai/gpt-oss-20b")
42
+ HF_SECONDARY_MODEL = os.getenv("HF_SECONDARY_MODEL", "Qwen/Qwen3-Coder-30B-A3B-Instruct")
43
+
44
+ # ---------- Fireworks (OpenAI-compatible) driver ----------
45
+ class FireworksOpenAIChat(BaseChatModel):
46
  model: str
47
  api_key: str | None = None
48
  temperature: float = 0.2
49
  max_new_tokens: int = 256
 
50
 
51
+ def __init__(self, **data):
52
+ super().__init__(**data)
53
+ # Fireworks OpenAI-compatible endpoint
54
+ self._client = OpenAI(
55
+ base_url=os.getenv("OPENAI_API_BASE", "https://api.fireworks.ai/inference/v1"),
56
+ api_key=self.api_key,
57
+ )
58
 
59
  @property
60
  def _llm_type(self) -> str:
61
+ return "fireworks_openai_chat"
62
 
63
  def _convert(self, messages):
64
  out=[]
65
  for m in messages:
66
+ if isinstance(m, SystemMessage): out.append({"role":"system","content":m.content})
67
+ elif isinstance(m, HumanMessage): out.append({"role":"user","content":m.content})
68
+ elif isinstance(m, AIMessage): out.append({"role":"assistant","content":m.content})
69
+ else: out.append({"role":"user","content":str(getattr(m,"content",m))})
 
 
 
 
70
  return out
71
 
72
  def _generate(self, messages, stop=None, run_manager=None, **kwargs) -> ChatResult:
73
  if not self.api_key:
74
  gen = ChatGeneration(message=AIMessage(content=""))
75
+ return ChatResult(generations=[gen], llm_output={"error":"no_api_key"})
76
  try:
77
  resp = self._client.chat.completions.create(
78
+ model=self.model, # e.g., accounts/openai/models/gpt-oss-20b
79
  messages=self._convert(messages),
80
+ temperature=kwargs.get("temperature", self.temperature),
81
+ max_tokens=kwargs.get("max_tokens", self.max_new_tokens),
82
  stream=False,
 
 
83
  )
84
  text = ""
85
  if hasattr(resp, "choices") and resp.choices:
86
  ch = resp.choices[0]
87
+ # OpenAI SDK v1 returns .message
88
+ if getattr(ch, "message", None) and getattr(ch.message, "content", None):
89
  text = ch.message.content
 
 
90
  gen = ChatGeneration(message=AIMessage(content=text or ""))
91
  return ChatResult(generations=[gen], llm_output={"model": self.model})
92
  except Exception as e:
93
+ log.warning(f"Fireworks(OpenAI) call failed for {self.model}: {type(e).__name__}: {str(e)[:200]}")
94
  gen = ChatGeneration(message=AIMessage(content=""))
95
  return ChatResult(generations=[gen], llm_output={"error": str(e)})
96
 
97
+ def _heartbeat_fireworks(model_id: str) -> bool:
98
  if not FIREWORKS_API_KEY: return False
99
  try:
100
+ cli = OpenAI(base_url="https://api.fireworks.ai/inference/v1", api_key=FIREWORKS_API_KEY)
101
+ _ = cli.chat.completions.create(model=model_id, messages=[{"role":"user","content":"ping"}], max_tokens=1)
102
+ return True
103
+ except Exception as e:
104
+ log.warning(f"FW heartbeat failed for {model_id}: {type(e).__name__}: {str(e)[:200]}")
105
+ return False
106
+
107
+ # ---------- HF Router (provider routing) driver ----------
108
+ class HFRouterChat(BaseChatModel):
109
+ model: str
110
+ hf_token: str | None = None
111
+ temperature: float = 0.2
112
+ max_new_tokens: int = 256
113
+
114
+ def __init__(self, **data):
115
+ super().__init__(**data)
116
+ self._client = InferenceClient(provider="fireworks-ai", api_key=self.hf_token)
117
+
118
+ @property
119
+ def _llm_type(self) -> str:
120
+ return "hf_router_fireworks"
121
+
122
+ def _convert(self, messages):
123
+ out=[]
124
+ for m in messages:
125
+ if isinstance(m, SystemMessage): out.append({"role":"system","content":m.content})
126
+ elif isinstance(m, HumanMessage): out.append({"role":"user","content":m.content})
127
+ elif isinstance(m, AIMessage): out.append({"role":"assistant","content":m.content})
128
+ else: out.append({"role":"user","content":str(getattr(m,"content",m))})
129
+ return out
130
+
131
+ def _generate(self, messages, stop=None, run_manager=None, **kwargs) -> ChatResult:
132
+ if not self.hf_token:
133
+ gen = ChatGeneration(message=AIMessage(content=""))
134
+ return ChatResult(generations=[gen], llm_output={"error":"no_hf_token"})
135
+ try:
136
+ resp = self._client.chat.completions.create(
137
+ model=self.model, # e.g., "openai/gpt-oss-20b"
138
+ messages=self._convert(messages),
139
+ stream=False,
140
+ max_tokens=kwargs.get("max_tokens", self.max_new_tokens),
141
+ temperature=kwargs.get("temperature", self.temperature),
142
+ )
143
+ text = ""
144
+ if hasattr(resp, "choices") and resp.choices:
145
+ ch = resp.choices[0]
146
+ if getattr(ch, "message", None) and getattr(ch.message, "content", None):
147
+ text = ch.message.content
148
+ elif getattr(ch, "text", None):
149
+ text = ch.text
150
+ gen = ChatGeneration(message=AIMessage(content=text or ""))
151
+ return ChatResult(generations=[gen], llm_output={"model": self.model})
152
+ except Exception as e:
153
+ log.warning(f"HF Router call failed for {self.model}: {type(e).__name__}: {str(e)[:200]}")
154
+ gen = ChatGeneration(message=AIMessage(content=""))
155
+ return ChatResult(generations=[gen], llm_output={"error": str(e)})
156
+
157
+ def _heartbeat_hf_router(model_id: str) -> bool:
158
+ if not HF_TOKEN: return False
159
+ try:
160
+ cli = InferenceClient(provider="fireworks-ai", api_key=HF_TOKEN)
161
+ _ = cli.chat.completions.create(model=model_id, messages=[{"role":"user","content":"ping"}], stream=False, max_tokens=1)
162
  return True
163
  except Exception as e:
164
+ log.warning(f"HF Router heartbeat failed for {model_id}: {type(e).__name__}: {str(e)[:200]}")
165
  return False
166
 
167
+ # ---------- LLM selection ----------
168
  def build_chat_llm():
169
+ # Prefer direct Fireworks when FW key is present
170
+ if FIREWORKS_API_KEY and _heartbeat_fireworks(FW_PRIMARY_MODEL):
171
+ log.info(f"Using Fireworks chat model: {FW_PRIMARY_MODEL}")
172
+ return FireworksOpenAIChat(model=FW_PRIMARY_MODEL, api_key=FIREWORKS_API_KEY)
173
+ if FIREWORKS_API_KEY and _heartbeat_fireworks(FW_SECONDARY_MODEL):
174
+ log.info(f"Using Fireworks fallback chat model: {FW_SECONDARY_MODEL}")
175
+ return FireworksOpenAIChat(model=FW_SECONDARY_MODEL, api_key=FIREWORKS_API_KEY)
176
+
177
+ # Else try HF Router (requires HF_TOKEN)
178
+ if HF_TOKEN and _heartbeat_hf_router(HF_PRIMARY_MODEL):
179
+ log.info(f"Using HF Router chat model: {HF_PRIMARY_MODEL}")
180
+ return HFRouterChat(model=HF_PRIMARY_MODEL, hf_token=HF_TOKEN)
181
+ if HF_TOKEN and _heartbeat_hf_router(HF_SECONDARY_MODEL):
182
+ log.info(f"Using HF Router fallback chat model: {HF_SECONDARY_MODEL}")
183
+ return HFRouterChat(model=HF_SECONDARY_MODEL, hf_token=HF_TOKEN)
184
+
185
  log.warning("No working chat model; notice will be shown.")
186
  return None
187
 
requirements.txt CHANGED
@@ -7,4 +7,5 @@ langchain>=0.2
7
  langchain-community>=0.2
8
  langchain-huggingface>=0.0.3
9
  pydantic>=2
10
- python-dotenv
 
 
7
  langchain-community>=0.2
8
  langchain-huggingface>=0.0.3
9
  pydantic>=2
10
+ python-dotenv
11
+ openai>=1.43.0