Spaces:
Sleeping
Sleeping
soupstick
commited on
Commit
·
3c66c0b
1
Parent(s):
8db5e34
llm: switch to single-model Qwen (Fireworks), drop heartbeats; throttle to avoid 429
Browse files- llm_provider.py +30 -160
llm_provider.py
CHANGED
@@ -1,16 +1,12 @@
|
|
1 |
from __future__ import annotations
|
2 |
"""
|
3 |
llm_provider.py
|
4 |
-
-
|
5 |
-
-
|
6 |
-
-
|
7 |
"""
|
8 |
|
9 |
-
import os
|
10 |
-
import time
|
11 |
-
import random
|
12 |
-
import threading
|
13 |
-
import logging
|
14 |
from typing import List
|
15 |
|
16 |
from dotenv import load_dotenv
|
@@ -21,7 +17,7 @@ from langchain_core.outputs import ChatGeneration, ChatResult
|
|
21 |
|
22 |
# Fireworks (OpenAI-compatible)
|
23 |
from openai import OpenAI
|
24 |
-
from openai import RateLimitError
|
25 |
|
26 |
# HF Router (provider routing)
|
27 |
from huggingface_hub import InferenceClient
|
@@ -30,10 +26,8 @@ load_dotenv()
|
|
30 |
log = logging.getLogger("fraud-analyst")
|
31 |
logging.basicConfig(level=logging.INFO)
|
32 |
|
33 |
-
# --------- Public constant used by app.py when LLM is not available ----------
|
34 |
SUMMARY_NOTICE = "🔌 Please connect to an inference point to generate summary."
|
35 |
|
36 |
-
# --------- Read secrets (support your HF repo secret name) ----------
|
37 |
def _first_env(*names: List[str]):
|
38 |
for n in names:
|
39 |
v = os.getenv(n)
|
@@ -41,31 +35,25 @@ def _first_env(*names: List[str]):
|
|
41 |
return v
|
42 |
return None
|
43 |
|
|
|
44 |
FIREWORKS_API_KEY = _first_env(
|
45 |
-
"fireworks_api_huggingface",
|
46 |
-
"
|
47 |
-
"FIREWORKS_API_KEY",
|
48 |
-
"OPENAI_API_KEY", # allow reuse
|
49 |
)
|
50 |
HF_TOKEN = _first_env("HF_TOKEN", "HUGGINGFACE_TOKEN")
|
51 |
|
52 |
-
#
|
53 |
-
|
54 |
-
|
55 |
-
FW_SECONDARY_MODEL = os.getenv("FW_SECONDARY_MODEL", "accounts/fireworks/models/qwen3-coder-30b-a3b-instruct")
|
56 |
|
57 |
-
#
|
58 |
-
|
59 |
-
|
|
|
|
|
|
|
60 |
|
61 |
-
#
|
62 |
-
MAX_NEW_TOKENS = int(os.getenv("LLM_MAX_NEW_TOKENS", "96"))
|
63 |
-
TEMP = float(os.getenv("LLM_TEMPERATURE", "0.2"))
|
64 |
-
MAX_RETRIES = int(os.getenv("LLM_MAX_RETRIES", "2"))
|
65 |
-
MIN_INTERVAL_S = float(os.getenv("LLM_MIN_INTERVAL_S", "1.0"))
|
66 |
-
MAX_CONCURRENCY = int(os.getenv("LLM_MAX_CONCURRENCY", "1"))
|
67 |
-
|
68 |
-
# Ensure OpenAI SDK itself doesn't add extra retries
|
69 |
os.environ.setdefault("OPENAI_MAX_RETRIES", "0")
|
70 |
|
71 |
# Global throttle across all instances
|
@@ -74,7 +62,7 @@ _last_call_ts = 0.0
|
|
74 |
_ts_lock = threading.Lock()
|
75 |
|
76 |
def _pace():
|
77 |
-
"""Global pacing to avoid hitting
|
78 |
global _last_call_ts
|
79 |
with _ts_lock:
|
80 |
now = time.monotonic()
|
@@ -99,11 +87,8 @@ def _with_retries(fn):
|
|
99 |
# ========================== Fireworks (OpenAI-compatible) ==========================
|
100 |
FW_BASE = os.getenv("OPENAI_API_BASE", "https://api.fireworks.ai/inference/v1")
|
101 |
|
102 |
-
class _ChatIncompatible(Exception):
|
103 |
-
"""Raise to switch a model to /completions pathway (e.g., gpt-oss-20b)."""
|
104 |
-
|
105 |
class FireworksOpenAIChat(BaseChatModel):
|
106 |
-
"""
|
107 |
model: str
|
108 |
api_key: str | None = None
|
109 |
temperature: float = TEMP
|
@@ -149,102 +134,15 @@ class FireworksOpenAIChat(BaseChatModel):
|
|
149 |
text = ch.message.content
|
150 |
return ChatResult(generations=[ChatGeneration(message=AIMessage(content=text or ""))],
|
151 |
llm_output={"model": self.model, "endpoint": "chat"})
|
152 |
-
except BadRequestError as e:
|
153 |
-
# Fireworks returns this when a model is not chat-friendly.
|
154 |
-
if "Failed to format non-streaming choice" in str(e) or "invalid_request_error" in str(e):
|
155 |
-
raise _ChatIncompatible()
|
156 |
-
log.warning(f"FW chat BadRequest for {self.model}: {str(e)[:200]}")
|
157 |
-
return ChatResult(generations=[ChatGeneration(message=AIMessage(content=""))],
|
158 |
-
llm_output={"error": str(e)})
|
159 |
-
except Exception as e:
|
160 |
-
log.warning(f"FW chat failed for {self.model}: {type(e).__name__}: {str(e)[:200]}")
|
161 |
-
return ChatResult(generations=[ChatGeneration(message=AIMessage(content=""))],
|
162 |
-
llm_output={"error": str(e)})
|
163 |
-
|
164 |
-
class FireworksOpenAICompletionChat(BaseChatModel):
|
165 |
-
"""Wrap /completions as a chat model (robust for gpt-oss-20b)."""
|
166 |
-
model: str
|
167 |
-
api_key: str | None = None
|
168 |
-
temperature: float = TEMP
|
169 |
-
max_new_tokens: int = MAX_NEW_TOKENS
|
170 |
-
|
171 |
-
def __init__(self, **data):
|
172 |
-
super().__init__(**data)
|
173 |
-
self._client = OpenAI(base_url=FW_BASE, api_key=self.api_key, max_retries=0)
|
174 |
-
|
175 |
-
@property
|
176 |
-
def _llm_type(self) -> str:
|
177 |
-
return "fireworks_openai_completion_chat"
|
178 |
-
|
179 |
-
def _to_prompt(self, messages) -> str:
|
180 |
-
parts=[]
|
181 |
-
for m in messages:
|
182 |
-
if isinstance(m, SystemMessage):
|
183 |
-
parts.append(f"[System] {m.content}")
|
184 |
-
elif isinstance(m, HumanMessage):
|
185 |
-
parts.append(f"[User] {m.content}")
|
186 |
-
elif isinstance(m, AIMessage):
|
187 |
-
parts.append(f"[Assistant] {m.content}")
|
188 |
-
else:
|
189 |
-
parts.append(f"[User] {str(getattr(m,'content',m))}")
|
190 |
-
parts.append("[Assistant]")
|
191 |
-
return "\n".join(parts)
|
192 |
-
|
193 |
-
def _generate(self, messages, stop=None, run_manager=None, **kwargs) -> ChatResult:
|
194 |
-
if not self.api_key:
|
195 |
-
return ChatResult(generations=[ChatGeneration(message=AIMessage(content=""))],
|
196 |
-
llm_output={"error": "no_api_key"})
|
197 |
-
prompt = self._to_prompt(messages)
|
198 |
-
def _call():
|
199 |
-
with _CALL_LOCK:
|
200 |
-
_pace()
|
201 |
-
return self._client.completions.create(
|
202 |
-
model=self.model,
|
203 |
-
prompt=prompt,
|
204 |
-
temperature=kwargs.get("temperature", self.temperature),
|
205 |
-
max_tokens=kwargs.get("max_tokens", self.max_new_tokens),
|
206 |
-
)
|
207 |
-
try:
|
208 |
-
resp = _with_retries(_call)
|
209 |
-
text = ""
|
210 |
-
if getattr(resp, "choices", None):
|
211 |
-
ch = resp.choices[0]
|
212 |
-
if getattr(ch, "text", None):
|
213 |
-
text = ch.text
|
214 |
-
return ChatResult(generations=[ChatGeneration(message=AIMessage(content=text or ""))],
|
215 |
-
llm_output={"model": self.model, "endpoint": "completions"})
|
216 |
except Exception as e:
|
217 |
-
|
|
|
218 |
return ChatResult(generations=[ChatGeneration(message=AIMessage(content=""))],
|
219 |
llm_output={"error": str(e)})
|
220 |
|
221 |
-
def _heartbeat_fw_chat(model_id: str) -> bool:
|
222 |
-
if not FIREWORKS_API_KEY:
|
223 |
-
return False
|
224 |
-
try:
|
225 |
-
cli = OpenAI(base_url=FW_BASE, api_key=FIREWORKS_API_KEY, max_retries=0)
|
226 |
-
_ = cli.chat.completions.create(model=model_id, messages=[{"role":"user","content":"ping"}], max_tokens=1)
|
227 |
-
return True
|
228 |
-
except BadRequestError as e:
|
229 |
-
# chat-incompatible → let caller try completions heartbeat
|
230 |
-
if "Failed to format non-streaming choice" in str(e) or "invalid_request_error" in str(e):
|
231 |
-
return False
|
232 |
-
return False
|
233 |
-
except Exception:
|
234 |
-
return False
|
235 |
-
|
236 |
-
def _heartbeat_fw_completion(model_id: str) -> bool:
|
237 |
-
if not FIREWORKS_API_KEY:
|
238 |
-
return False
|
239 |
-
try:
|
240 |
-
cli = OpenAI(base_url=FW_BASE, api_key=FIREWORKS_API_KEY, max_retries=0)
|
241 |
-
_ = cli.completions.create(model=model_id, prompt="ping", max_tokens=1)
|
242 |
-
return True
|
243 |
-
except Exception:
|
244 |
-
return False
|
245 |
-
|
246 |
# ========================== HF Router (provider="fireworks-ai") ==========================
|
247 |
class HFRouterChat(BaseChatModel):
|
|
|
248 |
model: str
|
249 |
hf_token: str | None = None
|
250 |
temperature: float = TEMP
|
@@ -252,7 +150,6 @@ class HFRouterChat(BaseChatModel):
|
|
252 |
|
253 |
def __init__(self, **data):
|
254 |
super().__init__(**data)
|
255 |
-
# This reaches Fireworks through HF Router. Needs HF_TOKEN.
|
256 |
self._client = InferenceClient(provider="fireworks-ai", api_key=self.hf_token)
|
257 |
|
258 |
@property
|
@@ -276,7 +173,7 @@ class HFRouterChat(BaseChatModel):
|
|
276 |
with _CALL_LOCK:
|
277 |
_pace()
|
278 |
return self._client.chat.completions.create(
|
279 |
-
model=self.model, #
|
280 |
messages=self._convert(messages),
|
281 |
stream=False,
|
282 |
max_tokens=kwargs.get("max_tokens", self.max_new_tokens),
|
@@ -294,50 +191,23 @@ class HFRouterChat(BaseChatModel):
|
|
294 |
return ChatResult(generations=[ChatGeneration(message=AIMessage(content=text or ""))],
|
295 |
llm_output={"model": self.model})
|
296 |
except Exception as e:
|
297 |
-
|
298 |
return ChatResult(generations=[ChatGeneration(message=AIMessage(content=""))],
|
299 |
llm_output={"error": str(e)})
|
300 |
|
301 |
-
def _heartbeat_hf_router(model_id: str) -> bool:
|
302 |
-
if not HF_TOKEN:
|
303 |
-
return False
|
304 |
-
try:
|
305 |
-
cli = InferenceClient(provider="fireworks-ai", api_key=HF_TOKEN)
|
306 |
-
_ = cli.chat.completions.create(model=model_id, messages=[{"role":"user","content":"ping"}], stream=False, max_tokens=1)
|
307 |
-
return True
|
308 |
-
except Exception:
|
309 |
-
return False
|
310 |
-
|
311 |
# =============================== Selection ===============================
|
312 |
def build_chat_llm():
|
313 |
-
# Prefer Fireworks direct
|
314 |
if FIREWORKS_API_KEY:
|
315 |
-
|
316 |
-
|
317 |
-
log.info(f"Using Fireworks chat model: {FW_PRIMARY_MODEL}")
|
318 |
-
return FireworksOpenAIChat(model=FW_PRIMARY_MODEL, api_key=FIREWORKS_API_KEY)
|
319 |
-
elif _heartbeat_fw_completion(FW_PRIMARY_MODEL):
|
320 |
-
log.info(f"Using Fireworks COMPLETION-wrapped model: {FW_PRIMARY_MODEL}")
|
321 |
-
return FireworksOpenAICompletionChat(model=FW_PRIMARY_MODEL, api_key=FIREWORKS_API_KEY)
|
322 |
-
|
323 |
-
# Secondary chain
|
324 |
-
if _heartbeat_fw_chat(FW_SECONDARY_MODEL):
|
325 |
-
log.info(f"Using Fireworks chat model (fallback): {FW_SECONDARY_MODEL}")
|
326 |
-
return FireworksOpenAIChat(model=FW_SECONDARY_MODEL, api_key=FIREWORKS_API_KEY)
|
327 |
-
elif _heartbeat_fw_completion(FW_SECONDARY_MODEL):
|
328 |
-
log.info(f"Using Fireworks COMPLETION-wrapped model (fallback): {FW_SECONDARY_MODEL}")
|
329 |
-
return FireworksOpenAICompletionChat(model=FW_SECONDARY_MODEL, api_key=FIREWORKS_API_KEY)
|
330 |
|
331 |
-
#
|
332 |
-
if HF_TOKEN
|
333 |
log.info(f"Using HF Router chat model: {HF_PRIMARY_MODEL}")
|
334 |
return HFRouterChat(model=HF_PRIMARY_MODEL, hf_token=HF_TOKEN)
|
335 |
-
if HF_TOKEN and _heartbeat_hf_router(HF_SECONDARY_MODEL):
|
336 |
-
log.info(f"Using HF Router chat model (fallback): {HF_SECONDARY_MODEL}")
|
337 |
-
return HFRouterChat(model=HF_SECONDARY_MODEL, hf_token=HF_TOKEN)
|
338 |
|
339 |
log.warning("No working chat model; notice will be shown.")
|
340 |
return None
|
341 |
|
342 |
-
# Singleton used by app.py and agent.py
|
343 |
CHAT_LLM = build_chat_llm()
|
|
|
1 |
from __future__ import annotations
|
2 |
"""
|
3 |
llm_provider.py
|
4 |
+
- Single provider: Fireworks (OpenAI-compatible) with Qwen3-Coder-30B-A3B-Instruct
|
5 |
+
- Optional fallback: Hugging Face Inference Router (provider="fireworks-ai") if HF_TOKEN present
|
6 |
+
- No heartbeats (avoid rate hits); conservative throttling & retries
|
7 |
"""
|
8 |
|
9 |
+
import os, time, random, threading, logging
|
|
|
|
|
|
|
|
|
10 |
from typing import List
|
11 |
|
12 |
from dotenv import load_dotenv
|
|
|
17 |
|
18 |
# Fireworks (OpenAI-compatible)
|
19 |
from openai import OpenAI
|
20 |
+
from openai import RateLimitError
|
21 |
|
22 |
# HF Router (provider routing)
|
23 |
from huggingface_hub import InferenceClient
|
|
|
26 |
log = logging.getLogger("fraud-analyst")
|
27 |
logging.basicConfig(level=logging.INFO)
|
28 |
|
|
|
29 |
SUMMARY_NOTICE = "🔌 Please connect to an inference point to generate summary."
|
30 |
|
|
|
31 |
def _first_env(*names: List[str]):
|
32 |
for n in names:
|
33 |
v = os.getenv(n)
|
|
|
35 |
return v
|
36 |
return None
|
37 |
|
38 |
+
# Secrets
|
39 |
FIREWORKS_API_KEY = _first_env(
|
40 |
+
"fireworks_api_huggingface", "FIREWORKS_API_HUGGINGFACE",
|
41 |
+
"FIREWORKS_API_KEY", "OPENAI_API_KEY"
|
|
|
|
|
42 |
)
|
43 |
HF_TOKEN = _first_env("HF_TOKEN", "HUGGINGFACE_TOKEN")
|
44 |
|
45 |
+
# Models (Qwen only)
|
46 |
+
FW_PRIMARY_MODEL = os.getenv("FW_PRIMARY_MODEL", "accounts/fireworks/models/qwen3-coder-30b-a3b-instruct")
|
47 |
+
HF_PRIMARY_MODEL = os.getenv("HF_PRIMARY_MODEL", "Qwen/Qwen3-Coder-30B-A3B-Instruct")
|
|
|
48 |
|
49 |
+
# Throttle / Retry (conservative for demo)
|
50 |
+
MAX_NEW_TOKENS = int(os.getenv("LLM_MAX_NEW_TOKENS", "96"))
|
51 |
+
TEMP = float(os.getenv("LLM_TEMPERATURE", "0.2"))
|
52 |
+
MAX_RETRIES = int(os.getenv("LLM_MAX_RETRIES", "2"))
|
53 |
+
MIN_INTERVAL_S = float(os.getenv("LLM_MIN_INTERVAL_S", "1.0"))
|
54 |
+
MAX_CONCURRENCY = int(os.getenv("LLM_MAX_CONCURRENCY", "1"))
|
55 |
|
56 |
+
# Ensure OpenAI SDK itself doesn't also retry
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
os.environ.setdefault("OPENAI_MAX_RETRIES", "0")
|
58 |
|
59 |
# Global throttle across all instances
|
|
|
62 |
_ts_lock = threading.Lock()
|
63 |
|
64 |
def _pace():
|
65 |
+
"""Global pacing to avoid hitting 429s."""
|
66 |
global _last_call_ts
|
67 |
with _ts_lock:
|
68 |
now = time.monotonic()
|
|
|
87 |
# ========================== Fireworks (OpenAI-compatible) ==========================
|
88 |
FW_BASE = os.getenv("OPENAI_API_BASE", "https://api.fireworks.ai/inference/v1")
|
89 |
|
|
|
|
|
|
|
90 |
class FireworksOpenAIChat(BaseChatModel):
|
91 |
+
"""Qwen on Fireworks via /chat/completions."""
|
92 |
model: str
|
93 |
api_key: str | None = None
|
94 |
temperature: float = TEMP
|
|
|
134 |
text = ch.message.content
|
135 |
return ChatResult(generations=[ChatGeneration(message=AIMessage(content=text or ""))],
|
136 |
llm_output={"model": self.model, "endpoint": "chat"})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
137 |
except Exception as e:
|
138 |
+
# Return empty output; UI will show notice if needed
|
139 |
+
logging.warning(f"Fireworks(Qwen) failed: {type(e).__name__}: {str(e)[:200]}")
|
140 |
return ChatResult(generations=[ChatGeneration(message=AIMessage(content=""))],
|
141 |
llm_output={"error": str(e)})
|
142 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
143 |
# ========================== HF Router (provider="fireworks-ai") ==========================
|
144 |
class HFRouterChat(BaseChatModel):
|
145 |
+
"""Fallback only if FIREWORKS_API_KEY is absent but HF_TOKEN is set."""
|
146 |
model: str
|
147 |
hf_token: str | None = None
|
148 |
temperature: float = TEMP
|
|
|
150 |
|
151 |
def __init__(self, **data):
|
152 |
super().__init__(**data)
|
|
|
153 |
self._client = InferenceClient(provider="fireworks-ai", api_key=self.hf_token)
|
154 |
|
155 |
@property
|
|
|
173 |
with _CALL_LOCK:
|
174 |
_pace()
|
175 |
return self._client.chat.completions.create(
|
176 |
+
model=self.model, # "Qwen/Qwen3-Coder-30B-A3B-Instruct"
|
177 |
messages=self._convert(messages),
|
178 |
stream=False,
|
179 |
max_tokens=kwargs.get("max_tokens", self.max_new_tokens),
|
|
|
191 |
return ChatResult(generations=[ChatGeneration(message=AIMessage(content=text or ""))],
|
192 |
llm_output={"model": self.model})
|
193 |
except Exception as e:
|
194 |
+
logging.warning(f"HF Router(Qwen) failed: {type(e).__name__}: {str(e)[:200]}")
|
195 |
return ChatResult(generations=[ChatGeneration(message=AIMessage(content=""))],
|
196 |
llm_output={"error": str(e)})
|
197 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
198 |
# =============================== Selection ===============================
|
199 |
def build_chat_llm():
|
200 |
+
# Prefer Fireworks direct (Qwen)
|
201 |
if FIREWORKS_API_KEY:
|
202 |
+
log.info(f"Using Fireworks chat model: {FW_PRIMARY_MODEL}")
|
203 |
+
return FireworksOpenAIChat(model=FW_PRIMARY_MODEL, api_key=FIREWORKS_API_KEY)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
204 |
|
205 |
+
# Fallback to HF Router (if provided)
|
206 |
+
if HF_TOKEN:
|
207 |
log.info(f"Using HF Router chat model: {HF_PRIMARY_MODEL}")
|
208 |
return HFRouterChat(model=HF_PRIMARY_MODEL, hf_token=HF_TOKEN)
|
|
|
|
|
|
|
209 |
|
210 |
log.warning("No working chat model; notice will be shown.")
|
211 |
return None
|
212 |
|
|
|
213 |
CHAT_LLM = build_chat_llm()
|