|
from transformers import AutoFeatureExtractor, AutoModelForImageClassification |
|
from PIL import Image |
|
import torch |
|
import gradio as gr |
|
|
|
|
|
cat_ex = AutoFeatureExtractor.from_pretrained("Kaludi/food-category-classification-v2.0") |
|
cat_model = AutoModelForImageClassification.from_pretrained("Kaludi/food-category-classification-v2.0") |
|
|
|
|
|
fruit_ex = AutoFeatureExtractor.from_pretrained("walzsil1/vit-base-fruits-360") |
|
fruit_model = AutoModelForImageClassification.from_pretrained("walzsil1/vit-base-fruits-360") |
|
|
|
def classify(img): |
|
inp = cat_ex(images=img, return_tensors="pt") |
|
logits = cat_model(**inp).logits |
|
probs = torch.softmax(logits, dim=1) |
|
idx = probs.argmax().item() |
|
category = cat_model.config.id2label[idx] |
|
out = {f"category: {category}": float(probs[0, idx])} |
|
if category.lower() == "fruit": |
|
inp2 = fruit_ex(images=img, return_tensors="pt") |
|
logits2 = fruit_model(**inp2).logits |
|
probs2 = torch.softmax(logits2, dim=1) |
|
idx2 = probs2.argmax().item() |
|
species = fruit_model.config.id2label[idx2] |
|
out[f"fruit: {species}"] = float(probs2[0, idx2]) |
|
return out |
|
|
|
demo = gr.Interface( |
|
fn=classify, |
|
inputs=gr.Image(type="pil"), |
|
outputs=gr.Label(num_top_classes=3), |
|
title="Grocery Category + Fruit Species Recognizer", |
|
description="Classifies food category, then if it's fruit, fine-grained species." |
|
) |
|
demo.launch() |
|
|