Spaces:
Sleeping
Sleeping
from fastapi import FastAPI | |
from pydantic import BaseModel | |
from transformers import AutoProcessor, Qwen2VLForConditionalGeneration | |
import torch, os, base64, io, logging, time | |
from typing import Any, Dict, List, Tuple | |
from PIL import Image | |
MODEL_ID = "osunlp/UGround-V1-72B" | |
CACHE_DIR = ( | |
os.environ.get("HF_HUB_CACHE") | |
or os.environ.get("HF_HOME") | |
or "/data/huggingface" | |
) | |
# PyTorch performance settings | |
# 1) Ensure CUDA kernel cache directory is writable/persistent to avoid recompilation stalls | |
KERNEL_CACHE_DIR = os.environ.get("PYTORCH_KERNEL_CACHE_PATH", "/tmp/torch_kernels") | |
os.environ["PYTORCH_KERNEL_CACHE_PATH"] = KERNEL_CACHE_DIR | |
try: | |
os.makedirs(KERNEL_CACHE_DIR, exist_ok=True) | |
except Exception: | |
pass | |
# 2) Enable TF32 for faster matmul on Ampere+ GPUs (minimal quality impact) | |
try: | |
torch.backends.cuda.matmul.allow_tf32 = True # type: ignore[attr-defined] | |
torch.backends.cudnn.allow_tf32 = True # type: ignore[attr-defined] | |
torch.set_float32_matmul_precision("high") # type: ignore[attr-defined] | |
except Exception: | |
pass | |
processor = AutoProcessor.from_pretrained( | |
MODEL_ID, trust_remote_code=True, cache_dir=CACHE_DIR, use_fast=False | |
) | |
model = Qwen2VLForConditionalGeneration.from_pretrained( | |
MODEL_ID, | |
dtype=torch.bfloat16, | |
device_map="auto", | |
trust_remote_code=True, | |
cache_dir=CACHE_DIR, | |
) | |
model.eval() | |
try: | |
torch.set_grad_enabled(False) | |
except Exception: | |
pass | |
app = FastAPI() | |
# Configure basic logging for debugging | |
logging.basicConfig( | |
level=logging.DEBUG, | |
format="%(asctime)s %(levelname)s %(name)s: %(message)s" | |
) | |
logger = logging.getLogger(__name__) | |
async def root(): | |
return {"status": "ok"} | |
class ChatCompletionRequest(BaseModel): | |
model: str | |
messages: List[Dict[str, Any]] | |
max_tokens: int = 256 | |
MAX_IMAGE_WIDTH = 512 | |
MAX_IMAGE_HEIGHT = 388 | |
def _decode_base64_image(data_url: str) -> Image.Image: | |
try: | |
is_data_url = data_url.startswith("data:") | |
if is_data_url: | |
header, b64data = data_url.split(",", 1) | |
logger.debug("Decoding image from data URL; header prefix=%r", header[:50]) | |
else: | |
b64data = data_url | |
logger.debug("Decoding image from raw base64 string; length=%d", len(b64data)) | |
img_bytes = base64.b64decode(b64data) | |
img = Image.open(io.BytesIO(img_bytes)).convert("RGB") | |
orig_w, orig_h = img.width, img.height | |
# Downscale if larger than bounds, preserving aspect ratio | |
if orig_w > MAX_IMAGE_WIDTH or orig_h > MAX_IMAGE_HEIGHT: | |
target = (MAX_IMAGE_WIDTH, MAX_IMAGE_HEIGHT) | |
img = img.copy() | |
img.thumbnail(target, Image.LANCZOS) | |
logger.debug( | |
"Resized image from %sx%s to %sx%s (bounds %sx%s)", | |
orig_w, | |
orig_h, | |
img.width, | |
img.height, | |
MAX_IMAGE_WIDTH, | |
MAX_IMAGE_HEIGHT, | |
) | |
try: | |
logger.debug("Decoded image: size=%sx%s mode=%s", img.width, img.height, img.mode) | |
except Exception: | |
logger.debug("Decoded image but could not log image metadata") | |
return img | |
except Exception: | |
logger.exception("Failed to decode base64 image") | |
raise | |
def _to_qwen_messages_and_images(messages: List[Dict[str, Any]]) -> Tuple[List[Dict[str, Any]], List[Any]]: | |
qwen_msgs: List[Dict[str, Any]] = [] | |
images: List[Any] = [] | |
logger.debug("Begin parsing messages: count=%d", len(messages) if messages else 0) | |
for idx, msg in enumerate(messages): | |
role = msg.get("role", "user") | |
content = msg.get("content") | |
logger.debug("Processing message #%d role=%s content_type=%s", idx, role, type(content).__name__) | |
q_content: List[Dict[str, Any]] = [] | |
if isinstance(content, str): | |
logger.debug("Message #%d text length=%d", idx, len(content)) | |
q_content.append({"type": "text", "text": content}) | |
elif isinstance(content, list): | |
logger.debug("Message #%d has %d content parts", idx, len(content)) | |
for pidx, part in enumerate(content): | |
ptype = part.get("type") | |
logger.debug("Part #%d type=%s", pidx, ptype) | |
if ptype == "text": | |
text_val = part.get("text") or part.get("content") or "" | |
logger.debug("Part #%d text length=%d", pidx, len(text_val)) | |
q_content.append({"type": "text", "text": text_val}) | |
elif ptype in ("image", "image_url"): | |
# OpenAI style: {type:"image_url", image_url:{url:"..."}} | |
url = part.get("image") | |
if url is None and isinstance(part.get("image_url"), dict): | |
url = part["image_url"].get("url") | |
if isinstance(url, str) and url.startswith("data:image"): | |
logger.debug("Part #%d image provided as base64 data URL", pidx) | |
img = _decode_base64_image(url) | |
images.append(img) | |
q_content.append({"type": "image", "image": img}) | |
else: | |
# URL or non-base64 string | |
logger.debug("Part #%d image provided as URL or non-base64 string: %s", pidx, str(url)[:200]) | |
images.append(url) | |
q_content.append({"type": "image", "image": url}) | |
else: | |
# Unknown content; coerce to text | |
logger.debug("Message #%d unknown content type; coercing to text", idx) | |
q_content.append({"type": "text", "text": str(content)}) | |
qwen_msgs.append({"role": role, "content": q_content}) | |
logger.debug("Finished parsing messages: qwen_msgs=%d images=%d", len(qwen_msgs), len(images)) | |
return qwen_msgs, images | |
def _make_tiny_base64_png(size: Tuple[int, int] = (64, 48), color: Tuple[int, int, int] = (128, 128, 128)) -> str: | |
buf = io.BytesIO() | |
Image.new("RGB", size, color).save(buf, format="PNG") | |
data = base64.b64encode(buf.getvalue()).decode("ascii") | |
return f"data:image/png;base64,{data}" | |
async def _startup_warmup(): | |
if os.environ.get("DISABLE_WARMUP", "0") == "1": | |
logger.info("Warmup disabled via DISABLE_WARMUP=1") | |
return | |
try: | |
logger.info("Warmup start: compiling kernels (text + tiny image)") | |
# Text-only warmup | |
text_msgs: List[Dict[str, Any]] = [ | |
{"role": "user", "content": "Hello"} | |
] | |
qmsgs_t, _ = _to_qwen_messages_and_images(text_msgs) | |
prompt_t = processor.apply_chat_template(qmsgs_t, tokenize=False, add_generation_prompt=True) | |
inputs_t = processor(text=[prompt_t], images=None, padding=True, return_tensors="pt") | |
inputs_t = inputs_t.to(model.device) | |
_t0 = time.perf_counter() | |
with torch.no_grad(): | |
_ = model.generate(**inputs_t, max_new_tokens=int(os.environ.get("WARMUP_MAX_NEW_TOKENS", "4")), max_time=float(os.environ.get("WARMUP_MAX_TIME_SECONDS", "3"))) | |
logger.info("Text warmup done in %.1f ms", (time.perf_counter() - _t0) * 1000.0) | |
# Tiny image + text warmup | |
tiny_url = _make_tiny_base64_png() | |
viz_msgs: List[Dict[str, Any]] = [ | |
{"role": "user", "content": [ | |
{"type": "text", "text": "Describe the image"}, | |
{"type": "image_url", "image_url": {"url": tiny_url}}, | |
]} | |
] | |
qmsgs_v, images_v = _to_qwen_messages_and_images(viz_msgs) | |
prompt_v = processor.apply_chat_template(qmsgs_v, tokenize=False, add_generation_prompt=True) | |
inputs_v = processor(text=[prompt_v], images=images_v, padding=True, return_tensors="pt") | |
inputs_v = inputs_v.to(model.device) | |
_t1 = time.perf_counter() | |
with torch.no_grad(): | |
_ = model.generate(**inputs_v, max_new_tokens=int(os.environ.get("WARMUP_MAX_NEW_TOKENS", "4")), max_time=float(os.environ.get("WARMUP_MAX_TIME_SECONDS", "3"))) | |
logger.info("Vision warmup done in %.1f ms", (time.perf_counter() - _t1) * 1000.0) | |
logger.info("Warmup complete") | |
except Exception: | |
logger.exception("Warmup failed") | |
async def chat_completions(req: ChatCompletionRequest): | |
logger.debug( | |
"Request received: model=%s, max_tokens=%s, message_count=%d", | |
req.model, | |
req.max_tokens, | |
len(req.messages) if req.messages is not None else 0, | |
) | |
if req.messages: | |
logger.debug("First message preview: %s", str(req.messages[0])[:300]) | |
qwen_messages, image_inputs = _to_qwen_messages_and_images(req.messages) | |
logger.debug( | |
"Converted messages: qwen_count=%d, images_count=%d", | |
len(qwen_messages), | |
len(image_inputs) if image_inputs is not None else 0, | |
) | |
if qwen_messages: | |
logger.debug("First qwen message preview: %s", str(qwen_messages[0])[:300]) | |
prompt_text = processor.apply_chat_template( | |
qwen_messages, tokenize=False, add_generation_prompt=True | |
) | |
logger.debug("Prompt length (chars)=%d; preview=%r", len(prompt_text), prompt_text[:200]) | |
inputs = processor( | |
text=[prompt_text], | |
images=image_inputs if image_inputs else None, | |
padding=True, | |
return_tensors="pt", | |
) | |
try: | |
tensor_info_pre = { | |
k: (tuple(v.shape), str(getattr(v, "dtype", "<na>"))) | |
for k, v in inputs.items() | |
if hasattr(v, "shape") | |
} | |
logger.debug("Processor outputs (pre .to): %s", tensor_info_pre) | |
except Exception: | |
logger.debug("Could not summarize processor outputs before device move") | |
inputs = inputs.to(model.device) | |
try: | |
tensor_info_post = { | |
k: ( | |
tuple(v.shape), | |
str(getattr(v, "dtype", "<na>")), | |
str(getattr(v, "device", "<na>")), | |
) | |
for k, v in inputs.items() | |
if torch.is_tensor(v) | |
} | |
logger.debug("Inputs moved to device=%s; tensor_info=%s", getattr(model, "device", "<unknown>"), tensor_info_post) | |
except Exception: | |
logger.debug("Could not summarize inputs after device move") | |
logger.debug("Starting generation: max_new_tokens=%d", req.max_tokens) | |
_t0 = time.perf_counter() | |
generated_ids = model.generate(**inputs, max_new_tokens=req.max_tokens) | |
_elapsed_ms = (time.perf_counter() - _t0) * 1000.0 | |
try: | |
logger.debug( | |
"Generation done in %.1f ms; generated_ids shape=%s dtype=%s device=%s", | |
_elapsed_ms, | |
tuple(generated_ids.shape) if hasattr(generated_ids, "shape") else "<na>", | |
str(getattr(generated_ids, "dtype", "<na>")), | |
str(getattr(generated_ids, "device", "<na>")), | |
) | |
except Exception: | |
logger.debug("Could not summarize generated_ids") | |
trimmed = [ | |
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) | |
] | |
try: | |
lengths_in = [row.size(0) for row in inputs.input_ids] | |
lengths_out = [row.size(0) for row in generated_ids] | |
logger.debug("Token lengths: input=%s, output=%s", lengths_in, lengths_out) | |
except Exception: | |
logger.debug("Could not compute token length summaries") | |
output_texts = processor.batch_decode( | |
trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False | |
) | |
text = output_texts[0] if output_texts else "" | |
logger.debug( | |
"Decoded %d sequences; first_text_len=%d", | |
len(output_texts), | |
len(text) if text else 0, | |
) | |
if text: | |
logger.debug("Output preview: %r", text[:500]) | |
return { | |
"id": "chatcmpl-uground72b", | |
"object": "chat.completion", | |
"choices": [{ | |
"index": 0, | |
"message": {"role": "assistant", "content": text}, | |
"finish_reason": "stop" | |
}] | |
} | |