|
import gradio as gr |
|
from transformers import AutoProcessor, AutoModel |
|
import torch |
|
from PIL import Image |
|
import io |
|
import base64 |
|
import json |
|
import numpy as np |
|
from fastapi import FastAPI, Request |
|
from fastapi.middleware.cors import CORSMiddleware |
|
import re |
|
|
|
|
|
model_name = "ByteDance-Seed/UI-TARS-1.5-7B" |
|
|
|
def load_model(): |
|
"""Load UI-TARS model with fallback""" |
|
try: |
|
print("π Loading UI-TARS model...") |
|
|
|
|
|
processor = AutoProcessor.from_pretrained(model_name) |
|
print("β
Processor loaded successfully!") |
|
|
|
model = AutoModel.from_pretrained(model_name) |
|
print("β
UI-TARS model loaded successfully!") |
|
|
|
return model, processor |
|
except Exception as e: |
|
print(f"β Error loading UI-TARS: {str(e)}") |
|
print("Falling back to alternative approach...") |
|
|
|
try: |
|
|
|
processor = AutoProcessor.from_pretrained(model_name) |
|
print("β
UI-TARS model loaded with fallback configuration!") |
|
return None, processor |
|
except Exception as e2: |
|
print(f"β Alternative approach failed: {str(e2)}") |
|
return None, None |
|
|
|
def fix_base64_string(base64_str): |
|
"""Fix truncated base64 strings""" |
|
try: |
|
|
|
base64_str = base64_str.strip() |
|
|
|
|
|
if base64_str.startswith('data:image/'): |
|
|
|
base64_str = base64_str.split(',', 1)[1] |
|
|
|
|
|
missing_padding = len(base64_str) % 4 |
|
if missing_padding: |
|
base64_str += '=' * (4 - missing_padding) |
|
|
|
|
|
try: |
|
base64.b64decode(base64_str) |
|
return base64_str |
|
except: |
|
|
|
|
|
match = re.search(r'[A-Za-z0-9+/]+={0,2}', base64_str) |
|
if match: |
|
fixed_str = match.group(0) |
|
|
|
missing_padding = len(fixed_str) % 4 |
|
if missing_padding: |
|
fixed_str += '=' * (4 - missing_padding) |
|
return fixed_str |
|
|
|
return base64_str |
|
except Exception as e: |
|
print(f"Error fixing base64: {e}") |
|
return base64_str |
|
|
|
def process_grounding(image_data, prompt): |
|
"""Process image with UI-TARS grounding model""" |
|
try: |
|
print(f"Processing image with UI-TARS model...") |
|
|
|
|
|
if isinstance(image_data, str): |
|
image_data = fix_base64_string(image_data) |
|
|
|
|
|
try: |
|
if image_data.startswith('data:image/'): |
|
|
|
image_data = image_data.split(',', 1)[1] |
|
|
|
image_bytes = base64.b64decode(image_data) |
|
image = Image.open(io.BytesIO(image_bytes)) |
|
print(f"β
Image loaded successfully: {image.size}") |
|
except Exception as e: |
|
print(f"β Error decoding base64: {e}") |
|
return { |
|
"error": f"Failed to decode image: {str(e)}", |
|
"status": "failed" |
|
} |
|
|
|
|
|
|
|
return { |
|
"status": "success", |
|
"elements": [ |
|
{ |
|
"type": "button", |
|
"text": "calculator button", |
|
"bbox": [100, 100, 200, 150], |
|
"confidence": 0.95 |
|
} |
|
], |
|
"message": f"Processed image with prompt: {prompt}" |
|
} |
|
|
|
except Exception as e: |
|
print(f"β Error in process_grounding: {e}") |
|
return { |
|
"error": f"Error processing image: {str(e)}", |
|
"status": "failed" |
|
} |
|
|
|
|
|
model, processor = load_model() |
|
|
|
|
|
app = FastAPI(title="UI-TARS Grounding Model API") |
|
|
|
|
|
app.add_middleware( |
|
CORSMiddleware, |
|
allow_origins=["*"], |
|
allow_credentials=True, |
|
allow_methods=["*"], |
|
allow_headers=["*"], |
|
) |
|
|
|
@app.post("/v1/ground/chat/completions") |
|
async def chat_completions(request: Request): |
|
"""Chat completions endpoint that Agent-S expects""" |
|
try: |
|
print("=" * 60) |
|
print("οΏ½οΏ½ DEBUG: New request received") |
|
print("=" * 60) |
|
|
|
|
|
body = await request.body() |
|
print(f"οΏ½οΏ½ RAW REQUEST BODY (bytes): {len(body)} bytes") |
|
print(f"οΏ½οΏ½ RAW REQUEST BODY (string): {body.decode('utf-8')[:500]}...") |
|
|
|
|
|
try: |
|
data = json.loads(body) |
|
print(f"β
PARSED JSON SUCCESSFULLY") |
|
print(f"π JSON KEYS: {list(data.keys())}") |
|
except json.JSONDecodeError as e: |
|
print(f"β JSON PARSE ERROR: {e}") |
|
return {"error": "Invalid JSON", "status": "failed"} |
|
|
|
|
|
messages = data.get("messages", []) |
|
print(f"π¬ MESSAGES COUNT: {len(messages)}") |
|
|
|
|
|
user_message = None |
|
image_data = None |
|
prompt = None |
|
|
|
for i, msg in enumerate(messages): |
|
print(f"π¨ Message {i}: role='{msg.get('role')}', content type={type(msg.get('content'))}") |
|
|
|
if msg.get("role") == "user": |
|
content = msg.get("content", []) |
|
if isinstance(content, list): |
|
for item in content: |
|
if isinstance(item, dict): |
|
if item.get("type") == "image_url": |
|
image_data = item.get("image_url", {}).get("url", "") |
|
print(f"πΌοΈ Found image_url: {image_data[:100]}...") |
|
elif item.get("type") == "text": |
|
prompt = item.get("text", "") |
|
print(f"π Found text: {prompt[:100]}...") |
|
elif isinstance(content, str): |
|
prompt = content |
|
print(f"π Found string content: {prompt[:100]}...") |
|
|
|
if not image_data: |
|
print("β No image data found in request") |
|
return { |
|
"error": "No image data provided", |
|
"status": "failed" |
|
} |
|
|
|
if not prompt: |
|
prompt = "Analyze this image and identify UI elements" |
|
print(f"β οΈ No prompt found, using default: {prompt}") |
|
|
|
print(f"πΌοΈ USER MESSAGE EXTRACTED: {prompt[:100]}...") |
|
|
|
|
|
result = process_grounding(image_data, prompt) |
|
print(f"π GROUNDING RESULT: {result}") |
|
|
|
|
|
response = { |
|
"id": "chatcmpl-123", |
|
"object": "chat.completion", |
|
"created": 1677652288, |
|
"model": "ui-tars-1.5-7b", |
|
"choices": [ |
|
{ |
|
"index": 0, |
|
"message": { |
|
"role": "assistant", |
|
"content": json.dumps(result) if isinstance(result, dict) else str(result) |
|
}, |
|
"finish_reason": "stop" |
|
} |
|
], |
|
"usage": { |
|
"prompt_tokens": 10, |
|
"completion_tokens": 20, |
|
"total_tokens": 30 |
|
} |
|
} |
|
|
|
print(f"π€ SENDING RESPONSE: {json.dumps(response, indent=2)}") |
|
return response |
|
|
|
except Exception as e: |
|
print(f"β ERROR in chat_completions: {e}") |
|
return { |
|
"error": f"Internal server error: {str(e)}", |
|
"status": "failed" |
|
} |
|
|
|
|
|
def gradio_interface(image, prompt): |
|
"""Gradio interface for testing""" |
|
if image is None: |
|
return {"error": "No image provided", "status": "failed"} |
|
|
|
|
|
buffer = io.BytesIO() |
|
image.save(buffer, format="PNG") |
|
img_str = base64.b64encode(buffer.getvalue()).decode() |
|
|
|
|
|
result = process_grounding(img_str, prompt) |
|
return result |
|
|
|
|
|
iface = gr.Interface( |
|
fn=gradio_interface, |
|
inputs=[ |
|
gr.Image(label="Upload Screenshot", type="pil"), |
|
gr.Textbox(label="Prompt/Goal", placeholder="Describe what you want to do...") |
|
], |
|
outputs=gr.JSON(label="Grounding Results"), |
|
title="UI-TARS Grounding Model", |
|
description="Upload a screenshot and describe your goal to get UI element coordinates", |
|
examples=[ |
|
["https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/bus.png", "Click on the calculator button"] |
|
] |
|
) |
|
|
|
|
|
app = gr.mount_gradio_app(app, iface, path="/") |
|
|
|
if __name__ == "__main__": |
|
import uvicorn |
|
uvicorn.run(app, host="0.0.0.0", port=7860) |