Spaces:
Running
Running
import os | |
import torch | |
import gradio as gr | |
from PIL import Image | |
from typing import List, Dict, Any | |
from transformers import AutoModel, AutoTokenizer | |
""" | |
Gradio app to run MiniCPM-V-4_5 int4 on CPU for image+text chat. | |
- Requires: pip install transformers accelerate gradio pillow | |
- Model: openbmb/MiniCPM-V-4_5-int4 (quantized, CPU-friendly) | |
- This script is self-contained and uses a simple multi-turn chat interface. | |
""" | |
MODEL_ID = os.environ.get("MINICPM_MODEL_ID", "openbmb/MiniCPM-V-4_5-int4") | |
# Global model/tokenizer, loaded once | |
model = None | |
tokenizer = None | |
def load_model(): | |
global model, tokenizer | |
if model is not None and tokenizer is not None: | |
return | |
# For CPU inference, keep it simple and avoid .cuda() / bfloat16 | |
# trust_remote_code is required because MiniCPM implements custom .chat() | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) | |
model = AutoModel.from_pretrained( | |
MODEL_ID, | |
trust_remote_code=True, | |
attn_implementation="sdpa", # SDPA is fine on CPU; avoid flash-attn on CPU | |
torch_dtype=torch.float32, # Safer default for CPU | |
device_map="cpu", # Ensure CPU execution | |
quantization_config=None, | |
) | |
model.eval() | |
def build_messages(history: List[Dict[str, Any]], image: Image.Image, user_input: str) -> List[Dict[str, Any]]: | |
""" | |
Convert Gradio chat history + current inputs into the message format expected by MiniCPM's .chat(). | |
history: List of {"role": "user"/"assistant", "content": "..."} pairs (text-only transcript). | |
image: PIL.Image or None for the current turn. | |
user_input: current user text. | |
Returns a msgs list with roles and content arrays [image?, text]. | |
""" | |
msgs = [] | |
# Reconstruct multi-turn context: interleave user/assistant turns | |
# We assume each user message is text-only and assistant reply is text-only in history. | |
# For the current turn, we can attach an image (if provided) and the user's text. | |
for turn in history: | |
# Each turn in history is a tuple (user_text, assistant_text) from gr.Chatbot | |
user_text, assistant_text = turn | |
if user_text is not None: | |
msgs.append({"role": "user", "content": [user_text]}) | |
if assistant_text is not None: | |
msgs.append({"role": "assistant", "content": [assistant_text]}) | |
# Append current user turn (with optional image) | |
content = [] | |
if image is not None: | |
# Ensure RGB | |
if image.mode != "RGB": | |
image = image.convert("RGB") | |
content.append(image) | |
if user_input and user_input.strip(): | |
content.append(user_input.strip()) | |
else: | |
# Ensure there is at least something in the content | |
content.append("") | |
msgs.append({"role": "user", "content": content}) | |
return msgs | |
def respond(user_text: str, image: Image.Image, chat_history: List[List[str]], enable_thinking: bool): | |
""" | |
Inference handler for Gradio. Returns updated chat history and clears the user textbox. | |
""" | |
load_model() | |
# Build MiniCPM messages | |
msgs = build_messages(chat_history or [], image, user_text) | |
# Run model.chat | |
with torch.inference_mode(): | |
answer = model.chat( | |
msgs=msgs, | |
tokenizer=tokenizer, | |
enable_thinking=enable_thinking | |
) | |
# Update history shown in Chatbot: append (user_text, answer) | |
# If user_text is empty but image provided, show a placeholder text. | |
shown_user_msg = user_text.strip() if (user_text and user_text.strip()) else "[Image]" | |
chat_history = chat_history + [[shown_user_msg, answer]] | |
return chat_history, "" | |
def clear_history(): | |
return [], None, "" | |
def demo_app(): | |
with gr.Blocks(title="MiniCPM-V-4_5-int4 (CPU) - Gradio", theme="soft") as demo: | |
gr.Markdown("## MiniCPM-V-4_5-int4 (CPU) Demo\nUpload an image (optional) and ask a question.") | |
with gr.Row(): | |
with gr.Column(scale=3): | |
chatbot = gr.Chatbot(height=420, type="messages", avatar_images=(None, None)) | |
with gr.Row(): | |
img = gr.Image(type="pil", label="Image (optional)", height=240) | |
user_in = gr.Textbox( | |
label="Your message", | |
placeholder="Ask something about the image or chat without an image...", | |
lines=3 | |
) | |
with gr.Row(): | |
enable_thinking = gr.Checkbox(value=False, label="Enable thinking mode") | |
send_btn = gr.Button("Send", variant="primary") | |
clear_btn = gr.Button("Clear") | |
with gr.Column(scale=1): | |
gr.Markdown("### Model") | |
gr.Markdown(f"- ID: `{MODEL_ID}`\n- Device: CPU\n- Quant: int4") | |
# Events | |
send_btn.click( | |
fn=respond, | |
inputs=[user_in, img, chatbot, enable_thinking], | |
outputs=[chatbot, user_in] | |
) | |
user_in.submit( | |
fn=respond, | |
inputs=[user_in, img, chatbot, enable_thinking], | |
outputs=[chatbot, user_in] | |
) | |
clear_btn.click( | |
fn=clear_history, | |
inputs=[], | |
outputs=[chatbot, img, user_in] | |
) | |
return demo | |
if __name__ == "__main__": | |
# Make sure we don't accidentally spawn CUDA context | |
os.environ.setdefault("CUDA_VISIBLE_DEVICES", "") | |
demo = demo_app() | |
demo.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860))) |