Yaowei222 commited on
Commit
0da2326
·
1 Parent(s): 9bc4f8c

fix md and pipeline

Browse files
app/business_logic.py CHANGED
@@ -412,8 +412,8 @@ def change_custmization_mode(custmization_mode, input_mask_mode):
412
  gr.update(value="<s>Select a input mask mode</s>", visible=False),
413
  gr.update(value="<s>Input target image & mask (Iterate clicking or brushing until the target is covered)</s>", visible=False),
414
  gr.update(value="<s>View or modify the target mask</s>", visible=False),
415
- gr.update(value="3. Input text prompt (necessary)"),
416
- gr.update(value="4. Submit and view the output"),
417
  gr.update(visible=False),
418
  gr.update(visible=False),
419
 
@@ -426,11 +426,11 @@ def change_custmization_mode(custmization_mode, input_mask_mode):
426
  gr.update(interactive=True, visible=True),
427
  gr.update(interactive=True, visible=True),
428
  gr.update(interactive=True, visible=True),
429
- gr.update(value="3. Select a input mask mode", visible=True),
430
- gr.update(value="4. Input target image & mask (Iterate clicking or brushing until the target is covered)", visible=True),
431
- gr.update(value="6. View or modify the target mask", visible=True),
432
- gr.update(value="5. Input text prompt (optional)", visible=True),
433
- gr.update(value="7. Submit and view the output", visible=True),
434
  gr.update(visible=True, value="Precise mask"),
435
  gr.update(visible=True),
436
  )
@@ -441,11 +441,11 @@ def change_custmization_mode(custmization_mode, input_mask_mode):
441
  gr.update(interactive=True, visible=True),
442
  gr.update(interactive=True, visible=True),
443
  gr.update(interactive=True, visible=True),
444
- gr.update(value="3. Select a input mask mode", visible=True),
445
- gr.update(value="4. Input target image & mask (Iterate clicking or brushing until the target is covered)", visible=True),
446
- gr.update(value="6. View or modify the target mask", visible=True),
447
- gr.update(value="5. Input text prompt (optional)", visible=True),
448
- gr.update(value="7. Submit and view the output", visible=True),
449
  gr.update(visible=True, value="User-drawn mask"),
450
  gr.update(visible=True),
451
  )
 
412
  gr.update(value="<s>Select a input mask mode</s>", visible=False),
413
  gr.update(value="<s>Input target image & mask (Iterate clicking or brushing until the target is covered)</s>", visible=False),
414
  gr.update(value="<s>View or modify the target mask</s>", visible=False),
415
+ gr.update(value="3\. Input text prompt (necessary)"),
416
+ gr.update(value="4\. Submit and view the output"),
417
  gr.update(visible=False),
418
  gr.update(visible=False),
419
 
 
426
  gr.update(interactive=True, visible=True),
427
  gr.update(interactive=True, visible=True),
428
  gr.update(interactive=True, visible=True),
429
+ gr.update(value="3\. Select a input mask mode", visible=True),
430
+ gr.update(value="4\. Input target image & mask (Iterate clicking or brushing until the target is covered)", visible=True),
431
+ gr.update(value="6\. View or modify the target mask", visible=True),
432
+ gr.update(value="5\. Input text prompt (optional)", visible=True),
433
+ gr.update(value="7\. Submit and view the output", visible=True),
434
  gr.update(visible=True, value="Precise mask"),
435
  gr.update(visible=True),
436
  )
 
441
  gr.update(interactive=True, visible=True),
442
  gr.update(interactive=True, visible=True),
443
  gr.update(interactive=True, visible=True),
444
+ gr.update(value="3\. Select a input mask mode", visible=True),
445
+ gr.update(value="4\. Input target image & mask (Iterate clicking or brushing until the target is covered)", visible=True),
446
+ gr.update(value="6\. View or modify the target mask", visible=True),
447
+ gr.update(value="5\. Input text prompt (optional)", visible=True),
448
+ gr.update(value="7\. Submit and view the output", visible=True),
449
  gr.update(visible=True, value="User-drawn mask"),
