rahul7star commited on
Commit
c9a7977
·
verified ·
1 Parent(s): 66ac461

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +119 -14
app.py CHANGED
@@ -2,13 +2,13 @@ import gradio as gr
2
  import torch
3
  from PIL import Image
4
  from transformers import AutoTokenizer, AutoModelForCausalLM
5
- #import spaces
6
 
7
  # Model configuration
8
  MID = "apple/FastVLM-0.5B"
9
  IMAGE_TOKEN_INDEX = -200
10
 
11
- # Load model and tokenizer
12
  tok = None
13
  model = None
14
 
@@ -37,41 +37,146 @@ def load_model():
37
 
38
  #@spaces.GPU(duration=60)
39
  def caption_image(image, custom_prompt=None):
 
 
 
 
 
 
 
 
 
 
40
  if image is None:
41
  return "Please upload an image first."
42
 
43
  try:
 
44
  tok, model = load_model()
 
45
  if image.mode != "RGB":
46
  image = image.convert("RGB")
47
-
 
48
  prompt = custom_prompt if custom_prompt else "Describe this image in detail."
49
- messages = [{"role": "user", "content": f"<image>\n{prompt}"}]
50
- rendered = tok.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
51
-
 
 
 
 
 
 
 
 
 
52
  pre, post = rendered.split("<image>", 1)
 
 
53
  pre_ids = tok(pre, return_tensors="pt", add_special_tokens=False).input_ids
54
  post_ids = tok(post, return_tensors="pt", add_special_tokens=False).input_ids
55
-
 
56
  img_tok = torch.tensor([[IMAGE_TOKEN_INDEX]], dtype=pre_ids.dtype)
57
  input_ids = torch.cat([pre_ids, img_tok, post_ids], dim=1).to(model.device)
58
  attention_mask = torch.ones_like(input_ids, device=model.device)
59
-
60
- px = model.get_vision_tower().image_processor(images=image, return_tensors="pt")["pixel_values"]
 
 
 
61
  px = px.to(model.device, dtype=model.dtype)
62
-
 
63
  with torch.no_grad():
64
  out = model.generate(
65
  inputs=input_ids,
66
  attention_mask=attention_mask,
67
  images=px,
68
  max_new_tokens=128,
69
- do_sample=False,
70
  temperature=1.0,
71
  )
72
-
 
73
  generated_text = tok.decode(out[0], skip_special_tokens=True)
74
- return generated_text.split("assistant")[-1].strip() if "assistant" in generated_text else generated_text
75
-
 
 
 
 
 
 
 
76
  except Exception as e:
77
  return f"Error generating caption: {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import torch
3
  from PIL import Image
4
  from transformers import AutoTokenizer, AutoModelForCausalLM
5
+ import spaces
6
 
7
  # Model configuration
8
  MID = "apple/FastVLM-0.5B"
9
  IMAGE_TOKEN_INDEX = -200
10
 
11
+ # Load model and tokenizer (will be loaded on first GPU allocation)
12
  tok = None
13
  model = None
14
 
 
37
 
38
  #@spaces.GPU(duration=60)
39
  def caption_image(image, custom_prompt=None):
40
+ """
41
+ Generate a caption for the input image.
42
+
43
+ Args:
44
+ image: PIL Image from Gradio
45
+ custom_prompt: Optional custom prompt to use instead of default
46
+
47
+ Returns:
48
+ Generated caption text
49
+ """
50
  if image is None:
51
  return "Please upload an image first."
52
 
53
  try:
54
+ # Load model if not already loaded
55
  tok, model = load_model()
56
+ # Convert image to RGB if needed
57
  if image.mode != "RGB":
58
  image = image.convert("RGB")
59
+
60
+ # Use custom prompt or default
61
  prompt = custom_prompt if custom_prompt else "Describe this image in detail."
62
+
63
+ # Build chat message
64
+ messages = [
65
+ {"role": "user", "content": f"<image>\n{prompt}"}
66
+ ]
67
+
68
+ # Render to string to place <image> token correctly
69
+ rendered = tok.apply_chat_template(
70
+ messages, add_generation_prompt=True, tokenize=False
71
+ )
72
+
73
+ # Split at image token
74
  pre, post = rendered.split("<image>", 1)
