Spaces:
Sleeping
Sleeping
import os | |
os.environ["HF_HOME"] = "/tmp/hf_cache" | |
os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_cache" | |
import io | |
import torch | |
from fastapi import FastAPI, File, UploadFile | |
from fastapi.responses import JSONResponse, HTMLResponse | |
from transformers import AutoImageProcessor, AutoModelForImageClassification | |
from PIL import Image | |
# Load model and processor | |
processor = AutoImageProcessor.from_pretrained("prithivMLmods/Realistic-Gender-Classification") | |
model = AutoModelForImageClassification.from_pretrained("prithivMLmods/Realistic-Gender-Classification") | |
# FastAPI app | |
app = FastAPI() | |
async def home(): | |
return ''' | |
<html> | |
<body> | |
<h2>Upload an Image for Gender Detection</h2> | |
<form action="/predict" enctype="multipart/form-data" method="post"> | |
<input name="file" type="file" accept="image/*"> | |
<input type="submit" value="Upload"> | |
</form> | |
</body> | |
</html> | |
''' | |
async def predict(file: UploadFile = File(...)): | |
image = Image.open(io.BytesIO(await file.read())).convert("RGB") | |
inputs = processor(images=image, return_tensors="pt") | |
with torch.no_grad(): | |
logits = model(**inputs).logits | |
probs = torch.nn.functional.softmax(logits, dim=-1).cpu().numpy()[0] | |
labels = model.config.id2label | |
result = {labels[i]: float(probs[i]) for i in range(len(labels))} | |
return JSONResponse(content=result) |