450
  gr.update(visible=True),
451
  )
app/ui_components.py CHANGED
@@ -44,7 +44,7 @@ def create_customization_section():
44
  with gr.Row():
45
  # Add a note to remind users to click Clear before starting
46
  md_custmization_mode = gr.Markdown(
47
- "1. Select a Customization Mode\n\n*Tip: Please click the Clear button first to reset all states before starting a new task.*"
48
  )
49
  with gr.Row():
50
  custmization_mode = gr.Radio(
@@ -61,7 +61,7 @@ def create_customization_section():
61
  def create_image_input_section():
62
  """Create image input section optimized for left column layout."""
63
  # Reference image section
64
- md_image_reference = gr.Markdown("2. Input reference image")
65
  with gr.Group():
66
  image_reference = gr.Image(
67
  label="Reference Image",
@@ -73,7 +73,7 @@ def create_image_input_section():
73
  )
74
 
75
  # Input mask mode selection
76
- md_input_mask_mode = gr.Markdown("3. Select input mask mode")
77
  with gr.Group():
78
  input_mask_mode = gr.Radio(
79
  ["Precise mask", "User-drawn mask"],
@@ -84,7 +84,7 @@ def create_image_input_section():
84
  )
85
 
86
  # Target image section
87
- md_target_image = gr.Markdown("4. Input target image & mask (Iterate clicking or brushing until the target is covered)")
88
 
89
  # Precise mask mode
90
  with gr.Group():
@@ -129,7 +129,7 @@ def create_image_input_section():
129
 
130
  def create_prompt_section():
131
  """Create the text prompt input section with improved layout."""
132
- md_prompt = gr.Markdown("5. Input text prompt (optional)")
133
  with gr.Group():
134
  prompt = gr.Textbox(
135
  placeholder="Please input the description for the target scene.",
@@ -243,7 +243,7 @@ def create_advanced_options_section():
243
 
244
  def create_mask_operation_section():
245
  """Create mask operation section optimized for right column (outputs)."""
246
- md_mask_operation = gr.Markdown("6. View or modify the target mask")
247
 
248
  with gr.Group():
249
  # Mask gallery with responsive layout
@@ -293,7 +293,7 @@ def create_mask_operation_section():
293
 
294
  def create_output_section():
295
  """Create the output section optimized for right column."""
296
- md_submit = gr.Markdown("7. Submit and view the output")
297
 
298
  # Generation controls at top for better workflow
299
  with gr.Group():
 
44
  with gr.Row():
45
  # Add a note to remind users to click Clear before starting
46
  md_custmization_mode = gr.Markdown(
47
+ "1\. Select a Customization Mode\n\n*Tip: Please click the Clear button first to reset all states before starting a new task.*"
48
  )
49
  with gr.Row():
50
  custmization_mode = gr.Radio(
 
61
  def create_image_input_section():
62
  """Create image input section optimized for left column layout."""
63
  # Reference image section
64
+ md_image_reference = gr.Markdown("2\. Input reference image")
65
  with gr.Group():
66
  image_reference = gr.Image(
67
  label="Reference Image",
 
73
  )
74
 
75
  # Input mask mode selection
76
+ md_input_mask_mode = gr.Markdown("3\. Select input mask mode")
77
  with gr.Group():
78
  input_mask_mode = gr.Radio(
79
  ["Precise mask", "User-drawn mask"],
 
84
  )
85
 
86
  # Target image section
87
+ md_target_image = gr.Markdown("4\. Input target image & mask (Iterate clicking or brushing until the target is covered)")
88
 
89
  # Precise mask mode
90
  with gr.Group():
 
129
 
130
  def create_prompt_section():
131
  """Create the text prompt input section with improved layout."""
132
+ md_prompt = gr.Markdown("5\. Input text prompt (optional)")
133
  with gr.Group():
134
  prompt = gr.Textbox(
135
  placeholder="Please input the description for the target scene.",
 
243
 
244
  def create_mask_operation_section():
245
  """Create mask operation section optimized for right column (outputs)."""
246
+ md_mask_operation = gr.Markdown("6\. View or modify the target mask")
247
 
248
  with gr.Group():
249
  # Mask gallery with responsive layout
 
293
 
294
  def create_output_section():
295
  """Create the output section optimized for right column."""
296
+ md_submit = gr.Markdown("7\. Submit and view the output")
297
 
298
  # Generation controls at top for better workflow
299
  with gr.Group():
ic_custom/pipelines/ic_custom_pipeline.py CHANGED
@@ -1,4 +1,4 @@
1
-
2
  import re
3
  from typing import List, Optional, Union
4
 
@@ -128,6 +128,10 @@ class ICCustomPipeline:
128
  double_blocks_idx: str = None,
129
  single_blocks_idx: str = None,
130
  ):
 
 
 
 
131
  lora_path = resolve_model_path(
132
  name=lora_path,
133
  repo_id_field="repo_id",
@@ -181,6 +185,9 @@ class ICCustomPipeline:
181
  self.load_model_weights(weights, strict=False)
182
 
183
  def set_img_txt_in(self, img_txt_in_path: str):
 
 
 
184
  img_txt_in_path = resolve_model_path(
185
  name=img_txt_in_path,
186
  repo_id_field="repo_id",
@@ -192,6 +199,9 @@ class ICCustomPipeline:
192
  self.load_model_weights(weights, strict=False)
193
 
194
  def set_boundary_embeddings(self, boundary_embeddings_path: str):
 
 
 
195
  boundary_embeddings_path = resolve_model_path(
196
  name=boundary_embeddings_path,
197
  repo_id_field="repo_id",
@@ -203,6 +213,9 @@ class ICCustomPipeline:
203
  self.load_model_weights(weights, strict=False)
204
 
205
  def set_task_register_embeddings(self, task_register_embeddings_path: str):
 
 
 
206
  task_register_embeddings_path = resolve_model_path(
207
  name=task_register_embeddings_path,
208
  repo_id_field="repo_id",
 
1
+ import os
2
  import re
3
  from typing import List, Optional, Union
4
 
 
128
  double_blocks_idx: str = None,
129
  single_blocks_idx: str = None,
130
  ):
131
+ if not os.path.exists(lora_path):
132
+ lora_path = "dit_lora_0x1561"
133
+
134
+
135
  lora_path = resolve_model_path(
136
  name=lora_path,
137
  repo_id_field="repo_id",
 
185
  self.load_model_weights(weights, strict=False)
186
 
187
  def set_img_txt_in(self, img_txt_in_path: str):
188
+ if not os.path.exists(img_txt_in_path):
189
+ img_txt_in_path = "dit_txt_img_in_0x1561"
190
+
191
  img_txt_in_path = resolve_model_path(
192
  name=img_txt_in_path,
193
  repo_id_field="repo_id",
 
199
  self.load_model_weights(weights, strict=False)
200
 
201
  def set_boundary_embeddings(self, boundary_embeddings_path: str):
202
+ if not os.path.exists(boundary_embeddings_path):
203
+ boundary_embeddings_path = "dit_boundary_embeddings_0x1561"
204
+
205
  boundary_embeddings_path = resolve_model_path(
206
  name=boundary_embeddings_path,
207
  repo_id_field="repo_id",
 
213
  self.load_model_weights(weights, strict=False)
214
 
215
  def set_task_register_embeddings(self, task_register_embeddings_path: str):
216
+ if not os.path.exists(task_register_embeddings_path):
217
+ task_register_embeddings_path = "dit_task_register_embeddings_0x1561"
218
+
219
  task_register_embeddings_path = resolve_model_path(
220
  name=task_register_embeddings_path,
221
  repo_id_field="repo_id",
ic_custom/utils/model_utils.py CHANGED
@@ -206,6 +206,9 @@ def load_dit(
206
  model: Loaded Flux model
207
  """
208
  # Loading Flux
 
 
 
209
  logger.info("Initializing Flux model")
210
 
211
  # Resolve checkpoint path
@@ -249,9 +252,11 @@ def load_ic_custom(
249
  model: Loaded IC_Custom model
250
  """
251
  logger.info("Initializing IC-Custom model")
252
-
253
  # Resolve checkpoint path
254
-
 
 
255
  ckpt_path = resolve_model_path(
256
  name=name,
257
  repo_id_field="repo_id",
@@ -312,8 +317,7 @@ def load_embedder(
312
  path,
313
  max_length=max_length,
314
  is_clip=is_clip,
315
- torch_dtype=dtype,
316
- ).to(device)
317
 
318
  return model
319
 
@@ -336,7 +340,11 @@ def load_t5(
336
  Returns:
337
  model: Loaded T5 model
338
  """
 
 
 
339
  logger.info(f"Loading T5 model: {name}")
 
340
  return load_embedder(
341
  name=name,
342
  is_clip=False,
@@ -362,7 +370,11 @@ def load_clip(
362
  Returns:
363
  model: Loaded CLIP model
364
  """
 
 
 
365
  logger.info(f"Loading CLIP model: {name}")
 
366
  return load_embedder(
367
  name=name,
368
  is_clip=True,
@@ -387,6 +399,10 @@ def load_ae(
387
  Returns:
388
  model: Loaded AutoEncoder model
389
  """
 
 
 
 
390
  logger.info(f"Loading AutoEncoder model: {name}")
391
 
392
  # Convert device string to torch.device if needed
@@ -429,6 +445,12 @@ def load_redux(
429
  Returns:
430
  model: Loaded Redux Image Encoder model
431
  """
 
 
 
 
 
 
432
  logger.info(f"Loading Redux Image Encoder: redux={redux_name}, siglip={siglip_name}")
433
 
434
  # Convert device string to torch.device if needed
 
206
  model: Loaded Flux model
207
  """
208
  # Loading Flux
209
+ if not os.path.exists(name):
210
+ name = "flux-fill-dev-dit"
211
+
212
  logger.info("Initializing Flux model")
213
 
214
  # Resolve checkpoint path
 
252
  model: Loaded IC_Custom model
253
  """
254
  logger.info("Initializing IC-Custom model")
255
+
256
  # Resolve checkpoint path
257
+ if not os.path.exists(name):
258
+ name = "flux-fill-dev-dit"
259
+
260
  ckpt_path = resolve_model_path(
261
  name=name,
262
  repo_id_field="repo_id",
 
317
  path,
318
  max_length=max_length,
319
  is_clip=is_clip,
320
+ ).to(device).to(dtype)
 
321
 
322
  return model
323
 
 
340
  Returns:
341
  model: Loaded T5 model
342
  """
343
+ if not os.path.exists(name):
344
+ name = "t5-v1_1-xxl"
345
+
346
  logger.info(f"Loading T5 model: {name}")
347
+
348
  return load_embedder(
349
  name=name,
350
  is_clip=False,
 
370
  Returns:
371
  model: Loaded CLIP model
372
  """
373
+ if not os.path.exists(name):
374
+ name = "clip-vit-large-patch14"
375
+
376
  logger.info(f"Loading CLIP model: {name}")
377
+
378
  return load_embedder(
379
  name=name,
380
  is_clip=True,
 
399
  Returns:
400
  model: Loaded AutoEncoder model
401
  """
402
+
403
+ if not os.path.exists(name):
404
+ name = "flux-fill-dev-ae"
405
+
406
  logger.info(f"Loading AutoEncoder model: {name}")
407
 
408
  # Convert device string to torch.device if needed
 
445
  Returns:
446
  model: Loaded Redux Image Encoder model
447
  """
448
+
449
+ if not os.path.exists(redux_name):
450
+ redux_name = "flux1-redux-dev"
451
+ if not os.path.exists(siglip_name):
452
+ siglip_name = "siglip-so400m-patch14-384"
453
+
454
  logger.info(f"Loading Redux Image Encoder: redux={redux_name}, siglip={siglip_name}")
455
 
456
  # Convert device string to torch.device if needed