|
|
|
|
|
import os, io, traceback |
|
from typing import Optional, List, Tuple |
|
import torch |
|
from fastapi import FastAPI, File, UploadFile, Request |
|
from fastapi.responses import JSONResponse |
|
from PIL import Image, UnidentifiedImageError, ImageFile |
|
from torchvision import transforms as T |
|
from functools import lru_cache |
|
|
|
ImageFile.LOAD_TRUNCATED_IMAGES = True |
|
|
|
CACHE_ROOT = os.environ.get("APP_CACHE", "/tmp/appcache") |
|
os.environ["XDG_CACHE_HOME"] = CACHE_ROOT |
|
os.environ["HF_HOME"] = os.path.join(CACHE_ROOT, "hf") |
|
os.environ["HUGGINGFACE_HUB_CACHE"] = os.environ["HF_HOME"] |
|
os.environ["TRANSFORMERS_CACHE"] = os.environ["HF_HOME"] |
|
os.environ["OPENCLIP_CACHE_DIR"] = os.path.join(CACHE_ROOT, "open_clip") |
|
os.environ["TORCH_HOME"] = os.path.join(CACHE_ROOT, "torch") |
|
os.makedirs(os.environ["HF_HOME"], exist_ok=True) |
|
os.makedirs(os.environ["OPENCLIP_CACHE_DIR"], exist_ok=True) |
|
os.makedirs(os.environ["TORCH_HOME"], exist_ok=True) |
|
|
|
import open_clip |
|
|
|
|
|
NUM_THREADS = int(os.environ.get("NUM_THREADS", "1")) |
|
torch.set_num_threads(NUM_THREADS) |
|
os.environ["OMP_NUM_THREADS"] = str(NUM_THREADS) |
|
os.environ["MKL_NUM_THREADS"] = str(NUM_THREADS) |
|
try: |
|
torch.set_num_interop_threads(1) |
|
except Exception: |
|
pass |
|
|
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32 |
|
if DEVICE == "cuda": |
|
torch.set_float32_matmul_precision("high") |
|
|
|
|
|
MODEL_EMB_PATH = os.getenv("MODEL_EMB_PATH", "text_embeddings_modelos_bigg.pt") |
|
VERS_EMB_PATH = os.getenv("VERS_EMB_PATH", "text_embeddings_bigg.pt") |
|
|
|
|
|
MODEL_NAME = "hf-hub:timm/PE-Core-bigG-14-448" |
|
PRETRAINED = None |
|
|
|
app = FastAPI(title="OpenCLIP PE bigG Vehicle API") |
|
|
|
|
|
_ret = open_clip.create_model_and_transforms(MODEL_NAME, pretrained=PRETRAINED) |
|
|
|
if isinstance(_ret, tuple) and len(_ret) == 3: |
|
clip_model, _preprocess_train, preprocess = _ret |
|
else: |
|
clip_model, preprocess = _ret |
|
|
|
clip_model = clip_model.to(device=DEVICE, dtype=DTYPE).eval() |
|
for p in clip_model.parameters(): |
|
p.requires_grad = False |
|
|
|
normalize = next(t for t in getattr(preprocess, "transforms", []) if isinstance(t, T.Normalize)) |
|
SIZE = next((getattr(t, "size", None) for t in getattr(preprocess, "transforms", []) if hasattr(t, "size")), None) |
|
if isinstance(SIZE, (tuple, list)): |
|
SIZE = max(SIZE) |
|
if SIZE is None: |
|
SIZE = 448 |
|
|
|
transform = T.Compose([T.ToTensor(), T.Normalize(mean=normalize.mean, std=normalize.std)]) |
|
|
|
|
|
def resize_letterbox(img: Image.Image, size: int) -> Image.Image: |
|
if img.mode != "RGB": |
|
img = img.convert("RGB") |
|
w, h = img.size |
|
if w == 0 or h == 0: |
|
raise UnidentifiedImageError("imagen invalida") |
|
scale = size / max(w, h) |
|
nw, nh = max(1, int(w*scale)), max(1, int(h*scale)) |
|
img_resized = img.resize((nw, nh), Image.BICUBIC) |
|
canvas = Image.new("RGB", (size, size), (0, 0, 0)) |
|
canvas.paste(img_resized, ((size-nw)//2, (size-nh)//2)) |
|
return canvas |
|
|
|
|
|
def _ensure_label_list(x): |
|
if isinstance(x, (list, tuple)): |
|
return list(x) |
|
if hasattr(x, "tolist"): |
|
return [str(s) for s in x.tolist()] |
|
return [str(s) for s in x] |
|
|
|
def _load_embeddings(path: str): |
|
ckpt = torch.load(path, map_location="cpu") |
|
labels = _ensure_label_list(ckpt["labels"]) |
|
embeds = ckpt["embeddings"].to("cpu") |
|
embeds = embeds / embeds.norm(dim=-1, keepdim=True) |
|
return labels, embeds |
|
|
|
model_labels, model_embeddings = _load_embeddings(MODEL_EMB_PATH) |
|
version_labels, version_embeddings = _load_embeddings(VERS_EMB_PATH) |
|
|
|
|
|
with torch.inference_mode(): |
|
dummy = torch.zeros(1, 3, SIZE, SIZE, device=DEVICE, dtype=DTYPE) |
|
img_dim = clip_model.encode_image(dummy).shape[-1] |
|
if model_embeddings.shape[1] != img_dim or version_embeddings.shape[1] != img_dim: |
|
raise RuntimeError( |
|
f"dimension mismatch: image={img_dim}, modelos={model_embeddings.shape[1]}, " |
|
f"versiones={version_embeddings.shape[1]}. Recalcula embeddings con {MODEL_NAME}." |
|
) |
|
|
|
_versions_cache: dict[str, Tuple[List[str], torch.Tensor]] = {} |
|
|
|
def _get_versions_subset(modelo_full: str) -> Tuple[List[str], Optional[torch.Tensor]]: |
|
hit = _versions_cache.get(modelo_full) |
|
if hit is not None: |
|
return hit |
|
idxs = [i for i, lab in enumerate(version_labels) if lab.startswith(modelo_full)] |
|
if not idxs: |
|
_versions_cache[modelo_full] = ([], None) |
|
return _versions_cache[modelo_full] |
|
labels_sub = [version_labels[i] for i in idxs] |
|
embeds_sub = version_embeddings[idxs] |
|
_versions_cache[modelo_full] = (labels_sub, embeds_sub) |
|
return _versions_cache[modelo_full] |
|
|
|
|
|
@torch.inference_mode() |
|
def _encode_pil(img: Image.Image) -> torch.Tensor: |
|
img = resize_letterbox(img, SIZE) |
|
tensor = transform(img).unsqueeze(0).to(device=DEVICE) |
|
if DEVICE == "cuda": |
|
tensor = tensor.to(dtype=DTYPE) |
|
feats = clip_model.encode_image(tensor) |
|
return feats / feats.norm(dim=-1, keepdim=True) |
|
|
|
def _topk_cosine(text_feats: torch.Tensor, text_labels: List[str], img_feat: torch.Tensor, k: int = 1): |
|
sim = (img_feat.float() @ text_feats.to(img_feat.device).float().T)[0] |
|
vals, idxs = torch.topk(sim, k=k) |
|
conf = torch.softmax(vals, dim=0) |
|
return [{"label": text_labels[int(i)], "confidence": round(float(c)*100.0, 2)} for i, c in zip(idxs, conf)] |
|
|
|
def process_image_bytes(front_bytes: bytes, back_bytes: Optional[bytes] = None): |
|
if not front_bytes or len(front_bytes) < 128: |
|
raise UnidentifiedImageError("imagen invalida") |
|
|
|
img_front = Image.open(io.BytesIO(front_bytes)) |
|
img_feat = _encode_pil(img_front) |
|
|
|
|
|
top_model = _topk_cosine(model_embeddings, model_labels, img_feat, k=1)[0] |
|
modelo_full = top_model["label"] |
|
|
|
partes = modelo_full.split(" ", 1) |
|
marca = partes[0] if len(partes) >= 1 else "" |
|
modelo = partes[1] if len(partes) == 2 else "" |
|
|
|
|
|
labels_sub, embeds_sub = _get_versions_subset(modelo_full) |
|
if not labels_sub: |
|
return {"brand": marca.upper(), "model": modelo.title(), "version": ""} |
|
|
|
|
|
top_ver = _topk_cosine(embeds_sub, labels_sub, img_feat, k=1)[0] |
|
raw = top_ver["label"] |
|
prefix = modelo_full + " " |
|
ver = raw[len(prefix):] if raw.startswith(prefix) else raw |
|
ver = ver.split(" ")[0] |
|
if top_ver["confidence"] < 30.0: |
|
ver = "" |
|
|
|
return {"brand": marca.upper(), "model": modelo.title(), "version": ver.title() if ver else ""} |
|
|
|
|
|
|
|
@app.get("/") |
|
def root(): |
|
return {"status": "ok", "device": DEVICE, "model": f"{MODEL_NAME}", "img_dim": int(model_embeddings.shape[1]), "threads": NUM_THREADS} |
|
|
|
@app.post("/predict/") |
|
async def predict(front: UploadFile = File(None), back: Optional[UploadFile] = File(None), request: Request = None): |
|
try: |
|
if front is None: |
|
return JSONResponse(content={"code": 400, "error": "faltan archivos: 'front' es obligatorio"}, status_code=200) |
|
front_bytes = await front.read() |
|
back_bytes = await back.read() if back is not None else None |
|
vehicle = process_image_bytes(front_bytes, back_bytes) |
|
return JSONResponse(content={"code": 200, "data": {"vehicle": vehicle}}, status_code=200) |
|
except (UnidentifiedImageError, OSError, RuntimeError, ValueError) as e: |
|
return JSONResponse(content={"code": 404, "data": {}, "error": str(e)}, status_code=200) |
|
except Exception: |
|
traceback.print_exc() |
|
return JSONResponse(content={"code": 404, "data": {}, "error": "internal"}, status_code=200) |
|
|