Spaces:
Running
on
Zero
Running
on
Zero
from typing import Optional | |
import spaces | |
import gradio as gr | |
import numpy as np | |
import torch | |
from PIL import Image | |
import io | |
import base64, os | |
from huggingface_hub import snapshot_download | |
import traceback | |
import warnings | |
import sys | |
# Suppress warnings | |
warnings.filterwarnings("ignore", category=FutureWarning) | |
warnings.filterwarnings("ignore", message=".*_supports_sdpa.*") | |
# Simple monkey patch for transformers - avoid recursion | |
def simple_patch_transformers(): | |
"""Simple patch to fix _supports_sdpa issue""" | |
try: | |
import transformers.modeling_utils as modeling_utils | |
# Store original method | |
original_check = modeling_utils.PreTrainedModel._check_and_adjust_attn_implementation | |
def patched_check(self, *args, **kwargs): | |
# Simply set the attribute if it doesn't exist | |
if not hasattr(self, '_supports_sdpa'): | |
object.__setattr__(self, '_supports_sdpa', False) | |
try: | |
return original_check(self, *args, **kwargs) | |
except AttributeError as e: | |
if '_supports_sdpa' in str(e): | |
# Return default attention implementation | |
return "eager" | |
raise | |
modeling_utils.PreTrainedModel._check_and_adjust_attn_implementation = patched_check | |
print("Applied simple transformers patch") | |
except Exception as e: | |
print(f"Warning: Could not patch transformers: {e}") | |
# Apply the patch BEFORE importing utils | |
simple_patch_transformers() | |
# Now import the utils | |
from util.utils import check_ocr_box, get_yolo_model, get_caption_model_processor, get_som_labeled_img | |
# Download repository | |
repo_id = "microsoft/OmniParser-v2.0" | |
local_dir = "weights" | |
if not os.path.exists(local_dir): | |
snapshot_download(repo_id=repo_id, local_dir=local_dir) | |
print(f"Repository downloaded to: {local_dir}") | |
else: | |
print(f"Weights already exist at: {local_dir}") | |
# Custom function to load caption model | |
def load_caption_model_safe(model_name="florence2", model_name_or_path="weights/icon_caption"): | |
"""Safely load caption model""" | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
# Method 1: Try original function | |
try: | |
return get_caption_model_processor(model_name, model_name_or_path) | |
except Exception as e: | |
print(f"Original loading failed: {e}, trying alternative...") | |
# Method 2: Load with specific configs | |
try: | |
from transformers import AutoProcessor, AutoModelForCausalLM | |
print(f"Loading caption model from {model_name_or_path}...") | |
processor = AutoProcessor.from_pretrained( | |
model_name_or_path, | |
trust_remote_code=True | |
) | |
# Load model with safer config | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name_or_path, | |
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, | |
trust_remote_code=True, | |
attn_implementation="eager", # Use eager attention | |
low_cpu_mem_usage=True | |
) | |
# Ensure attribute exists (using object.__setattr__ to avoid recursion) | |
if not hasattr(model, '_supports_sdpa'): | |
object.__setattr__(model, '_supports_sdpa', False) | |
if device.type == 'cuda': | |
model = model.to(device) | |
print("Model loaded successfully with alternative method") | |
return {'model': model, 'processor': processor} | |
except Exception as e: | |
print(f"Alternative loading also failed: {e}") | |
# Method 3: Manual loading as last resort | |
try: | |
print("Attempting manual model loading...") | |
# Import required modules | |
from transformers import AutoProcessor, AutoConfig | |
import importlib.util | |
# Load processor | |
processor = AutoProcessor.from_pretrained( | |
model_name_or_path, | |
trust_remote_code=True | |
) | |
# Load config | |
config = AutoConfig.from_pretrained( | |
model_name_or_path, | |
trust_remote_code=True | |
) | |
# Manually import and instantiate model | |
model_file = os.path.join(model_name_or_path, "modeling_florence2.py") | |
if os.path.exists(model_file): | |
spec = importlib.util.spec_from_file_location("modeling_florence2_custom", model_file) | |
module = importlib.util.module_from_spec(spec) | |
spec.loader.exec_module(module) | |
# Get model class | |
if hasattr(module, 'Florence2ForConditionalGeneration'): | |
model_class = module.Florence2ForConditionalGeneration | |
# Create model instance | |
model = model_class(config) | |
# Set the attribute before loading weights | |
object.__setattr__(model, '_supports_sdpa', False) | |
# Load weights | |
weight_file = os.path.join(model_name_or_path, "model.safetensors") | |
if os.path.exists(weight_file): | |
from safetensors.torch import load_file | |
state_dict = load_file(weight_file) | |
model.load_state_dict(state_dict, strict=False) | |
if device.type == 'cuda': | |
model = model.to(device) | |
model = model.half() # Use half precision | |
print("Model loaded successfully with manual method") | |
return {'model': model, 'processor': processor} | |
except Exception as e: | |
print(f"Manual loading failed: {e}") | |
raise RuntimeError(f"Could not load model with any method: {e}") | |
# Load models | |
try: | |
print("Loading YOLO model...") | |
yolo_model = get_yolo_model(model_path='weights/icon_detect/model.pt') | |
print("YOLO model loaded successfully") | |
print("Loading caption model...") | |
caption_model_processor = load_caption_model_safe() | |
print("Caption model loaded successfully") | |
except Exception as e: | |
print(f"Critical error loading models: {e}") | |
print(traceback.format_exc()) | |
caption_model_processor = None | |
yolo_model = None | |
# UI Configuration | |
MARKDOWN = """ | |
# OmniParser V2 Pro🔥 | |
<div style="background-color: #f0f8ff; padding: 15px; border-radius: 10px; margin-bottom: 20px;"> | |
<p style="margin: 0;">🎯 <strong>AI-powered screen understanding tool</strong> that detects UI elements and extracts text with high accuracy.</p> | |
<p style="margin: 5px 0 0 0;">📝 Supports both PaddleOCR and EasyOCR for flexible text extraction.</p> | |
</div> | |
""" | |
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
print(f"Using device: {DEVICE}") | |
custom_css = """ | |
body { background-color: #f0f2f5; } | |
.gradio-container { font-family: 'Segoe UI', sans-serif; max-width: 1400px; margin: auto; } | |
h1, h2, h3, h4 { color: #283E51; } | |
button { border-radius: 6px; transition: all 0.3s ease; } | |
button:hover { transform: translateY(-2px); box-shadow: 0 4px 12px rgba(0,0,0,0.15); } | |
.output-image { border: 2px solid #e1e4e8; border-radius: 8px; } | |
#input_image { border: 2px dashed #4a90e2; border-radius: 8px; } | |
#input_image:hover { border-color: #2c5aa0; } | |
""" | |
def process( | |
image_input, | |
box_threshold, | |
iou_threshold, | |
use_paddleocr, | |
imgsz | |
) -> tuple: | |
"""Process image with error handling""" | |
if image_input is None: | |
return None, "⚠️ Please upload an image for processing." | |
if caption_model_processor is None or yolo_model is None: | |
return None, "⚠️ Models not loaded properly. Please restart the application." | |
try: | |
print(f"Processing: box_threshold={box_threshold}, iou_threshold={iou_threshold}, " | |
f"use_paddleocr={use_paddleocr}, imgsz={imgsz}") | |
# Calculate overlay ratio | |
image_width = image_input.size[0] | |
box_overlay_ratio = max(0.5, min(2.0, image_width / 3200)) | |
draw_bbox_config = { | |
'text_scale': 0.8 * box_overlay_ratio, | |
'text_thickness': max(int(2 * box_overlay_ratio), 1), | |
'text_padding': max(int(3 * box_overlay_ratio), 1), | |
'thickness': max(int(3 * box_overlay_ratio), 1), | |
} | |
# OCR processing | |
try: | |
ocr_bbox_rslt, is_goal_filtered = check_ocr_box( | |
image_input, | |
display_img=False, | |
output_bb_format='xyxy', | |
goal_filtering=None, | |
easyocr_args={'paragraph': False, 'text_threshold': 0.9}, | |
use_paddleocr=use_paddleocr | |
) | |
if ocr_bbox_rslt is None: | |
text, ocr_bbox = [], [] | |
else: | |
text, ocr_bbox = ocr_bbox_rslt | |
text = text if text is not None else [] | |
ocr_bbox = ocr_bbox if ocr_bbox is not None else [] | |
print(f"OCR found {len(text)} text regions") | |
except Exception as e: | |
print(f"OCR error: {e}") | |
text, ocr_bbox = [], [] | |
# Object detection and captioning | |
try: | |
# Ensure model has _supports_sdpa attribute | |
if isinstance(caption_model_processor, dict) and 'model' in caption_model_processor: | |
model = caption_model_processor['model'] | |
if not hasattr(model, '_supports_sdpa'): | |
object.__setattr__(model, '_supports_sdpa', False) | |
dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img( | |
image_input, | |
yolo_model, | |
BOX_TRESHOLD=box_threshold, | |
output_coord_in_ratio=True, | |
ocr_bbox=ocr_bbox, | |
draw_bbox_config=draw_bbox_config, | |
caption_model_processor=caption_model_processor, | |
ocr_text=text, | |
iou_threshold=iou_threshold, | |
imgsz=imgsz | |
) | |
if dino_labled_img is None: | |
raise ValueError("Failed to generate labeled image") | |
except Exception as e: | |
print(f"Detection error: {e}") | |
return image_input, f"⚠️ Error during detection: {str(e)}" | |
# Decode image | |
try: | |
image = Image.open(io.BytesIO(base64.b64decode(dino_labled_img))) | |
except Exception as e: | |
print(f"Image decode error: {e}") | |
return image_input, f"⚠️ Error decoding image: {str(e)}" | |
# Format results | |
if parsed_content_list and len(parsed_content_list) > 0: | |
parsed_text = "🎯 **Detected Elements:**\n\n" | |
for i, v in enumerate(parsed_content_list): | |
if v: | |
parsed_text += f"**Element {i}:** {v}\n" | |
else: | |
parsed_text = "ℹ️ No UI elements detected. Try adjusting the thresholds." | |
print(f'Processing complete. Found {len(parsed_content_list)} elements.') | |
return image, parsed_text | |
except Exception as e: | |
print(f"Processing error: {e}") | |
print(traceback.format_exc()) | |
return None, f"⚠️ Error: {str(e)}" | |
# Build UI | |
with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo: | |
gr.Markdown(MARKDOWN) | |
if caption_model_processor is None or yolo_model is None: | |
gr.Markdown("### ⚠️ Warning: Models failed to load. Please check logs.") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
with gr.Accordion("📤 Upload & Settings", open=True): | |
image_input_component = gr.Image( | |
type='pil', | |
label='Upload Screenshot', | |
elem_id="input_image" | |
) | |
gr.Markdown("### 🎛️ Detection Settings") | |
box_threshold_component = gr.Slider( | |
label='Box Threshold', | |
minimum=0.01, | |
maximum=1.0, | |
step=0.01, | |
value=0.05, | |
info="Lower = more detections" | |
) | |
iou_threshold_component = gr.Slider( | |
label='IOU Threshold', | |
minimum=0.01, | |
maximum=1.0, | |
step=0.01, | |
value=0.1, | |
info="Overlap filtering" | |
) | |
use_paddleocr_component = gr.Checkbox( | |
label='Use PaddleOCR', | |
value=True | |
) | |
imgsz_component = gr.Slider( | |
label='Image Size', | |
minimum=640, | |
maximum=1920, | |
step=32, | |
value=640 | |
) | |
submit_button_component = gr.Button( | |
value='🚀 Process', | |
variant='primary' | |
) | |
with gr.Column(scale=2): | |
with gr.Tabs(): | |
with gr.Tab("🖼️ Result"): | |
image_output_component = gr.Image( | |
type='pil', | |
label='Annotated Image' | |
) | |
with gr.Tab("📝 Elements"): | |
text_output_component = gr.Markdown( | |
value="*Results will appear here...*" | |
) | |
submit_button_component.click( | |
fn=process, | |
inputs=[ | |
image_input_component, | |
box_threshold_component, | |
iou_threshold_component, | |
use_paddleocr_component, | |
imgsz_component | |
], | |
outputs=[image_output_component, text_output_component], | |
show_progress=True | |
) | |
# Launch | |
if __name__ == "__main__": | |
try: | |
demo.queue(max_size=10) | |
demo.launch( | |
share=False, | |
show_error=True, | |
server_name="0.0.0.0", | |
server_port=7860 | |
) | |
except Exception as e: | |
print(f"Launch failed: {e}") |