VLM_Test / app.py
robot0820's picture
Update app.py
513a480 verified
# app.py
import torch
import gradio as gr
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
from deepseek_vl.models import VLChatProcessor, MultiModalityCausalLM
# ==== 模型設定 ====
model_path = "deepseek-ai/deepseek-vl-7b-chat"
# BitsAndBytes 4-bit 量化設定
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True
)
# 載入 processor 和 tokenizer
vl_chat_processor: VLChatProcessor = VLChatProcessor.from_pretrained(model_path)
tokenizer = vl_chat_processor.tokenizer
# 載入模型
vl_gpt: MultiModalityCausalLM = AutoModelForCausalLM.from_pretrained(
model_path,
quantization_config=bnb_config,
device_map="auto",
trust_remote_code=True
).eval()
# ==== 對話歷史 ====
chat_history = []
# ==== 文字+圖片推理函式 ====
def chat_with_image(image, user_message):
global chat_history
try:
# 建立對話內容
conversation = chat_history.copy()
conversation.append({
"role": "User",
"content": "<image_placeholder>" + user_message,
"images": [image] if image else []
})
conversation.append({"role": "Assistant", "content": ""})
# 準備輸入
prepare_inputs = vl_chat_processor(
conversations=conversation,
images=[image] if image else [],
force_batchify=True
).to(vl_gpt.device)
# 轉成 dict,並正確處理 dtype
prepare_inputs = {k: getattr(prepare_inputs, k) for k in prepare_inputs.__dataclass_fields__.keys()}
new_inputs = {}
for k, v in prepare_inputs.items():
if torch.is_tensor(v):
if k in ["input_ids", "labels"]:
new_inputs[k] = v.to(torch.long)
else:
new_inputs[k] = v.to(torch.float16)
else:
new_inputs[k] = v
prepare_inputs = new_inputs
# 取得 embeddings
inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)
# 生成回答
outputs = vl_gpt.language_model.generate(
inputs_embeds=inputs_embeds,
attention_mask=prepare_inputs["attention_mask"],
pad_token_id=tokenizer.eos_token_id,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
max_new_tokens=128,
do_sample=False,
use_cache=True
)
# 解碼
answer = tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True)
# 更新歷史
chat_history.append((user_message, answer))
return answer, chat_history
except Exception as e:
return f"Error: {str(e)}", chat_history
def reset_chat():
global chat_history
chat_history = []
return "", []
# ==== Gradio Web UI ====
with gr.Blocks() as demo:
gr.Markdown("# DeepSeek-VL-7B-Chat Demo (4-bit, float16)")
with gr.Row():
image_input = gr.Image(type="pil", label="Upload Image")
text_input = gr.Textbox(lines=2, placeholder="Ask about the image...")
with gr.Row():
submit_btn = gr.Button("Submit")
reset_btn = gr.Button("Reset Chat")
output_text = gr.Textbox(label="Answer")
chat_display = gr.Chatbot(label="Chat History")
submit_btn.click(chat_with_image, inputs=[image_input, text_input], outputs=[output_text, chat_display])
reset_btn.click(reset_chat, inputs=[], outputs=[output_text, chat_display])
if __name__ == "__main__":
demo.launch()