File size: 5,571 Bytes
872c909
b2f7be1
288b1c5
872c909
5b4c8f7
 
1ef139f
872c909
01b202f
288b1c5
b2f7be1
50ad890
288b1c5
 
 
b2f7be1
 
 
1972609
 
872c909
 
01b202f
 
872c909
1ef139f
 
01b202f
 
1ef139f
0d1fade
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
872c909
 
288b1c5
b2f7be1
 
0d1fade
 
 
 
288b1c5
 
0d1fade
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8950257
0d1fade
8950257
288b1c5
 
5b4c8f7
0d1fade
 
 
 
 
 
1ef139f
 
 
 
 
 
 
 
 
 
5143501
1ef139f
 
 
 
 
 
 
 
 
 
 
728e48c
1ef139f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5b4c8f7
 
 
 
 
 
 
 
 
288b1c5
 
d089809
288b1c5
 
d089809
872c909
5b4c8f7
 
d089809
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import asyncio
import httpx
from typing import List

model_name = "facebook/nllb-200-distilled-1.3B"
try:
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForSeq2SeqLM.from_pretrained(model_name, dtype="auto")
except Exception as e:
    raise RuntimeError(f"Failed to load model or tokenizer: {str(e)}")

device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)

app = FastAPI(title="Translation API")

class TranslationRequest(BaseModel):
    text: str
    src_lang: str = "eng_Latn"   
    tgt_lang: str = "spa_Latn"  

class TranslationBatchRequest(BaseModel):
    texts: List[str]
    src_lang: str = "eng_Latn"   
    tgt_lang: str = "spa_Latn"   


MAX_LEN = 1024

def split_long_text(text: str, max_len: int):
    # Split on . ? !
    parts = re.split(r'([.?!])', text)
    sentences = ["".join([parts[i], parts[i+1]]).strip() for i in range(0, len(parts)-1, 2)]
    if len(parts) % 2 != 0:
        sentences.append(parts[-1].strip())

    groups, current = [], ""
    for s in sentences:
        if not s:
            continue
        test = (current + " " + s).strip()
        if len(tokenizer(test)["input_ids"]) > max_len:
            if current:
                groups.append(current)
            current = s
        else:
            current = test
    if current:
        groups.append(current)
    return groups

@app.post("/translate")
async def translate_text(req: TranslationRequest):
    try:
        supported_langs = tokenizer.additional_special_tokens
        if req.src_lang not in supported_langs or req.tgt_lang not in supported_langs:
            raise HTTPException(
                status_code=400,
                detail=f"Unsupported language code. Supported languages: {supported_langs}"
            )

        tokenizer.src_lang = req.src_lang
        input_ids = tokenizer(req.text)["input_ids"]

        # If short enough, just translate once
        if len(input_ids) <= MAX_LEN:
            encoded = tokenizer(req.text, return_tensors="pt", truncation=True, max_length=MAX_LEN)
            generated = model.generate(
                **encoded,
                forced_bos_token_id=tokenizer.convert_tokens_to_ids(req.tgt_lang),
                max_length=MAX_LEN
            )
            return {"translation": tokenizer.batch_decode(generated, skip_special_tokens=True)[0]}

        # Else split into manageable groups
        groups = split_long_text(req.text, MAX_LEN)
        results = []
        for group in groups:
            encoded = tokenizer(group, return_tensors="pt", truncation=True, max_length=MAX_LEN)
            generated = model.generate(
                **encoded,
                forced_bos_token_id=tokenizer.convert_tokens_to_ids(req.tgt_lang),
                max_length=MAX_LEN
            )
            results.append(tokenizer.batch_decode(generated, skip_special_tokens=True)[0])

        return {"translation": " ".join(results)}

    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Translation failed: {str(e)}")







@app.post("/translate_batch")
async def translate_batch_texts(req: TranslationBatchRequest):
    try:
        supported_langs = tokenizer.additional_special_tokens
        if req.src_lang not in supported_langs or req.tgt_lang not in supported_langs:
            raise HTTPException(status_code=400, detail=f"Unsupported language code. Supported languages: {supported_langs}")

        if not req.texts:
            raise HTTPException(status_code=400, detail="Text list cannot be empty")
        
        if len(req.texts) > 300:  # Limit batch size to prevent memory issues
            raise HTTPException(status_code=400, detail="Batch size cannot exceed 100 texts")

        tokenizer.src_lang = req.src_lang
        
        translations = []
        
        for text in req.texts:
            if not text.strip():  # Handle empty strings
                translations.append("")
                continue
                
            encoded_input = tokenizer(text, return_tensors="pt",truncation=True,max_length=1024)
            encoded_input = {k: v.to(model.device) for k, v in encoded_input.items()}

            generated_tokens = model.generate(
                **encoded_input,
                forced_bos_token_id=tokenizer.convert_tokens_to_ids(req.tgt_lang),
                max_length=512
            )

            translated_text = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
            translations.append(translated_text)

        return {"translations": translations}
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Batch translation failed: {str(e)}")

@app.get("/health")
async def health_check():
    return {"status": "ok"}

async def periodic_health_check():
    url = "https://huggingface.co/spaces/saeidseyfi/hf_translator"
    async with httpx.AsyncClient() as client:
        while True:
            try:
                response = await client.get(url)
                if response.status_code != 200:
                    print(f"Health check failed with status {response.status_code} | {response.text}")
            except Exception as e:
                print(f"Health check error: {str(e)}")
            await asyncio.sleep(300)  # Check every 30 minutes

@app.on_event("startup")
async def startup_event():
    asyncio.create_task(periodic_health_check())