Spaces:
Sleeping
Sleeping
File size: 2,025 Bytes
b90bddd 228728b 4dc891c 228728b b90bddd 228728b 4dc891c 228728b b90bddd 4dc891c b90bddd 4dc891c b90bddd 228728b b90bddd 228728b 4dc891c 6fc5c0f 228728b ed823ac af953a5 228728b 4dc891c 228728b 4dc891c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 |
from fastapi import FastAPI, Request
from pydantic import BaseModel
from model_loader import load_model
import torch
import logging
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI()
# Global variables for model and tokenizer
tokenizer = None
model = None
@app.on_event("startup")
async def startup_event():
global tokenizer, model
logger.info("Loading model and tokenizer...")
try:
tokenizer, model = load_model()
model.eval()
logger.info("Model and tokenizer loaded successfully!")
logger.info("FastAPI application is ready to serve requests")
except Exception as e:
logger.error(f"Failed to load model: {e}")
raise e
class PromptRequest(BaseModel):
prompt: str
@app.get("/")
async def root():
return {"message": "Qwen Finetuned Model API is running!"}
@app.get("/health")
async def health_check():
if model is None or tokenizer is None:
return {"status": "unhealthy", "message": "Model not loaded"}
return {"status": "healthy", "message": "Model is ready"}
@app.post("/generate")
async def generate_text(request: PromptRequest):
if model is None or tokenizer is None:
return {"error": "Model not loaded yet"}
prompt = request.prompt
if not prompt:
return {"error": "Prompt is missing"}
try:
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
output = model.generate(**inputs, max_new_tokens=400,repetition_penalty=1.1,temperature=0.3)
full_response = tokenizer.decode(output[0], skip_special_tokens=True)
generated_text = full_response[len(prompt):].strip()
return {"response": generated_text}
except Exception as e:
logger.error(f"Error during text generation: {e}")
return {"error": f"Generation failed: {str(e)}"}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)
|