File size: 2,025 Bytes
b90bddd
 
228728b
 
4dc891c
 
 
 
 
228728b
b90bddd
228728b
4dc891c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
228728b
b90bddd
 
 
4dc891c
 
 
 
 
 
 
 
 
 
b90bddd
 
4dc891c
 
 
b90bddd
228728b
 
b90bddd
228728b
4dc891c
 
 
 
6fc5c0f
228728b
ed823ac
 
af953a5
 
 
228728b
4dc891c
 
 
228728b
4dc891c
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
from fastapi import FastAPI, Request
from pydantic import BaseModel
from model_loader import load_model
import torch
import logging

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

app = FastAPI()

# Global variables for model and tokenizer
tokenizer = None
model = None

@app.on_event("startup")
async def startup_event():
    global tokenizer, model
    logger.info("Loading model and tokenizer...")
    try:
        tokenizer, model = load_model()
        model.eval()
        logger.info("Model and tokenizer loaded successfully!")
        logger.info("FastAPI application is ready to serve requests")
    except Exception as e:
        logger.error(f"Failed to load model: {e}")
        raise e

class PromptRequest(BaseModel):
    prompt: str

@app.get("/")
async def root():
    return {"message": "Qwen Finetuned Model API is running!"}

@app.get("/health")
async def health_check():
    if model is None or tokenizer is None:
        return {"status": "unhealthy", "message": "Model not loaded"}
    return {"status": "healthy", "message": "Model is ready"}

@app.post("/generate")
async def generate_text(request: PromptRequest):
    if model is None or tokenizer is None:
        return {"error": "Model not loaded yet"}
    
    prompt = request.prompt

    if not prompt:
        return {"error": "Prompt is missing"}

    try:
        inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

        with torch.no_grad():
            output = model.generate(**inputs, max_new_tokens=400,repetition_penalty=1.1,temperature=0.3)

        full_response = tokenizer.decode(output[0], skip_special_tokens=True)

        generated_text = full_response[len(prompt):].strip()

        return {"response": generated_text}

    except Exception as e:
        logger.error(f"Error during text generation: {e}")
        return {"error": f"Generation failed: {str(e)}"}

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=7860)