Spaces:
Running
on
Zero
Running
on
Zero
fix md and pipeline
Browse files- app/business_logic.py +12 -12
- app/ui_components.py +7 -7
- ic_custom/pipelines/ic_custom_pipeline.py +14 -1
- ic_custom/utils/model_utils.py +26 -4
app/business_logic.py
CHANGED
@@ -412,8 +412,8 @@ def change_custmization_mode(custmization_mode, input_mask_mode):
|
|
412 |
gr.update(value="<s>Select a input mask mode</s>", visible=False),
|
413 |
gr.update(value="<s>Input target image & mask (Iterate clicking or brushing until the target is covered)</s>", visible=False),
|
414 |
gr.update(value="<s>View or modify the target mask</s>", visible=False),
|
415 |
-
gr.update(value="3
|
416 |
-
gr.update(value="4
|
417 |
gr.update(visible=False),
|
418 |
gr.update(visible=False),
|
419 |
|
@@ -426,11 +426,11 @@ def change_custmization_mode(custmization_mode, input_mask_mode):
|
|
426 |
gr.update(interactive=True, visible=True),
|
427 |
gr.update(interactive=True, visible=True),
|
428 |
gr.update(interactive=True, visible=True),
|
429 |
-
gr.update(value="3
|
430 |
-
gr.update(value="4
|
431 |
-
gr.update(value="6
|
432 |
-
gr.update(value="5
|
433 |
-
gr.update(value="7
|
434 |
gr.update(visible=True, value="Precise mask"),
|
435 |
gr.update(visible=True),
|
436 |
)
|
@@ -441,11 +441,11 @@ def change_custmization_mode(custmization_mode, input_mask_mode):
|
|
441 |
gr.update(interactive=True, visible=True),
|
442 |
gr.update(interactive=True, visible=True),
|
443 |
gr.update(interactive=True, visible=True),
|
444 |
-
gr.update(value="3
|
445 |
-
gr.update(value="4
|
446 |
-
gr.update(value="6
|
447 |
-
gr.update(value="5
|
448 |
-
gr.update(value="7
|
449 |
gr.update(visible=True, value="User-drawn mask"),
|
450 |
gr.update(visible=True),
|
451 |
)
|
|
|
412 |
gr.update(value="<s>Select a input mask mode</s>", visible=False),
|
413 |
gr.update(value="<s>Input target image & mask (Iterate clicking or brushing until the target is covered)</s>", visible=False),
|
414 |
gr.update(value="<s>View or modify the target mask</s>", visible=False),
|
415 |
+
gr.update(value="3\. Input text prompt (necessary)"),
|
416 |
+
gr.update(value="4\. Submit and view the output"),
|
417 |
gr.update(visible=False),
|
418 |
gr.update(visible=False),
|
419 |
|
|
|
426 |
gr.update(interactive=True, visible=True),
|
427 |
gr.update(interactive=True, visible=True),
|
428 |
gr.update(interactive=True, visible=True),
|
429 |
+
gr.update(value="3\. Select a input mask mode", visible=True),
|
430 |
+
gr.update(value="4\. Input target image & mask (Iterate clicking or brushing until the target is covered)", visible=True),
|
431 |
+
gr.update(value="6\. View or modify the target mask", visible=True),
|
432 |
+
gr.update(value="5\. Input text prompt (optional)", visible=True),
|
433 |
+
gr.update(value="7\. Submit and view the output", visible=True),
|
434 |
gr.update(visible=True, value="Precise mask"),
|
435 |
gr.update(visible=True),
|
436 |
)
|
|
|
441 |
gr.update(interactive=True, visible=True),
|
442 |
gr.update(interactive=True, visible=True),
|
443 |
gr.update(interactive=True, visible=True),
|
444 |
+
gr.update(value="3\. Select a input mask mode", visible=True),
|
445 |
+
gr.update(value="4\. Input target image & mask (Iterate clicking or brushing until the target is covered)", visible=True),
|
446 |
+
gr.update(value="6\. View or modify the target mask", visible=True),
|
447 |
+
gr.update(value="5\. Input text prompt (optional)", visible=True),
|
448 |
+
gr.update(value="7\. Submit and view the output", visible=True),
|
449 |
gr.update(visible=True, value="User-drawn mask"),
|
450 |
gr.update(visible=True),
|
451 |
)
|
app/ui_components.py
CHANGED
@@ -44,7 +44,7 @@ def create_customization_section():
|
|
44 |
with gr.Row():
|
45 |
# Add a note to remind users to click Clear before starting
|
46 |
md_custmization_mode = gr.Markdown(
|
47 |
-
"1
|
48 |
)
|
49 |
with gr.Row():
|
50 |
custmization_mode = gr.Radio(
|
@@ -61,7 +61,7 @@ def create_customization_section():
|
|
61 |
def create_image_input_section():
|
62 |
"""Create image input section optimized for left column layout."""
|
63 |
# Reference image section
|
64 |
-
md_image_reference = gr.Markdown("2
|
65 |
with gr.Group():
|
66 |
image_reference = gr.Image(
|
67 |
label="Reference Image",
|
@@ -73,7 +73,7 @@ def create_image_input_section():
|
|
73 |
)
|
74 |
|
75 |
# Input mask mode selection
|
76 |
-
md_input_mask_mode = gr.Markdown("3
|
77 |
with gr.Group():
|
78 |
input_mask_mode = gr.Radio(
|
79 |
["Precise mask", "User-drawn mask"],
|
@@ -84,7 +84,7 @@ def create_image_input_section():
|
|
84 |
)
|
85 |
|
86 |
# Target image section
|
87 |
-
md_target_image = gr.Markdown("4
|
88 |
|
89 |
# Precise mask mode
|
90 |
with gr.Group():
|
@@ -129,7 +129,7 @@ def create_image_input_section():
|
|
129 |
|
130 |
def create_prompt_section():
|
131 |
"""Create the text prompt input section with improved layout."""
|
132 |
-
md_prompt = gr.Markdown("5
|
133 |
with gr.Group():
|
134 |
prompt = gr.Textbox(
|
135 |
placeholder="Please input the description for the target scene.",
|
@@ -243,7 +243,7 @@ def create_advanced_options_section():
|
|
243 |
|
244 |
def create_mask_operation_section():
|
245 |
"""Create mask operation section optimized for right column (outputs)."""
|
246 |
-
md_mask_operation = gr.Markdown("6
|
247 |
|
248 |
with gr.Group():
|
249 |
# Mask gallery with responsive layout
|
@@ -293,7 +293,7 @@ def create_mask_operation_section():
|
|
293 |
|
294 |
def create_output_section():
|
295 |
"""Create the output section optimized for right column."""
|
296 |
-
md_submit = gr.Markdown("7
|
297 |
|
298 |
# Generation controls at top for better workflow
|
299 |
with gr.Group():
|
|
|
44 |
with gr.Row():
|
45 |
# Add a note to remind users to click Clear before starting
|
46 |
md_custmization_mode = gr.Markdown(
|
47 |
+
"1\. Select a Customization Mode\n\n*Tip: Please click the Clear button first to reset all states before starting a new task.*"
|
48 |
)
|
49 |
with gr.Row():
|
50 |
custmization_mode = gr.Radio(
|
|
|
61 |
def create_image_input_section():
|
62 |
"""Create image input section optimized for left column layout."""
|
63 |
# Reference image section
|
64 |
+
md_image_reference = gr.Markdown("2\. Input reference image")
|
65 |
with gr.Group():
|
66 |
image_reference = gr.Image(
|
67 |
label="Reference Image",
|
|
|
73 |
)
|
74 |
|
75 |
# Input mask mode selection
|
76 |
+
md_input_mask_mode = gr.Markdown("3\. Select input mask mode")
|
77 |
with gr.Group():
|
78 |
input_mask_mode = gr.Radio(
|
79 |
["Precise mask", "User-drawn mask"],
|
|
|
84 |
)
|
85 |
|
86 |
# Target image section
|
87 |
+
md_target_image = gr.Markdown("4\. Input target image & mask (Iterate clicking or brushing until the target is covered)")
|
88 |
|
89 |
# Precise mask mode
|
90 |
with gr.Group():
|
|
|
129 |
|
130 |
def create_prompt_section():
|
131 |
"""Create the text prompt input section with improved layout."""
|
132 |
+
md_prompt = gr.Markdown("5\. Input text prompt (optional)")
|
133 |
with gr.Group():
|
134 |
prompt = gr.Textbox(
|
135 |
placeholder="Please input the description for the target scene.",
|
|
|
243 |
|
244 |
def create_mask_operation_section():
|
245 |
"""Create mask operation section optimized for right column (outputs)."""
|
246 |
+
md_mask_operation = gr.Markdown("6\. View or modify the target mask")
|
247 |
|
248 |
with gr.Group():
|
249 |
# Mask gallery with responsive layout
|
|
|
293 |
|
294 |
def create_output_section():
|
295 |
"""Create the output section optimized for right column."""
|
296 |
+
md_submit = gr.Markdown("7\. Submit and view the output")
|
297 |
|
298 |
# Generation controls at top for better workflow
|
299 |
with gr.Group():
|
ic_custom/pipelines/ic_custom_pipeline.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
|
2 |
import re
|
3 |
from typing import List, Optional, Union
|
4 |
|
@@ -128,6 +128,10 @@ class ICCustomPipeline:
|
|
128 |
double_blocks_idx: str = None,
|
129 |
single_blocks_idx: str = None,
|
130 |
):
|
|
|
|
|
|
|
|
|
131 |
lora_path = resolve_model_path(
|
132 |
name=lora_path,
|
133 |
repo_id_field="repo_id",
|
@@ -181,6 +185,9 @@ class ICCustomPipeline:
|
|
181 |
self.load_model_weights(weights, strict=False)
|
182 |
|
183 |
def set_img_txt_in(self, img_txt_in_path: str):
|
|
|
|
|
|
|
184 |
img_txt_in_path = resolve_model_path(
|
185 |
name=img_txt_in_path,
|
186 |
repo_id_field="repo_id",
|
@@ -192,6 +199,9 @@ class ICCustomPipeline:
|
|
192 |
self.load_model_weights(weights, strict=False)
|
193 |
|
194 |
def set_boundary_embeddings(self, boundary_embeddings_path: str):
|
|
|
|
|
|
|
195 |
boundary_embeddings_path = resolve_model_path(
|
196 |
name=boundary_embeddings_path,
|
197 |
repo_id_field="repo_id",
|
@@ -203,6 +213,9 @@ class ICCustomPipeline:
|
|
203 |
self.load_model_weights(weights, strict=False)
|
204 |
|
205 |
def set_task_register_embeddings(self, task_register_embeddings_path: str):
|
|
|
|
|
|
|
206 |
task_register_embeddings_path = resolve_model_path(
|
207 |
name=task_register_embeddings_path,
|
208 |
repo_id_field="repo_id",
|
|
|
1 |
+
import os
|
2 |
import re
|
3 |
from typing import List, Optional, Union
|
4 |
|
|
|
128 |
double_blocks_idx: str = None,
|
129 |
single_blocks_idx: str = None,
|
130 |
):
|
131 |
+
if not os.path.exists(lora_path):
|
132 |
+
lora_path = "dit_lora_0x1561"
|
133 |
+
|
134 |
+
|
135 |
lora_path = resolve_model_path(
|
136 |
name=lora_path,
|
137 |
repo_id_field="repo_id",
|
|
|
185 |
self.load_model_weights(weights, strict=False)
|
186 |
|
187 |
def set_img_txt_in(self, img_txt_in_path: str):
|
188 |
+
if not os.path.exists(img_txt_in_path):
|
189 |
+
img_txt_in_path = "dit_txt_img_in_0x1561"
|
190 |
+
|
191 |
img_txt_in_path = resolve_model_path(
|
192 |
name=img_txt_in_path,
|
193 |
repo_id_field="repo_id",
|
|
|
199 |
self.load_model_weights(weights, strict=False)
|
200 |
|
201 |
def set_boundary_embeddings(self, boundary_embeddings_path: str):
|
202 |
+
if not os.path.exists(boundary_embeddings_path):
|
203 |
+
boundary_embeddings_path = "dit_boundary_embeddings_0x1561"
|
204 |
+
|
205 |
boundary_embeddings_path = resolve_model_path(
|
206 |
name=boundary_embeddings_path,
|
207 |
repo_id_field="repo_id",
|
|
|
213 |
self.load_model_weights(weights, strict=False)
|
214 |
|
215 |
def set_task_register_embeddings(self, task_register_embeddings_path: str):
|
216 |
+
if not os.path.exists(task_register_embeddings_path):
|
217 |
+
task_register_embeddings_path = "dit_task_register_embeddings_0x1561"
|
218 |
+
|
219 |
task_register_embeddings_path = resolve_model_path(
|
220 |
name=task_register_embeddings_path,
|
221 |
repo_id_field="repo_id",
|
ic_custom/utils/model_utils.py
CHANGED
@@ -206,6 +206,9 @@ def load_dit(
|
|
206 |
model: Loaded Flux model
|
207 |
"""
|
208 |
# Loading Flux
|
|
|
|
|
|
|
209 |
logger.info("Initializing Flux model")
|
210 |
|
211 |
# Resolve checkpoint path
|
@@ -249,9 +252,11 @@ def load_ic_custom(
|
|
249 |
model: Loaded IC_Custom model
|
250 |
"""
|
251 |
logger.info("Initializing IC-Custom model")
|
252 |
-
|
253 |
# Resolve checkpoint path
|
254 |
-
|
|
|
|
|
255 |
ckpt_path = resolve_model_path(
|
256 |
name=name,
|
257 |
repo_id_field="repo_id",
|
@@ -312,8 +317,7 @@ def load_embedder(
|
|
312 |
path,
|
313 |
max_length=max_length,
|
314 |
is_clip=is_clip,
|
315 |
-
|
316 |
-
).to(device)
|
317 |
|
318 |
return model
|
319 |
|
@@ -336,7 +340,11 @@ def load_t5(
|
|
336 |
Returns:
|
337 |
model: Loaded T5 model
|
338 |
"""
|
|
|
|
|
|
|
339 |
logger.info(f"Loading T5 model: {name}")
|
|
|
340 |
return load_embedder(
|
341 |
name=name,
|
342 |
is_clip=False,
|
@@ -362,7 +370,11 @@ def load_clip(
|
|
362 |
Returns:
|
363 |
model: Loaded CLIP model
|
364 |
"""
|
|
|
|
|
|
|
365 |
logger.info(f"Loading CLIP model: {name}")
|
|
|
366 |
return load_embedder(
|
367 |
name=name,
|
368 |
is_clip=True,
|
@@ -387,6 +399,10 @@ def load_ae(
|
|
387 |
Returns:
|
388 |
model: Loaded AutoEncoder model
|
389 |
"""
|
|
|
|
|
|
|
|
|
390 |
logger.info(f"Loading AutoEncoder model: {name}")
|
391 |
|
392 |
# Convert device string to torch.device if needed
|
@@ -429,6 +445,12 @@ def load_redux(
|
|
429 |
Returns:
|
430 |
model: Loaded Redux Image Encoder model
|
431 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
432 |
logger.info(f"Loading Redux Image Encoder: redux={redux_name}, siglip={siglip_name}")
|
433 |
|
434 |
# Convert device string to torch.device if needed
|
|
|
206 |
model: Loaded Flux model
|
207 |
"""
|
208 |
# Loading Flux
|
209 |
+
if not os.path.exists(name):
|
210 |
+
name = "flux-fill-dev-dit"
|
211 |
+
|
212 |
logger.info("Initializing Flux model")
|
213 |
|
214 |
# Resolve checkpoint path
|
|
|
252 |
model: Loaded IC_Custom model
|
253 |
"""
|
254 |
logger.info("Initializing IC-Custom model")
|
255 |
+
|
256 |
# Resolve checkpoint path
|
257 |
+
if not os.path.exists(name):
|
258 |
+
name = "flux-fill-dev-dit"
|
259 |
+
|
260 |
ckpt_path = resolve_model_path(
|
261 |
name=name,
|
262 |
repo_id_field="repo_id",
|
|
|
317 |
path,
|
318 |
max_length=max_length,
|
319 |
is_clip=is_clip,
|
320 |
+
).to(device).to(dtype)
|
|
|
321 |
|
322 |
return model
|
323 |
|
|
|
340 |
Returns:
|
341 |
model: Loaded T5 model
|
342 |
"""
|
343 |
+
if not os.path.exists(name):
|
344 |
+
name = "t5-v1_1-xxl"
|
345 |
+
|
346 |
logger.info(f"Loading T5 model: {name}")
|
347 |
+
|
348 |
return load_embedder(
|
349 |
name=name,
|
350 |
is_clip=False,
|
|
|
370 |
Returns:
|
371 |
model: Loaded CLIP model
|
372 |
"""
|
373 |
+
if not os.path.exists(name):
|
374 |
+
name = "clip-vit-large-patch14"
|
375 |
+
|
376 |
logger.info(f"Loading CLIP model: {name}")
|
377 |
+
|
378 |
return load_embedder(
|
379 |
name=name,
|
380 |
is_clip=True,
|
|
|
399 |
Returns:
|
400 |
model: Loaded AutoEncoder model
|
401 |
"""
|
402 |
+
|
403 |
+
if not os.path.exists(name):
|
404 |
+
name = "flux-fill-dev-ae"
|
405 |
+
|
406 |
logger.info(f"Loading AutoEncoder model: {name}")
|
407 |
|
408 |
# Convert device string to torch.device if needed
|
|
|
445 |
Returns:
|
446 |
model: Loaded Redux Image Encoder model
|
447 |
"""
|
448 |
+
|
449 |
+
if not os.path.exists(redux_name):
|
450 |
+
redux_name = "flux1-redux-dev"
|
451 |
+
if not os.path.exists(siglip_name):
|
452 |
+
siglip_name = "siglip-so400m-patch14-384"
|
453 |
+
|
454 |
logger.info(f"Loading Redux Image Encoder: redux={redux_name}, siglip={siglip_name}")
|
455 |
|
456 |
# Convert device string to torch.device if needed
|