File size: 5,429 Bytes
f8d8011
 
 
 
b7bf121
f8d8011
 
 
 
 
b7bf121
 
 
f8d8011
b7bf121
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f8d8011
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b7bf121
 
f8d8011
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b7bf121
f8d8011
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
import gradio as gr
import torch
from PIL import Image
from transformers import AutoTokenizer, AutoModelForCausalLM
import spaces

# Model configuration
MID = "apple/FastVLM-0.5B"
IMAGE_TOKEN_INDEX = -200

# Load model and tokenizer (will be loaded on first GPU allocation)
tok = None
model = None

def load_model():
    global tok, model
    if tok is None or model is None:
        print("Loading model...")
        tok = AutoTokenizer.from_pretrained(MID, trust_remote_code=True)
        model = AutoModelForCausalLM.from_pretrained(
            MID,
            torch_dtype=torch.float16,
            device_map="cuda",
            trust_remote_code=True,
        )
        print("Model loaded successfully!")
    return tok, model

@spaces.GPU(duration=60)
def caption_image(image, custom_prompt=None):
    """
    Generate a caption for the input image.
    
    Args:
        image: PIL Image from Gradio
        custom_prompt: Optional custom prompt to use instead of default
    
    Returns:
        Generated caption text
    """
    if image is None:
        return "Please upload an image first."
    
    try:
        # Load model if not already loaded
        tok, model = load_model()
        # Convert image to RGB if needed
        if image.mode != "RGB":
            image = image.convert("RGB")
        
        # Use custom prompt or default
        prompt = custom_prompt if custom_prompt else "Describe this image in detail."
        
        # Build chat message
        messages = [
            {"role": "user", "content": f"<image>\n{prompt}"}
        ]
        
        # Render to string to place <image> token correctly
        rendered = tok.apply_chat_template(
            messages, add_generation_prompt=True, tokenize=False
        )
        
        # Split at image token
        pre, post = rendered.split("<image>", 1)
        
        # Tokenize text around the image token
        pre_ids = tok(pre, return_tensors="pt", add_special_tokens=False).input_ids
        post_ids = tok(post, return_tensors="pt", add_special_tokens=False).input_ids
        
        # Insert IMAGE token id at placeholder position
        img_tok = torch.tensor([[IMAGE_TOKEN_INDEX]], dtype=pre_ids.dtype)
        input_ids = torch.cat([pre_ids, img_tok, post_ids], dim=1).to(model.device)
        attention_mask = torch.ones_like(input_ids, device=model.device)
        
        # Preprocess image using model's vision tower
        px = model.get_vision_tower().image_processor(
            images=image, return_tensors="pt"
        )["pixel_values"]
        px = px.to(model.device, dtype=model.dtype)
        
        # Generate caption
        with torch.no_grad():
            out = model.generate(
                inputs=input_ids,
                attention_mask=attention_mask,
                images=px,
                max_new_tokens=128,
                do_sample=False,  # Deterministic generation
                temperature=1.0,
            )
        
        # Decode and return the generated text
        generated_text = tok.decode(out[0], skip_special_tokens=True)
        
        # Extract only the assistant's response
        if "assistant" in generated_text:
            response = generated_text.split("assistant")[-1].strip()
        else:
            response = generated_text
            
        return response
        
    except Exception as e:
        return f"Error generating caption: {str(e)}"

# Create Gradio interface
with gr.Blocks(title="FastVLM Image Captioning") as demo:
    gr.Markdown(
        """
        # 🖼️ FastVLM Image Captioning
        
        Upload an image to generate a detailed caption using Apple's FastVLM-0.5B model.
        You can use the default prompt or provide your own custom prompt.
        """
    )
    
    with gr.Row():
        with gr.Column():
            image_input = gr.Image(
                type="pil",
                label="Upload Image",
                elem_id="image-upload"
            )
            
            custom_prompt = gr.Textbox(
                label="Custom Prompt (Optional)",
                placeholder="Leave empty for default: 'Describe this image in detail.'",
                lines=2
            )
            
            with gr.Row():
                clear_btn = gr.ClearButton([image_input, custom_prompt])
                generate_btn = gr.Button("Generate Caption", variant="primary")
        
        with gr.Column():
            output = gr.Textbox(
                label="Generated Caption",
                lines=8,
                max_lines=15,
                show_copy_button=True
            )
    
    # Event handlers
    generate_btn.click(
        fn=caption_image,
        inputs=[image_input, custom_prompt],
        outputs=output
    )
    
    # Also generate on image upload if no custom prompt
    image_input.change(
        fn=lambda img, prompt: caption_image(img, prompt) if img is not None and not prompt else None,
        inputs=[image_input, custom_prompt],
        outputs=output
    )
    
    gr.Markdown(
        """
        ---
        **Model:** [apple/FastVLM-0.5B](https://huggingface.co/apple/FastVLM-0.5B)
        
        **Note:** This Space uses ZeroGPU for dynamic GPU allocation.
        """
    )

if __name__ == "__main__":
    demo.launch(
        share=False,
        show_error=True,
        server_name="0.0.0.0",
        server_port=7860
    )