frammartina commited on
Commit
ec548e9
·
verified ·
1 Parent(s): a1bf370

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -0
app.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from fastapi import FastAPI
3
+ from pydantic import BaseModel
4
+ import torch
5
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
6
+
7
+ MODEL_PATH = "LSTM__0.9170.pt.pt"
8
+ MODEL_URL = "https://drive.google.com/uc?id=133F-sRp_mCGOo73t1ieSnbk5fSxPFENT"
9
+
10
+ if not os.path.exists(MODEL_PATH):
11
+ import gdown
12
+ print("Scaricamento dei pesi dal Google Drive...")
13
+ gdown.download(MODEL_URL, MODEL_PATH, quiet=False)
14
+
15
+ tokenizer = AutoTokenizer.from_pretrained("dmis-lab/biobert-base-cased-v1.1", use_fast=False)
16
+ model = AutoModelForSequenceClassification.from_pretrained(
17
+ "dmis-lab/biobert-base-cased-v1.1",
18
+ num_labels=2
19
+ )
20
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
+ model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
22
+ model.to(device)
23
+ model.eval()
24
+
25
+ app = FastAPI()
26
+
27
+ class Query(BaseModel):
28
+ question: str
29
+ context: str
30
+ long_answer: str
31
+
32
+ @app.post("/chat")
33
+ def get_response(query: Query):
34
+ text = query.question + " " + query.context + " " + query.long_answer
35
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
36
+ inputs = {k: v.to(device) for k, v in inputs.items()}
37
+ outputs = model(**inputs)
38
+ answer = torch.argmax(outputs.logits, dim=-1).item()
39
+ result = "Yes" if answer == 1 else "No"
40
+ return {"answer": result}