UI-TARS / app.py
ISSAAYMAN4's picture
Update app.py
eb3c23e verified
# app.py — UI-TARS demo (OSS disabled)
import base64
import json
import ast
import os
import re
import io
import math
from datetime import datetime
import gradio as gr
from PIL import ImageDraw
# =========================
# OpenAI client (optional)
# =========================
# If OPENAI_API_KEY is set we will use OpenAI for parsing the model output text.
# If ENDPOINT_URL is set, we'll point the OpenAI client at that base URL (advanced use).
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
ENDPOINT_URL = os.getenv("ENDPOINT_URL") # optional
MODEL_NAME = os.getenv("MODEL_NAME", "gpt-4o-mini") # safe default instead of "tgi"
client = None
if OPENAI_API_KEY:
try:
from openai import OpenAI
if ENDPOINT_URL:
client = OpenAI(api_key=OPENAI_API_KEY, base_url=ENDPOINT_URL)
else:
client = OpenAI(api_key=OPENAI_API_KEY)
print("✅ OpenAI client initialized.")
except Exception as e:
print(f"⚠️ OpenAI client not available: {e}")
else:
print("ℹ️ OPENAI_API_KEY not set. Running without OpenAI parsing.")
# =========================
# UI-TARS prompt
# =========================
DESCRIPTION = "[UI-TARS](https://github.com/bytedance/UI-TARS)"
prompt = (
"Output only the coordinate of one box in your response. "
"Return a tuple like (x,y) with values in 0..1000 for x and y. "
"Do not include any extra text. "
)
# =========================
# OSS (Aliyun) — DISABLED
# =========================
# The original demo used Aliyun OSS (oss2) to upload images/metadata.
# We disable it fully so no ENV like BUCKET / ENDPOINT is required.
bucket = None
print("⚠️ OSS integration disabled: skipping Aliyun storage.")
def draw_point_area(image, point):
"""Draw a red point+circle at a (0..1000, 0..1000) coordinate on the given PIL image."""
if not point:
return image
radius = min(image.width, image.height) // 15
x = round(point[0] / 1000 * image.width)
y = round(point[1] / 1000 * image.height)
drawer = ImageDraw.Draw(image)
drawer.ellipse((x - radius, y - radius, x + radius, y + radius), outline="red", width=2)
drawer.ellipse((x - 2, y - 2, x + 2, y + 2), fill="red")
return image
def resize_image(image):
"""Resize extremely large screenshots to keep compute stable."""
max_pixels = 6000 * 28 * 28
if image.width * image.height > max_pixels:
max_pixels = 2700 * 28 * 28
else:
max_pixels = 1340 * 28 * 28
resize_factor = math.sqrt(max_pixels / (image.width * image.height))
width, height = int(image.width * resize_factor), int(image.height * resize_factor)
return image.resize((width, height))
def upload_images(session_id, image, result_image, query):
"""No-op when OSS is disabled. Keeps API stable."""
if bucket is None:
print("↪️ Skipped OSS upload (no bucket configured).")
return
img_path = f"{session_id}.png"
result_img_path = f"{session_id}-draw.png"
metadata = dict(
query=query,
resize_image=img_path,
result_image=result_img_path,
session_id=session_id,
)
img_bytes = io.BytesIO()
image.save(img_bytes, format="png")
bucket.put_object(img_path, img_bytes.getvalue())
rst_img_bytes = io.BytesIO()
result_image.save(rst_img_bytes, format="png")
bucket.put_object(result_img_path, rst_img_bytes.getvalue())
bucket.put_object(f"{session_id}.json", json.dumps(metadata).encode("utf-8"))
print("✅ (would) upload images — skipped unless bucket configured")
def run_ui(image, query, session_id, is_example_image):
"""Main inference path: builds the message, asks the model for (x,y), draws, returns results."""
click_xy = None
images_during_iterations = []
width, height = image.width, image.height
# Resize for throughput + encode
image = resize_image(image)
buf = io.BytesIO()
image.save(buf, format="png")
base64_image = base64.standard_b64encode(buf.getvalue()).decode("utf-8")
# Prepare prompt for an LLM that returns '(x,y)'
messages = [
{
"role": "user",
"content": [
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{base64_image}"}},
{"type": "text", "text": prompt + query},
],
}
]
# If OpenAI client is present, ask it to parse coordinates. Otherwise we return a safe default.
output_text = ""
if client is not None:
try:
resp = client.chat.completions.create(
model=MODEL_NAME,
messages=messages,
temperature=1.0,
top_p=0.7,
max_tokens=128,
frequency_penalty=1,
stream=False,
)
output_text = resp.choices[0].message.content or ""
except Exception as e:
output_text = ""
print(f"⚠️ OpenAI call failed: {e}")
# Extract "(x,y)" from the text using regex
pattern = r"\((\d+,\s*\d+)\)"
match = re.search(pattern, output_text)
if match:
coordinates = match.group(1)
try:
click_xy = ast.literal_eval(coordinates) # (x, y) with 0..1000 scale
except Exception:
click_xy = None
# If we still don't have coordinates, fall back to center
if click_xy is None:
click_xy = (500, 500)
# Draw result + convert to absolute pixel coords for display
result_image = draw_point_area(image.copy(), click_xy)
images_during_iterations.append(result_image)
abs_xy = (round(click_xy[0] / 1000 * width), round(click_xy[1] / 1000 * height))
# Upload artifacts only for real (non-example) inputs
if str(is_example_image) == "False":
upload_images(session_id, image, result_image, query)
return images_during_iterations, str(abs_xy)
def update_vote(vote_type, image, click_image, prompt_text, is_example):
"""Simple feedback hook (no external upload when OSS disabled)."""
if vote_type == "upvote":
return "Everything good"
if is_example == "True":
return "Do nothing for example"
# Example gallery returns file paths; we do nothing here
return "Thank you for your feedback!"
# Demo examples
examples = [
["./examples/solitaire.png", "Play the solitaire collection", True],
["./examples/weather_ui.png", "Open map", True],
["./examples/football_live.png", "click team 1 win", True],
["./examples/windows_panel.png", "switch to documents", True],
["./examples/paint_3d.png", "rotate left", True],
["./examples/finder.png", "view files from airdrop", True],
["./examples/amazon.jpg", "Search bar at the top of the page", True],
["./examples/semantic.jpg", "Home", True],
["./examples/accweather.jpg", "Select May", True],
["./examples/arxiv.jpg", "Home", True],
["./examples/health.jpg", "text labeled by 2023/11/26", True],
["./examples/ios_setting.png", "Turn off Do not disturb.", True],
]
title_markdown = """
# UI-TARS Pioneering Automated GUI Interaction with Native Agents
[[🤗Model](https://huggingface.co/bytedance-research/UI-TARS-7B-SFT)] [[⌨️Code](https://github.com/bytedance/UI-TARS)] [[📑Paper](https://github.com/bytedance/UI-TARS/blob/main/UI_TARS_paper.pdf)] [🏄[Midscene (Browser Automation)](https://github.com/web-infra-dev/Midscene)] [🫨[Discord](https://discord.gg/txAE43ps)]
"""
tos_markdown = """
### Terms of use
This demo is governed by the original license of UI-TARS. We strongly advise users not to knowingly generate or allow others to knowingly generate harmful content, including hate speech, violence, pornography, deception, etc. (注:本演示受UI-TARS的许可协议限制。我们强烈建议,用户不应传播及不应允许他人传播以下内容,包括但不限于仇恨言论、暴力、色情、欺诈相关的有害信息。)
"""
learn_more_markdown = """
### License
Apache License 2.0
"""
code_adapt_markdown = """
### Acknowledgments
The app code is modified from [ShowUI](https://huggingface.co/spaces/showlab/ShowUI)
"""
block_css = """
#buttons button { min-width: min(120px,100%); }
#chatbot img {
max-width: 80%;
max-height: 80vh;
width: auto;
height: auto;
object-fit: contain;
}
"""
def build_demo():
with gr.Blocks(title="UI-TARS Demo", theme=gr.themes.Default(), css=block_css) as demo:
state_session_id = gr.State(value=None)
gr.Markdown(title_markdown)
with gr.Row():
with gr.Column(scale=3):
imagebox = gr.Image(type="pil", label="Input Screenshot")
textbox = gr.Textbox(
show_label=True,
placeholder="Enter an instruction and press Submit",
label="Instruction",
)
submit_btn = gr.Button(value="Submit", variant="primary")
with gr.Column(scale=6):
output_gallery = gr.Gallery(label="Output with click", object_fit="contain", preview=True)
gr.HTML(
"""
<p><strong>Notice:</strong> The <span style="color: red;">red point</span> with a circle on the output image represents the predicted coordinates for a click.</p>
"""
)
with gr.Row():
output_coords = gr.Textbox(label="Final Coordinates")
image_size = gr.Textbox(label="Image Size")
gr.HTML("<p><strong>Expected result or not? help us improve! ⬇️</strong></p>")
with gr.Row(elem_id="action-buttons", equal_height=True):
upvote_btn = gr.Button(value="👍 Looks good!", variant="secondary")
downvote_btn = gr.Button(value="👎 Wrong coordinates!", variant="secondary")
clear_btn = gr.Button(value="🗑️ Clear", interactive=True)
with gr.Column(scale=3):
gr.Examples(
examples=[[e[0], e[1]] for e in examples],
inputs=[imagebox, textbox],
outputs=[textbox],
examples_per_page=3,
)
is_example_dropdown = gr.Dropdown(
choices=["True", "False"], value="False", visible=False, label="Is Example Image",
)
def set_is_example(query):
for _, example_query, is_example in examples:
if query.strip() == example_query.strip():
return str(is_example)
return "False"
textbox.change(set_is_example, inputs=[textbox], outputs=[is_example_dropdown])
def on_submit(image, query, is_example_image):
if image is None:
raise ValueError("No image provided. Please upload an image before submitting.")
session_id = datetime.now().strftime("%Y%m%d_%H%M%S")
images_during_iterations, click_coords = run_ui(image, query, session_id, is_example_image)
return images_during_iterations, click_coords, session_id, f"{image.width}x{image.height}"
submit_btn.click(
on_submit,
[imagebox, textbox, is_example_dropdown],
[output_gallery, output_coords, state_session_id, image_size],
)
clear_btn.click(
lambda: (None, None, None, None, None, None),
inputs=None,
outputs=[imagebox, textbox, output_gallery, output_coords, state_session_id, image_size],
queue=False,
)
upvote_btn.click(
lambda image, click_image, prompt_text, is_example:
update_vote("upvote", image, click_image, prompt_text, is_example),
inputs=[imagebox, output_gallery, textbox, is_example_dropdown],
outputs=[],
queue=False,
)
downvote_btn.click(
lambda image, click_image, prompt_text, is_example:
update_vote("downvote", image, click_image, prompt_text, is_example),
inputs=[imagebox, output_gallery, textbox, is_example_dropdown],
outputs=[],
queue=False,
)
gr.Markdown(tos_markdown)
gr.Markdown(learn_more_markdown)
gr.Markdown(code_adapt_markdown)
return demo
if __name__ == "__main__":
demo = build_demo()
demo.queue(api_open=False).launch(
server_name="0.0.0.0",
server_port=7860,
debug=True,
)