HRM / app.py
tenet's picture
Update app.py
b22dfe5 verified
import gradio as gr
from ultralytics import YOLO
import numpy as np
import cv2
from PIL import Image
import random
from transformers import pipeline
# ---------------------------
# Load Models
# ---------------------------
# Text model (tiny LLM)
text_gen = pipeline("text-generation", model="tiny-random-gpt2")
# YOLOv8 segmentation (nano version for speed)
yolo_model = YOLO("yolov8n-seg.pt") # change to yolov8s-seg.pt for more accuracy
# ---------------------------
# Image Segmentation
# ---------------------------
def segment_image(image: Image.Image):
results = yolo_model.predict(np.array(image))[0]
overlay = np.array(image).copy()
annotations = []
if results.masks is not None:
for mask, cls in zip(results.masks.xy, results.boxes.cls):
pts = np.array(mask, dtype=np.int32)
color = [random.randint(0, 255) for _ in range(3)]
cv2.fillPoly(overlay, [pts], color)
annotations.append((mask.tolist(), yolo_model.names[int(cls)]))
overlay_img = Image.fromarray(overlay)
return (overlay_img, annotations)
# ---------------------------
# Video Segmentation
# ---------------------------
def segment_video(video):
cap = cv2.VideoCapture(video)
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
out_path = "output.mp4"
out = cv2.VideoWriter(out_path, fourcc, cap.get(cv2.CAP_PROP_FPS),
(int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)),
int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))))
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
results = yolo_model.predict(frame)[0]
overlay = frame.copy()
if results.masks is not None:
for mask, cls in zip(results.masks.xy, results.boxes.cls):
pts = np.array(mask, dtype=np.int32)
color = [random.randint(0, 255) for _ in range(3)]
cv2.fillPoly(overlay, [pts], color)
out.write(overlay)
cap.release()
out.release()
return out_path
# ---------------------------
# Text Generation
# ---------------------------
def generate_text(prompt):
result = text_gen(prompt, max_length=100, num_return_sequences=1)
return result[0]["generated_text"]
# ---------------------------
# Gradio UI
# ---------------------------
with gr.Blocks() as demo:
gr.Markdown("# πŸ”₯ Multi-Modal Playground\nTry out **Text + Image + Video Segmentation** in one app!")
with gr.Tab("πŸ’¬ Text Generation"):
inp_text = gr.Textbox(label="Enter your prompt")
out_text = gr.Textbox(label="Generated text")
btn_text = gr.Button("Generate")
btn_text.click(generate_text, inputs=inp_text, outputs=out_text)
with gr.Tab("πŸ–ΌοΈ Image Segmentation"):
inp_img = gr.Image(type="pil", label="Upload Image")
out_img = gr.Image(type="pil", label="Segmented Image")
out_ann = gr.JSON(label="Annotations")
btn_img = gr.Button("Run Segmentation")
btn_img.click(segment_image, inputs=inp_img, outputs=[out_img, out_ann])
with gr.Tab("πŸŽ₯ Video Segmentation"):
inp_vid = gr.Video(label="Upload Video")
out_vid = gr.Video(label="Segmented Video")
btn_vid = gr.Button("Run Segmentation")
btn_vid.click(segment_video, inputs=inp_vid, outputs=out_vid)
demo.launch(server_name="0.0.0.0", server_port=7860, share=False, ssr_mode=False)