asd324r53 / app.py
SAD43W's picture
Update app.py
2f25a59 verified
from transformers import AutoFeatureExtractor, AutoModelForImageClassification
from PIL import Image
import torch
import gradio as gr
# Category classifier (e.g. Food / Fruit / Vegetable / Rice)
cat_ex = AutoFeatureExtractor.from_pretrained("Kaludi/food-category-classification-v2.0")
cat_model = AutoModelForImageClassification.from_pretrained("Kaludi/food-category-classification-v2.0")
# Updated Fruit species classifier
fruit_ex = AutoFeatureExtractor.from_pretrained("walzsil1/vit-base-fruits-360")
fruit_model = AutoModelForImageClassification.from_pretrained("walzsil1/vit-base-fruits-360")
def classify(img):
inp = cat_ex(images=img, return_tensors="pt")
logits = cat_model(**inp).logits
probs = torch.softmax(logits, dim=1)
idx = probs.argmax().item()
category = cat_model.config.id2label[idx]
out = {f"category: {category}": float(probs[0, idx])}
if category.lower() == "fruit":
inp2 = fruit_ex(images=img, return_tensors="pt")
logits2 = fruit_model(**inp2).logits
probs2 = torch.softmax(logits2, dim=1)
idx2 = probs2.argmax().item()
species = fruit_model.config.id2label[idx2]
out[f"fruit: {species}"] = float(probs2[0, idx2])
return out
demo = gr.Interface(
fn=classify,
inputs=gr.Image(type="pil"),
outputs=gr.Label(num_top_classes=3),
title="Grocery Category + Fruit Species Recognizer",
description="Classifies food category, then if it's fruit, fine-grained species."
)
demo.launch()