75
+
76
+ # Tokenize text around the image token
77
  pre_ids = tok(pre, return_tensors="pt", add_special_tokens=False).input_ids
78
  post_ids = tok(post, return_tensors="pt", add_special_tokens=False).input_ids
79
+
80
+ # Insert IMAGE token id at placeholder position
81
  img_tok = torch.tensor([[IMAGE_TOKEN_INDEX]], dtype=pre_ids.dtype)
82
  input_ids = torch.cat([pre_ids, img_tok, post_ids], dim=1).to(model.device)
83
  attention_mask = torch.ones_like(input_ids, device=model.device)
84
+
85
+ # Preprocess image using model's vision tower
86
+ px = model.get_vision_tower().image_processor(
87
+ images=image, return_tensors="pt"
88
+ )["pixel_values"]
89
  px = px.to(model.device, dtype=model.dtype)
90
+
91
+ # Generate caption
92
  with torch.no_grad():
93
  out = model.generate(
94
  inputs=input_ids,
95
  attention_mask=attention_mask,
96
  images=px,
97
  max_new_tokens=128,
98
+ do_sample=False, # Deterministic generation
99
  temperature=1.0,
100
  )
101
+
102
+ # Decode and return the generated text
103
  generated_text = tok.decode(out[0], skip_special_tokens=True)
104
+
105
+ # Extract only the assistant's response
106
+ if "assistant" in generated_text:
107
+ response = generated_text.split("assistant")[-1].strip()
108
+ else:
109
+ response = generated_text
110
+
111
+ return response
112
+
113
  except Exception as e:
114
  return f"Error generating caption: {str(e)}"
115
+
116
+ # Create Gradio interface
117
+ with gr.Blocks(title="FastVLM Image Captioning") as demo:
118
+ gr.Markdown(
119
+ """
120
+ # 🖼️ FastVLM Image Captioning
121
+
122
+ Upload an image to generate a detailed caption using Apple's FastVLM-0.5B model.
123
+ You can use the default prompt or provide your own custom prompt.
124
+ """
125
+ )
126
+
127
+ with gr.Row():
128
+ with gr.Column():
129
+ image_input = gr.Image(
130
+ type="pil",
131
+ label="Upload Image",
132
+ elem_id="image-upload"
133
+ )
134
+
135
+ custom_prompt = gr.Textbox(
136
+ label="Custom Prompt (Optional)",
137
+ placeholder="Leave empty for default: 'Describe this image in detail.'",
138
+ lines=2
139
+ )
140
+
141
+ with gr.Row():
142
+ clear_btn = gr.ClearButton([image_input, custom_prompt])
143
+ generate_btn = gr.Button("Generate Caption", variant="primary")
144
+
145
+ with gr.Column():
146
+ output = gr.Textbox(
147
+ label="Generated Caption",
148
+ lines=8,
149
+ max_lines=15,
150
+ show_copy_button=True
151
+ )
152
+
153
+ # Event handlers
154
+ generate_btn.click(
155
+ fn=caption_image,
156
+ inputs=[image_input, custom_prompt],
157
+ outputs=output
158
+ )
159
+
160
+ # Also generate on image upload if no custom prompt
161
+ image_input.change(
162
+ fn=lambda img, prompt: caption_image(img, prompt) if img is not None and not prompt else None,
163
+ inputs=[image_input, custom_prompt],
164
+ outputs=output
165
+ )
166
+
167
+ gr.Markdown(
168
+ """
169
+ ---
170
+ **Model:** [apple/FastVLM-0.5B](https://huggingface.co/apple/FastVLM-0.5B)
171
+
172
+ **Note:** This Space uses ZeroGPU for dynamic GPU allocation.
173
+ """
174
+ )
175
+
176
+ if __name__ == "__main__":
177
+ demo.launch(
178
+ share=False,
179
+ show_error=True,
180
+ server_name="0.0.0.0",
181
+ server_port=7860
182
+ )