dondoesstuff commited on
Commit
c8ae67b
·
verified ·
1 Parent(s): 1c6b552

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +251 -222
app.py CHANGED
@@ -1,222 +1,251 @@
1
- import os
2
- import time
3
- import uuid
4
- from typing import List, Optional, Dict, Any
5
-
6
- import torch
7
- from fastapi import FastAPI, HTTPException
8
- from fastapi.responses import RedirectResponse
9
- from pydantic import BaseModel, Field
10
- from transformers import AutoModelForCausalLM, AutoTokenizer
11
-
12
- MODEL_ID = os.getenv("MODEL_ID", "LiquidAI/LFM2-1.2B")
13
- DEFAULT_MAX_TOKENS = int(os.getenv("MAX_TOKENS", "256"))
14
-
15
- app = FastAPI(title="OpenAI-compatible API for LiquidAI/LFM2-1.2B")
16
-
17
- tokenizer = None
18
- model = None
19
-
20
-
21
- def get_dtype() -> torch.dtype:
22
- if torch.cuda.is_available():
23
- # Prefer bfloat16 if supported; else float16
24
- if torch.cuda.is_bf16_supported():
25
- return torch.bfloat16
26
- return torch.float16
27
- # CPU
28
- return torch.float32
29
-
30
-
31
- @app.on_event("startup")
32
- def load_model():
33
- global tokenizer, model
34
- dtype = get_dtype()
35
- tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True, trust_remote_code=True)
36
- model = AutoModelForCausalLM.from_pretrained(
37
- MODEL_ID,
38
- torch_dtype=dtype,
39
- device_map="auto",
40
- trust_remote_code=True,
41
- )
42
- # Ensure eos/bos tokens exist
43
- if tokenizer.eos_token is None:
44
- tokenizer.eos_token = tokenizer.sep_token or tokenizer.pad_token or "</s>"
45
- if tokenizer.pad_token is None:
46
- tokenizer.pad_token = tokenizer.eos_token
47
-
48
-
49
- class ChatMessage(BaseModel):
50
- role: str
51
- content: str
52
-
53
-
54
- class ChatCompletionRequest(BaseModel):
55
- model: Optional[str] = Field(default=MODEL_ID)
56
- messages: List[ChatMessage]
57
- temperature: Optional[float] = 0.7
58
- top_p: Optional[float] = 0.95
59
- max_tokens: Optional[int] = None
60
- stop: Optional[List[str] | str] = None
61
- n: Optional[int] = 1
62
-
63
-
64
- class CompletionRequest(BaseModel):
65
- model: Optional[str] = Field(default=MODEL_ID)
66
- prompt: str | List[str]
67
- temperature: Optional[float] = 0.7
68
- top_p: Optional[float] = 0.95
69
- max_tokens: Optional[int] = None
70
- stop: Optional[List[str] | str] = None
71
- n: Optional[int] = 1
72
-
73
-
74
- class Usage(BaseModel):
75
- prompt_tokens: int
76
- completion_tokens: int
77
- total_tokens: int
78
-
79
-
80
- # Simple chat prompt formatter
81
-
82
- def build_chat_prompt(messages: List[ChatMessage]) -> str:
83
- system_prefix = "You are a helpful assistant."
84
- system_msgs = [m.content for m in messages if m.role == "system"]
85
- if system_msgs:
86
- system_prefix = system_msgs[-1]
87
-
88
- conv: List[str] = [f"System: {system_prefix}"]
89
- for m in messages:
90
- if m.role == "system":
91
- continue
92
- role = "User" if m.role == "user" else ("Assistant" if m.role == "assistant" else m.role.capitalize())
93
- conv.append(f"{role}: {m.content}")
94
- conv.append("Assistant:")
95
- return "\n".join(conv)
96
-
97
-
98
- def apply_stop_sequences(text: str, stop: Optional[List[str] | str]) -> str:
99
- if stop is None:
100
- return text
101
- stops = stop if isinstance(stop, list) else [stop]
102
- cut = len(text)
103
- for s in stops:
104
- if not s:
105
- continue
106
- idx = text.find(s)
107
- if idx != -1:
108
- cut = min(cut, idx)
109
- return text[:cut]
110
-
111
-
112
- def generate_once(prompt: str, temperature: float, top_p: float, max_new_tokens: int) -> Dict[str, Any]:
113
- assert tokenizer is not None and model is not None, "Model not loaded"
114
-
115
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
116
- gen_ids = model.generate(
117
- **inputs,
118
- max_new_tokens=max_new_tokens,
119
- do_sample=True if temperature and temperature > 0 else False,
120
- temperature=max(0.0, float(temperature or 0.0)),
121
- top_p=max(0.0, float(top_p or 1.0)),
122
- pad_token_id=tokenizer.pad_token_id,
123
- eos_token_id=tokenizer.eos_token_id,
124
- )
125
- out = tokenizer.decode(gen_ids[0][inputs["input_ids"].shape[-1]:], skip_special_tokens=True)
126
- return {
127
- "text": out,
128
- "prompt_tokens": inputs["input_ids"].numel(),
129
- "completion_tokens": gen_ids[0].shape[0] - inputs["input_ids"].shape[-1],
130
- }
131
-
132
-
133
- @app.get("/")
134
- def root():
135
- return RedirectResponse(url="/docs")
136
-
137
-
138
- @app.get("/health")
139
- def health():
140
- return {"status": "ok", "model": MODEL_ID}
141
-
142
-
143
- @app.post("/v1/chat/completions")
144
- def chat_completions(req: ChatCompletionRequest):
145
- if req.n and req.n > 1:
146
- raise HTTPException(status_code=400, detail="Only n=1 is supported in this simple server.")
147
- max_new = req.max_tokens or DEFAULT_MAX_TOKENS
148
-
149
- prompt = build_chat_prompt(req.messages)
150
- g = generate_once(prompt, req.temperature or 0.7, req.top_p or 0.95, max_new)
151
- text = apply_stop_sequences(g["text"], req.stop)
152
-
153
- created = int(time.time())
154
- comp_id = f"chatcmpl-{uuid.uuid4().hex[:24]}"
155
-
156
- usage = Usage(
157
- prompt_tokens=g["prompt_tokens"],
158
- completion_tokens=g["completion_tokens"],
159
- total_tokens=g["prompt_tokens"] + g["completion_tokens"],
160
- )
161
-
162
- return {
163
- "id": comp_id,
164
- "object": "chat.completion",
165
- "created": created,
166
- "model": req.model or MODEL_ID,
167
- "choices": [
168
- {
169
- "index": 0,
170
- "message": {"role": "assistant", "content": text},
171
- "finish_reason": "stop",
172
- }
173
- ],
174
- "usage": usage.dict(),
175
- }
176
-
177
-
178
- @app.post("/v1/completions")
179
- def completions(req: CompletionRequest):
180
- if req.n and req.n > 1:
181
- raise HTTPException(status_code=400, detail="Only n=1 is supported in this simple server.")
182
-
183
- prompts = req.prompt if isinstance(req.prompt, list) else [req.prompt]
184
- if len(prompts) != 1:
185
- raise HTTPException(status_code=400, detail="Only a single prompt is supported in this simple server.")
186
-
187
- max_new = req.max_tokens or DEFAULT_MAX_TOKENS
188
-
189
- g = generate_once(prompts[0], req.temperature or 0.7, req.top_p or 0.95, max_new)
190
- text = apply_stop_sequences(g["text"], req.stop)
191
-
192
- created = int(time.time())
193
- comp_id = f"cmpl-{uuid.uuid4().hex[:24]}"
194
-
195
- usage = Usage(
196
- prompt_tokens=g["prompt_tokens"],
197
- completion_tokens=g["completion_tokens"],
198
- total_tokens=g["prompt_tokens"] + g["completion_tokens"],
199
- )
200
-
201
- return {
202
- "id": comp_id,
203
- "object": "text_completion",
204
- "created": created,
205
- "model": req.model or MODEL_ID,
206
- "choices": [
207
- {
208
- "index": 0,
209
- "text": text,
210
- "finish_reason": "stop",
211
- "logprobs": None,
212
- }
213
- ],
214
- "usage": usage.dict(),
215
- }
216
-
217
-
218
- if __name__ == "__main__":
219
- import uvicorn
220
-
221
- port = int(os.getenv("PORT", "7860"))
222
- uvicorn.run("app:app", host="0.0.0.0", port=port, reload=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Minimal OpenAI-compatible local server that serves /LiquidAI/LFM2-1.2B via Hugging Face
3
+ Transformers on CPU and exposes a subset of the OpenAI REST API (chat/completions, models).
4
+
5
+ Save as local_openai_compatible_server.py and run:
6
+ pip install -r requirements.txt
7
+ python local_openai_compatible_server.py
8
+
9
+ Or run with uvicorn directly (recommended for production/dev):
10
+ uvicorn local_openai_compatible_server:app --host 0.0.0.0 --port 7860
11
+
12
+ Requirements (requirements.txt):
13
+ fastapi
14
+ "uvicorn[standard]"
15
+ transformers
16
+ torch
17
+
18
+ Notes:
19
+ - CPU-only: model loads on CPU (may be slow for a 1.2B model depending on your machine).
20
+ - Model repo id used: "/LiquidAI/LFM2-1.2B" — adjust if you have a different path or local copy.
21
+ - This provides a simplified compatibility layer. It is NOT feature-complete with OpenAI's API
22
+ but implements common fields: messages, max_tokens, temperature, top_p, n, stop, stream (basic).
23
+ """
24
+
25
+ from fastapi import FastAPI, Request, HTTPException
26
+ from fastapi.responses import JSONResponse, StreamingResponse, PlainTextResponse
27
+ from fastapi.middleware.cors import CORSMiddleware
28
+ from pydantic import BaseModel
29
+ from typing import List, Optional, Any, Dict
30
+ import torch
31
+ from transformers import AutoTokenizer, AutoModelForCausalLM
32
+ import time
33
+ import json
34
+ import uuid
35
+
36
+ # -----------------------------
37
+ # Configuration
38
+ # -----------------------------
39
+ MODEL_ID = "/LiquidAI/LFM2-1.2B" # change to your model location or HF repo
40
+ HOST = "0.0.0.0"
41
+ PORT = 7860
42
+ DEVICE = torch.device("cpu") # CPU-only as requested
43
+ DEFAULT_MAX_TOKENS = 256
44
+
45
+ # -----------------------------
46
+ # Load model & tokenizer
47
+ # -----------------------------
48
+ print(f"Loading tokenizer and model '{MODEL_ID}' on device {DEVICE} (CPU-only)... this may take a while")
49
+ try:
50
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)
51
+ model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype=torch.float32)
52
+ model.to(DEVICE)
53
+ model.eval()
54
+ except Exception as e:
55
+ raise RuntimeError(f"Failed to load model/tokenizer for '{MODEL_ID}': {e}")
56
+
57
+ # If tokenizer has no pad/eos, try to set sensible defaults
58
+ if tokenizer.pad_token_id is None:
59
+ if tokenizer.eos_token_id is not None:
60
+ tokenizer.pad_token_id = tokenizer.eos_token_id
61
+
62
+ # -----------------------------
63
+ # FastAPI app
64
+ # -----------------------------
65
+ app = FastAPI(title="Local OpenAI-compatible server (transformers)", version="0.1")
66
+ app.add_middleware(
67
+ CORSMiddleware,
68
+ allow_origins=["*"],
69
+ allow_credentials=True,
70
+ allow_methods=["*"],
71
+ allow_headers=["*"],
72
+ )
73
+
74
+ # -----------------------------
75
+ # Pydantic models (request bodies)
76
+ # -----------------------------
77
+ class Message(BaseModel):
78
+ role: str
79
+ content: str
80
+
81
+ class ChatCompletionRequest(BaseModel):
82
+ model: Optional[str] = MODEL_ID
83
+ messages: List[Message]
84
+ max_tokens: Optional[int] = DEFAULT_MAX_TOKENS
85
+ temperature: Optional[float] = 0.0
86
+ top_p: Optional[float] = 1.0
87
+ n: Optional[int] = 1
88
+ stop: Optional[List[str]] = None
89
+ stream: Optional[bool] = False
90
+
91
+ # -----------------------------
92
+ # Helpers
93
+ # -----------------------------
94
+ def build_prompt_from_messages(messages: List[Dict[str, Any]]) -> str:
95
+ # Simple conversational prompt formatting. Adjust to suit model's expected format.
96
+ parts = []
97
+ for m in messages:
98
+ role = m.get("role", "user")
99
+ content = m.get("content", "")
100
+ if role == "system":
101
+ parts.append(f"<|system|> {content}\n")
102
+ elif role == "user":
103
+ parts.append(f"User: {content}\n")
104
+ elif role == "assistant":
105
+ parts.append(f"Assistant: {content}\n")
106
+ else:
107
+ parts.append(f"{role}: {content}\n")
108
+ parts.append("Assistant: ")
109
+ return "".join(parts)
110
+
111
+
112
+ def apply_stop_sequences(text: str, stops: Optional[List[str]]) -> str:
113
+ if not stops:
114
+ return text
115
+ idx = None
116
+ for s in stops:
117
+ if s == "":
118
+ continue
119
+ pos = text.find(s)
120
+ if pos != -1:
121
+ if idx is None or pos < idx:
122
+ idx = pos
123
+ if idx is not None:
124
+ return text[:idx]
125
+ return text
126
+
127
+ # -----------------------------
128
+ # Endpoints
129
+ # -----------------------------
130
+ @app.get("/", response_class=PlainTextResponse)
131
+ async def root():
132
+ return "Local OpenAI-compatible server running. Use /v1/chat/completions or /v1/models"
133
+
134
+ @app.get("/v1/models")
135
+ async def list_models():
136
+ return {"data": [{"id": MODEL_ID, "object": "model"}], "object": "list"}
137
+
138
+ @app.post("/v1/chat/completions")
139
+ async def chat_completions(request: Request, body: ChatCompletionRequest):
140
+ # Basic validation
141
+ if body.model is None or body.model != MODEL_ID:
142
+ # Allow the default model but warn if mismatched
143
+ raise HTTPException(status_code=400, detail={"error": "invalid_model", "message": f"Only model {MODEL_ID} is available on this server."})
144
+
145
+ prompt = build_prompt_from_messages([m.dict() for m in body.messages])
146
+
147
+ # Tokenize
148
+ inputs = tokenizer(prompt, return_tensors="pt")
149
+ input_ids = inputs["input_ids"].to(DEVICE)
150
+ input_len = input_ids.shape[-1]
151
+
152
+ # Generation settings
153
+ gen_kwargs = {
154
+ "max_new_tokens": body.max_tokens,
155
+ "do_sample": bool(body.temperature and body.temperature > 0.0),
156
+ "temperature": float(body.temperature or 0.0),
157
+ "top_p": float(body.top_p or 1.0),
158
+ "num_return_sequences": int(body.n or 1),
159
+ "pad_token_id": tokenizer.pad_token_id or tokenizer.eos_token_id,
160
+ # note: on CPU large models may be slow
161
+ }
162
+
163
+ # Synchronous generation
164
+ with torch.no_grad():
165
+ outputs = model.generate(input_ids, **gen_kwargs)
166
+
167
+ choices = []
168
+ for i, out_ids in enumerate(outputs):
169
+ full_text = tokenizer.decode(out_ids, skip_special_tokens=True)
170
+ # Attempt to strip the prompt prefix to return only generated reply
171
+ # find the last occurrence of the prompt in full_text (best-effort)
172
+ stripped = full_text
173
+ try:
174
+ # prefer exact match; fallback to trimming by token count
175
+ if prompt.strip() and prompt in full_text:
176
+ stripped = full_text.split(prompt, 1)[1]
177
+ else:
178
+ # fallback: remove first input_len tokens from decoded sequence
179
+ decoded_all = full_text
180
+ # naive fallback: no-op (we keep the full_text)
181
+ stripped = decoded_all
182
+ except Exception:
183
+ stripped = full_text
184
+
185
+ # apply stop sequences
186
+ stripped = apply_stop_sequences(stripped, body.stop)
187
+
188
+ # build choice structure similar to OpenAI
189
+ choice = {
190
+ "index": i,
191
+ "message": {"role": "assistant", "content": stripped},
192
+ "finish_reason": "stop" if body.stop else "length",
193
+ }
194
+ choices.append(choice)
195
+
196
+ # approximate token usage
197
+ completion_tokens = max(0, (outputs.shape[-1] - input_len) if outputs is not None else 0)
198
+ usage = {"prompt_tokens": int(input_len), "completion_tokens": int(completion_tokens), "total_tokens": int(input_len + completion_tokens)}
199
+
200
+ response = {
201
+ "id": str(uuid.uuid4()),
202
+ "object": "chat.completion",
203
+ "created": int(time.time()),
204
+ "model": body.model,
205
+ "choices": choices,
206
+ "usage": usage,
207
+ }
208
+
209
+ # Streaming: rudimentary implementation that streams chunks of the final text as SSE
210
+ if body.stream:
211
+ # Only support streaming a single response (n > 1 will still stream the first)
212
+ text_to_stream = choices[0]["message"]["content"]
213
+ def event_stream():
214
+ # send a few small chunks
215
+ chunk_size = 128
216
+ for start in range(0, len(text_to_stream), chunk_size):
217
+ chunk = text_to_stream[start:start+chunk_size]
218
+ payload = {"id": response["id"], "object": "chat.completion.chunk", "choices": [{"delta": {"content": chunk}, "index": 0}]}
219
+ yield f"data: {json.dumps(payload)}\n\n"
220
+ # final done message
221
+ done_payload = {"id": response["id"], "object": "chat.completion.chunk", "choices": [{"delta": {}, "index": 0}], "done": True}
222
+ yield f"data: {json.dumps(done_payload)}\n\n"
223
+ return StreamingResponse(event_stream(), media_type="text/event-stream")
224
+
225
+ return JSONResponse(response)
226
+
227
+ # A convenience POST /v1/completions that accepts 'prompt' (legacy completions API)
228
+ class CompletionRequest(BaseModel):
229
+ model: Optional[str] = MODEL_ID
230
+ prompt: Optional[str] = ""
231
+ max_tokens: Optional[int] = DEFAULT_MAX_TOKENS
232
+ temperature: Optional[float] = 0.0
233
+ top_p: Optional[float] = 1.0
234
+ n: Optional[int] = 1
235
+ stop: Optional[List[str]] = None
236
+ stream: Optional[bool] = False
237
+
238
+ @app.post("/v1/completions")
239
+ async def completions(req: CompletionRequest):
240
+ # wrap prompt into the chat-format for our generator
241
+ messages = [Message(role="user", content=req.prompt)]
242
+ chat_req = ChatCompletionRequest(model=req.model, messages=messages, max_tokens=req.max_tokens, temperature=req.temperature, top_p=req.top_p, n=req.n, stop=req.stop, stream=req.stream)
243
+ # call the chat_completions handler directly
244
+ return await chat_completions(Request(scope={}), chat_req)
245
+
246
+ # -----------------------------
247
+ # If executed directly, run uvicorn
248
+ # -----------------------------
249
+ if __name__ == "__main__":
250
+ import uvicorn
251
+ uvicorn.run("local_openai_compatible_server:app", host=HOST, port=PORT, log_level="info")