iia / server1.py
addgbf's picture
Update server1.py
58ccff7 verified
# app.py
import os, io, traceback
from typing import Optional, List, Tuple
import torch
from fastapi import FastAPI, File, UploadFile, Request
from fastapi.responses import JSONResponse
from PIL import Image, UnidentifiedImageError, ImageFile
from torchvision import transforms as T
from functools import lru_cache
ImageFile.LOAD_TRUNCATED_IMAGES = True
CACHE_ROOT = os.environ.get("APP_CACHE", "/tmp/appcache")
os.environ["XDG_CACHE_HOME"] = CACHE_ROOT
os.environ["HF_HOME"] = os.path.join(CACHE_ROOT, "hf")
os.environ["HUGGINGFACE_HUB_CACHE"] = os.environ["HF_HOME"]
os.environ["TRANSFORMERS_CACHE"] = os.environ["HF_HOME"]
os.environ["OPENCLIP_CACHE_DIR"] = os.path.join(CACHE_ROOT, "open_clip")
os.environ["TORCH_HOME"] = os.path.join(CACHE_ROOT, "torch")
os.makedirs(os.environ["HF_HOME"], exist_ok=True)
os.makedirs(os.environ["OPENCLIP_CACHE_DIR"], exist_ok=True)
os.makedirs(os.environ["TORCH_HOME"], exist_ok=True)
import open_clip # importar despues de ajustar caches
# ===== limites basicos =====
NUM_THREADS = int(os.environ.get("NUM_THREADS", "1"))
torch.set_num_threads(NUM_THREADS)
os.environ["OMP_NUM_THREADS"] = str(NUM_THREADS)
os.environ["MKL_NUM_THREADS"] = str(NUM_THREADS)
try:
torch.set_num_interop_threads(1)
except Exception:
pass
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32
if DEVICE == "cuda":
torch.set_float32_matmul_precision("high")
# ===== rutas a embeddings =====
MODEL_EMB_PATH = os.getenv("MODEL_EMB_PATH", "text_embeddings_modelos_bigg.pt")
VERS_EMB_PATH = os.getenv("VERS_EMB_PATH", "text_embeddings_bigg.pt")
# ===== modelo PE bigG =====
MODEL_NAME = "hf-hub:timm/PE-Core-bigG-14-448"
PRETRAINED = None
app = FastAPI(title="OpenCLIP PE bigG Vehicle API")
# ===== modelo / preprocess =====
_ret = open_clip.create_model_and_transforms(MODEL_NAME, pretrained=PRETRAINED)
# versiones de open_clip devuelven (model, preprocess_train, preprocess_val)
if isinstance(_ret, tuple) and len(_ret) == 3:
clip_model, _preprocess_train, preprocess = _ret
else:
clip_model, preprocess = _ret
clip_model = clip_model.to(device=DEVICE, dtype=DTYPE).eval()
for p in clip_model.parameters():
p.requires_grad = False
normalize = next(t for t in getattr(preprocess, "transforms", []) if isinstance(t, T.Normalize))
SIZE = next((getattr(t, "size", None) for t in getattr(preprocess, "transforms", []) if hasattr(t, "size")), None)
if isinstance(SIZE, (tuple, list)):
SIZE = max(SIZE)
if SIZE is None:
SIZE = 448 # PE bigG es 448; fallback
transform = T.Compose([T.ToTensor(), T.Normalize(mean=normalize.mean, std=normalize.std)])
# ===== utils imagen (sin cambios: letterbox + BICUBIC) =====
def resize_letterbox(img: Image.Image, size: int) -> Image.Image:
if img.mode != "RGB":
img = img.convert("RGB")
w, h = img.size
if w == 0 or h == 0:
raise UnidentifiedImageError("imagen invalida")
scale = size / max(w, h)
nw, nh = max(1, int(w*scale)), max(1, int(h*scale))
img_resized = img.resize((nw, nh), Image.BICUBIC)
canvas = Image.new("RGB", (size, size), (0, 0, 0))
canvas.paste(img_resized, ((size-nw)//2, (size-nh)//2))
return canvas
# ===== cargar embeddings (sin cambios) =====
def _ensure_label_list(x):
if isinstance(x, (list, tuple)):
return list(x)
if hasattr(x, "tolist"):
return [str(s) for s in x.tolist()]
return [str(s) for s in x]
def _load_embeddings(path: str):
ckpt = torch.load(path, map_location="cpu")
labels = _ensure_label_list(ckpt["labels"])
embeds = ckpt["embeddings"].to("cpu")
embeds = embeds / embeds.norm(dim=-1, keepdim=True)
return labels, embeds
model_labels, model_embeddings = _load_embeddings(MODEL_EMB_PATH)
version_labels, version_embeddings = _load_embeddings(VERS_EMB_PATH)
# comprobar dimension (PE bigG mantiene 1280)
with torch.inference_mode():
dummy = torch.zeros(1, 3, SIZE, SIZE, device=DEVICE, dtype=DTYPE)
img_dim = clip_model.encode_image(dummy).shape[-1]
if model_embeddings.shape[1] != img_dim or version_embeddings.shape[1] != img_dim:
raise RuntimeError(
f"dimension mismatch: image={img_dim}, modelos={model_embeddings.shape[1]}, "
f"versiones={version_embeddings.shape[1]}. Recalcula embeddings con {MODEL_NAME}."
)
_versions_cache: dict[str, Tuple[List[str], torch.Tensor]] = {}
def _get_versions_subset(modelo_full: str) -> Tuple[List[str], Optional[torch.Tensor]]:
hit = _versions_cache.get(modelo_full)
if hit is not None:
return hit
idxs = [i for i, lab in enumerate(version_labels) if lab.startswith(modelo_full)]
if not idxs:
_versions_cache[modelo_full] = ([], None)
return _versions_cache[modelo_full]
labels_sub = [version_labels[i] for i in idxs]
embeds_sub = version_embeddings[idxs] # copia de esas filas
_versions_cache[modelo_full] = (labels_sub, embeds_sub)
return _versions_cache[modelo_full]
# ===== inferencia (sin cambios de logica/precision) =====
@torch.inference_mode()
def _encode_pil(img: Image.Image) -> torch.Tensor:
img = resize_letterbox(img, SIZE)
tensor = transform(img).unsqueeze(0).to(device=DEVICE)
if DEVICE == "cuda":
tensor = tensor.to(dtype=DTYPE)
feats = clip_model.encode_image(tensor)
return feats / feats.norm(dim=-1, keepdim=True)
def _topk_cosine(text_feats: torch.Tensor, text_labels: List[str], img_feat: torch.Tensor, k: int = 1):
sim = (img_feat.float() @ text_feats.to(img_feat.device).float().T)[0]
vals, idxs = torch.topk(sim, k=k)
conf = torch.softmax(vals, dim=0)
return [{"label": text_labels[int(i)], "confidence": round(float(c)*100.0, 2)} for i, c in zip(idxs, conf)]
def process_image_bytes(front_bytes: bytes, back_bytes: Optional[bytes] = None):
if not front_bytes or len(front_bytes) < 128:
raise UnidentifiedImageError("imagen invalida")
img_front = Image.open(io.BytesIO(front_bytes))
img_feat = _encode_pil(img_front)
# paso 1: modelo
top_model = _topk_cosine(model_embeddings, model_labels, img_feat, k=1)[0]
modelo_full = top_model["label"]
partes = modelo_full.split(" ", 1)
marca = partes[0] if len(partes) >= 1 else ""
modelo = partes[1] if len(partes) == 2 else ""
# paso 2: versiones con cache
labels_sub, embeds_sub = _get_versions_subset(modelo_full)
if not labels_sub:
return {"brand": marca.upper(), "model": modelo.title(), "version": ""}
# paso 3: version
top_ver = _topk_cosine(embeds_sub, labels_sub, img_feat, k=1)[0]
raw = top_ver["label"]
prefix = modelo_full + " "
ver = raw[len(prefix):] if raw.startswith(prefix) else raw
ver = ver.split(" ")[0]
if top_ver["confidence"] < 30.0:
ver = ""
return {"brand": marca.upper(), "model": modelo.title(), "version": ver.title() if ver else ""}
# ===== endpoints =====
@app.get("/")
def root():
return {"status": "ok", "device": DEVICE, "model": f"{MODEL_NAME}", "img_dim": int(model_embeddings.shape[1]), "threads": NUM_THREADS}
@app.post("/predict/")
async def predict(front: UploadFile = File(None), back: Optional[UploadFile] = File(None), request: Request = None):
try:
if front is None:
return JSONResponse(content={"code": 400, "error": "faltan archivos: 'front' es obligatorio"}, status_code=200)
front_bytes = await front.read()
back_bytes = await back.read() if back is not None else None
vehicle = process_image_bytes(front_bytes, back_bytes)
return JSONResponse(content={"code": 200, "data": {"vehicle": vehicle}}, status_code=200)
except (UnidentifiedImageError, OSError, RuntimeError, ValueError) as e:
return JSONResponse(content={"code": 404, "data": {}, "error": str(e)}, status_code=200)
except Exception:
traceback.print_exc()
return JSONResponse(content={"code": 404, "data": {}, "error": "internal"}, status_code=200)