Spaces:
Sleeping
Sleeping
File size: 5,571 Bytes
872c909 b2f7be1 288b1c5 872c909 5b4c8f7 1ef139f 872c909 01b202f 288b1c5 b2f7be1 50ad890 288b1c5 b2f7be1 1972609 872c909 01b202f 872c909 1ef139f 01b202f 1ef139f 0d1fade 872c909 288b1c5 b2f7be1 0d1fade 288b1c5 0d1fade 8950257 0d1fade 8950257 288b1c5 5b4c8f7 0d1fade 1ef139f 5143501 1ef139f 728e48c 1ef139f 5b4c8f7 288b1c5 d089809 288b1c5 d089809 872c909 5b4c8f7 d089809 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 |
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import asyncio
import httpx
from typing import List
model_name = "facebook/nllb-200-distilled-1.3B"
try:
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name, dtype="auto")
except Exception as e:
raise RuntimeError(f"Failed to load model or tokenizer: {str(e)}")
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)
app = FastAPI(title="Translation API")
class TranslationRequest(BaseModel):
text: str
src_lang: str = "eng_Latn"
tgt_lang: str = "spa_Latn"
class TranslationBatchRequest(BaseModel):
texts: List[str]
src_lang: str = "eng_Latn"
tgt_lang: str = "spa_Latn"
MAX_LEN = 1024
def split_long_text(text: str, max_len: int):
# Split on . ? !
parts = re.split(r'([.?!])', text)
sentences = ["".join([parts[i], parts[i+1]]).strip() for i in range(0, len(parts)-1, 2)]
if len(parts) % 2 != 0:
sentences.append(parts[-1].strip())
groups, current = [], ""
for s in sentences:
if not s:
continue
test = (current + " " + s).strip()
if len(tokenizer(test)["input_ids"]) > max_len:
if current:
groups.append(current)
current = s
else:
current = test
if current:
groups.append(current)
return groups
@app.post("/translate")
async def translate_text(req: TranslationRequest):
try:
supported_langs = tokenizer.additional_special_tokens
if req.src_lang not in supported_langs or req.tgt_lang not in supported_langs:
raise HTTPException(
status_code=400,
detail=f"Unsupported language code. Supported languages: {supported_langs}"
)
tokenizer.src_lang = req.src_lang
input_ids = tokenizer(req.text)["input_ids"]
# If short enough, just translate once
if len(input_ids) <= MAX_LEN:
encoded = tokenizer(req.text, return_tensors="pt", truncation=True, max_length=MAX_LEN)
generated = model.generate(
**encoded,
forced_bos_token_id=tokenizer.convert_tokens_to_ids(req.tgt_lang),
max_length=MAX_LEN
)
return {"translation": tokenizer.batch_decode(generated, skip_special_tokens=True)[0]}
# Else split into manageable groups
groups = split_long_text(req.text, MAX_LEN)
results = []
for group in groups:
encoded = tokenizer(group, return_tensors="pt", truncation=True, max_length=MAX_LEN)
generated = model.generate(
**encoded,
forced_bos_token_id=tokenizer.convert_tokens_to_ids(req.tgt_lang),
max_length=MAX_LEN
)
results.append(tokenizer.batch_decode(generated, skip_special_tokens=True)[0])
return {"translation": " ".join(results)}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Translation failed: {str(e)}")
@app.post("/translate_batch")
async def translate_batch_texts(req: TranslationBatchRequest):
try:
supported_langs = tokenizer.additional_special_tokens
if req.src_lang not in supported_langs or req.tgt_lang not in supported_langs:
raise HTTPException(status_code=400, detail=f"Unsupported language code. Supported languages: {supported_langs}")
if not req.texts:
raise HTTPException(status_code=400, detail="Text list cannot be empty")
if len(req.texts) > 300: # Limit batch size to prevent memory issues
raise HTTPException(status_code=400, detail="Batch size cannot exceed 100 texts")
tokenizer.src_lang = req.src_lang
translations = []
for text in req.texts:
if not text.strip(): # Handle empty strings
translations.append("")
continue
encoded_input = tokenizer(text, return_tensors="pt",truncation=True,max_length=1024)
encoded_input = {k: v.to(model.device) for k, v in encoded_input.items()}
generated_tokens = model.generate(
**encoded_input,
forced_bos_token_id=tokenizer.convert_tokens_to_ids(req.tgt_lang),
max_length=512
)
translated_text = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
translations.append(translated_text)
return {"translations": translations}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Batch translation failed: {str(e)}")
@app.get("/health")
async def health_check():
return {"status": "ok"}
async def periodic_health_check():
url = "https://huggingface.co/spaces/saeidseyfi/hf_translator"
async with httpx.AsyncClient() as client:
while True:
try:
response = await client.get(url)
if response.status_code != 200:
print(f"Health check failed with status {response.status_code} | {response.text}")
except Exception as e:
print(f"Health check error: {str(e)}")
await asyncio.sleep(300) # Check every 30 minutes
@app.on_event("startup")
async def startup_event():
asyncio.create_task(periodic_health_check())
|