echarlaix HF Staff commited on
Commit
3586102
·
1 Parent(s): ce41e3e

add model choices

Browse files
Files changed (1) hide show
  1. app.py +29 -6
app.py CHANGED
@@ -12,12 +12,22 @@ from transformers import AutoModelForImageTextToText, AutoProcessor
12
  from transformers.generation.streamers import TextIteratorStreamer
13
  from optimum.intel import OVModelForVisualCausalLM
14
 
15
- model_id = "echarlaix/SmolVLM2-2.2B-Instruct-openvino"
16
- # model_id = "echarlaix/SmolVLM-256M-Instruct-openvino"
17
- # model_id = "echarlaix/SmolVLM2-500M-Video-Instruct-openvino"
18
 
19
- processor = AutoProcessor.from_pretrained(model_id)
20
- model = OVModelForVisualCausalLM.from_pretrained(model_id)
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
  IMAGE_FILE_TYPES = (".jpg", ".jpeg", ".png", ".webp")
23
  VIDEO_FILE_TYPES = (".mp4", ".mov", ".webm")
@@ -152,7 +162,12 @@ def process_history(history: list[dict]) -> list[dict]:
152
 
153
 
154
  @torch.inference_mode()
155
- def generate(message: dict, history: list[dict], system_prompt: str = "", max_new_tokens: int = 512) -> Iterator[str]:
 
 
 
 
 
156
  if not validate_media_constraints(message):
157
  yield ""
158
  return
@@ -238,6 +253,13 @@ examples = [
238
  ],
239
  ]
240
 
 
 
 
 
 
 
 
241
  demo = gr.ChatInterface(
242
  fn=generate,
243
  type="messages",
@@ -248,6 +270,7 @@ demo = gr.ChatInterface(
248
  ),
249
  multimodal=True,
250
  additional_inputs=[
 
251
  gr.Textbox(label="System Prompt", value="You are a helpful assistant."),
252
  gr.Slider(label="Max New Tokens", minimum=100, maximum=2000, step=10, value=700),
253
  ],
 
12
  from transformers.generation.streamers import TextIteratorStreamer
13
  from optimum.intel import OVModelForVisualCausalLM
14
 
 
 
 
15
 
16
+ default_model_id = "echarlaix/SmolVLM2-2.2B-Instruct-openvino"
17
+
18
+ model_cache = {
19
+ "model_id" : default_model_id,
20
+ "processor" : AutoProcessor.from_pretrained(default_model_id),
21
+ "model" : OVModelForVisualCausalLM.from_pretrained(default_model_id),
22
+ }
23
+
24
+ def update_model(model_id):
25
+ if model_cache["model_id"] != model_id:
26
+ model_cache["model_id"] = model_id
27
+ model_cache["processor"] = AutoProcessor.from_pretrained(model_id)
28
+ model_cache["model"] = OVModelForVisualCausalLM.from_pretrained(model_id)
29
+
30
+
31
 
32
  IMAGE_FILE_TYPES = (".jpg", ".jpeg", ".png", ".webp")
33
  VIDEO_FILE_TYPES = (".mp4", ".mov", ".webm")
 
162
 
163
 
164
  @torch.inference_mode()
165
+ def generate(message: dict, history: list[dict], model_id: str, system_prompt: str = "", max_new_tokens: int = 512) -> Iterator[str]:
166
+
167
+ update_model(model_id)
168
+ processor = model_cache["processor"]
169
+ model = model_cache["model"]
170
+
171
  if not validate_media_constraints(message):
172
  yield ""
173
  return
 
253
  ],
254
  ]
255
 
256
+
257
+ model_choices = [
258
+ "echarlaix/SmolVLM2-2.2B-Instruct-openvino",
259
+ "echarlaix/SmolVLM-256M-Instruct-openvino",
260
+ "echarlaix/SmolVLM2-500M-Video-Instruct-openvino",
261
+ ]
262
+
263
  demo = gr.ChatInterface(
264
  fn=generate,
265
  type="messages",
 
270
  ),
271
  multimodal=True,
272
  additional_inputs=[
273
+ gr.Dropdown(model_choices, value=model_choices[0], label="Model ID"),
274
  gr.Textbox(label="System Prompt", value="You are a helpful assistant."),
275
  gr.Slider(label="Max New Tokens", minimum=100, maximum=2000, step=10, value=700),
276
  ],