Spaces:
Running
Running
# seg2med_app/app.py | |
# streamlit run tutorial8_app.py | |
# F:\yang_Environments\torch\venv\Scripts\activate.ps1 | |
# streamlit run tutorial8_app.py --server.address=0.0.0.0 --server.port=8501 | |
# http://129.206.168.125:8501 http://169.254.3.1:8501 | |
#import sys | |
#sys.path.append('./seg2med_app') | |
import os | |
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" | |
# seg2med_app/main.py | |
import os | |
import streamlit as st | |
import zipfile | |
import hashlib | |
import pandas as pd | |
import numpy as np | |
import nibabel as nib | |
from seg2med_app.simulation.get_labels import get_labels | |
from seg2med_app.app_utils.image_utils import ( | |
show_three_planes, | |
show_label_overlay, | |
show_three_planes_interactive, | |
show_single_planes_interactive, | |
show_label_overlay_single, | |
generate_color_map, | |
load_image_canonical, | |
global_slice_slider, | |
image_to_base64, | |
show_single_slice_image, | |
show_single_slice_label, | |
) | |
from seg2med_app.ui.simulation_and_display import simulation_controls | |
from seg2med_app.ui.upload_and_prepare import handle_upload, compute_md5 | |
from dataprocesser.simulation_functions import ( | |
_merge_seg_tissue, | |
_create_body_contour_by_tissue_seg, | |
_create_body_contour | |
) | |
from seg2med_app.simulation.combine_selected_organs import combine_selected_organs | |
from seg2med_app.ui.inference_controls import inference_controls | |
from seg2med_app.ui.inference_gradio import call_gradio_gpu_infer | |
from seg2med_app.frankenstein.frankenstein import frankenstein_control | |
from seg2med_app.app_utils.titles import * | |
# ========== CONFIG ========== | |
app_root = 'seg2med_app' | |
os.makedirs(os.path.join(app_root, "tmp"), exist_ok=True) | |
# ========== UI STRUCTURE ========== | |
st.set_page_config( | |
page_title="Frankenstein App", | |
page_icon="🧠", | |
layout="wide" | |
) | |
st.session_state["app_root"] = app_root | |
import streamlit as st | |
from PIL import Image | |
import os | |
def reset_app(): | |
st.session_state.clear() | |
st.session_state.authenticated = True | |
st.session_state["authenticated"] = True | |
st.success("App has been reset. Login information is preserved.") | |
print("App has been reset. Login information is preserved.") | |
st.rerun() | |
image = Image.open(os.path.join(app_root, "Frankenstein0.png")) | |
image_to_base64(image) | |
st.title("\U0001F9E0 Frankenstein - multimodal medical image generation") | |
st.markdown(""" | |
**Created by**: Zeyu Yang | |
PhD Student, Computer-assisted Clinical Medicine | |
University of Heidelberg | |
🔗 [GitHub Repository](https://github.com/musetee/frankenstein) | |
📄 [Preprint on arXiv](https://arxiv.org/abs/2504.09182) | |
✉️ Contact: [Zeyu.Yang@medma.uni-heidelberg.de](mailto:Zeyu.Yang@medma.uni-heidelberg.de) | |
""") | |
PASSWORD = "frankenstein" | |
if "authenticated" not in st.session_state: | |
st.session_state.authenticated = True # set False to be authenticated | |
if not st.session_state.authenticated: | |
st.session_state["app_password"] = st.text_input("Enter access code", type="password") | |
if st.session_state["app_password"] == PASSWORD: | |
st.session_state.authenticated = True | |
st.success("✅ Access granted!") | |
else: | |
st.warning("🔒 Please enter the correct access code to continue.") | |
st.stop() | |
# ========== SIDEBAR (DATASET LOADER) ========== | |
st.sidebar.title("\U0001F9EC Dataset Loading") | |
load_method = st.sidebar.radio("Select load method", ["\U0001F3AE Random sample & manual draw", "\U0001F4C1 Upload segmentation"]) | |
if st.button("🔄 Reset App"): | |
reset_app() | |
Begin = "### 🎨 Begin: Choose a colormap to visualize different tissues" | |
st.write(Begin) | |
default_cmap = "PiYG" | |
cmap_options = [default_cmap, "nipy_spectral", "tab20", "Set3", "Paired", "tab10", "gist_rainbow", "custom"] | |
selected_cmap = st.selectbox("Label colormap", cmap_options, index=0) | |
# 如果选择“自定义”,显示文本框供用户输入 | |
if selected_cmap == "custom": | |
custom_cmap = st.text_input("please type custom colormap name", value=default_cmap) | |
selected_cmap = custom_cmap | |
else: | |
selected_cmap = selected_cmap | |
st.session_state.update({"selected_cmap": selected_cmap}) | |
# ========== select color map for visualization segmentation ============== | |
if "label_ids" in st.session_state: | |
st.session_state["label_to_color"] = generate_color_map(st.session_state["label_ids"], cmap=st.session_state["selected_cmap"]) | |
print('organ label to color: ', list(st.session_state["label_to_color"].items())[:5]) | |
# ========== MAIN: UPLOAD SEGMENTATION ========== | |
if load_method == "\U0001F4C1 Upload segmentation": | |
# ========== FIRST ROW ========== | |
col1, col2, col3, col4 = st.columns(4) | |
with col1: | |
uploaded_file = st.file_uploader("Upload segmentation", type=["zip", "nii.gz", "nii"]) | |
with col2: | |
uploaded_tissue = st.file_uploader("Upload tissue segmentation", type=["zip", "nii.gz", "nii"], key="tissue_upload") | |
with col3: | |
original_file = st.file_uploader("Upload original image", type=["nii.gz", "nii", "dcm"]) | |
with col4: | |
# 设置 body threshold(默认值根据模态设置或用户手动输入) | |
default_body_threshold = 0 | |
if "body_threshold" not in st.session_state: | |
st.session_state["body_threshold"] = default_body_threshold | |
user_input_threshold = st.number_input( | |
"Body threshold for contour extraction (used on original image)", | |
value=st.session_state["body_threshold"], | |
step=1 | |
) | |
use_custom_threshold = st.checkbox("Use custom body threshold", value=False) | |
st.session_state["use_custom_threshold"] = use_custom_threshold | |
visual_options = ["Only Axial Plane", "Three Planes"] | |
st.session_state["selected_visual"] = st.selectbox("Visualization Type", visual_options, index=0) | |
if user_input_threshold: | |
st.session_state["body_threshold"] = user_input_threshold | |
if user_input_threshold and "orig_img" in st.session_state: | |
st.session_state["contour"] = _create_body_contour(st.session_state['orig_img'], st.session_state['body_threshold'], body_mask_value=1) | |
# ========== HASH MANAGEMENT ========== | |
new_upload_hash = compute_md5(uploaded_file) if uploaded_file else None | |
cached_upload_hash = st.session_state.get("uploaded_file_hash", None) | |
new_tissue_hash = compute_md5(uploaded_tissue) if uploaded_tissue else None | |
cached_tissue_hash = st.session_state.get("uploaded_tissue_hash", None) | |
new_origin_hash = compute_md5(original_file) if original_file else None | |
cached_origin_hash = st.session_state.get("uploaded_origin_hash", None) | |
handle_upload(app_root, | |
uploaded_file, uploaded_tissue, original_file | |
) | |
# ========== SIMULATION UI (SHARED) ========== | |
simulation_controls(app_root) | |
# ========== INFERENCE UI (SHARED) ========== | |
inference_controls() | |
# ========== visualize ========== | |
if "combined_seg" in st.session_state: | |
z_idx, y_idx, x_idx = global_slice_slider(st.session_state["volume_shape"]) | |
st.session_state.update({ | |
"z_idx": z_idx, | |
"y_idx": y_idx, | |
"x_idx": x_idx, | |
}) | |
if st.session_state["selected_visual"] == "Three Planes": | |
show_three_planes_interactive(st.session_state["contour"], z_idx, y_idx, x_idx) | |
show_label_overlay(st.session_state["combined_seg"], z_idx, y_idx, x_idx, label_colors=st.session_state["label_to_color"]) | |
else: | |
show_single_planes_interactive(st.session_state["contour"], z_idx, y_idx, x_idx) | |
show_label_overlay_single(st.session_state["combined_seg"], z_idx, y_idx, x_idx, label_colors=st.session_state["label_to_color"]) | |
if "selected_organs" in st.session_state and len(st.session_state["selected_organs"]) > 0: | |
multi_seg = combine_selected_organs(uploaded_file) | |
if st.session_state["selected_visual"] == "Three Planes": | |
show_label_overlay(multi_seg, z_idx, y_idx, x_idx, label_colors=st.session_state["label_to_color"]) | |
else: | |
show_label_overlay_single(multi_seg, z_idx, y_idx, x_idx, label_colors=st.session_state["label_to_color"]) | |
if "orig_img" in st.session_state: | |
if st.session_state["selected_visual"] == "Three Planes": | |
show_three_planes_interactive(st.session_state["orig_img"], z_idx, y_idx, x_idx,) | |
else: | |
show_single_planes_interactive(st.session_state["orig_img"], z_idx, y_idx, x_idx,) | |
if st.session_state.get("processed_img") is not None: | |
st.markdown("🔍 View Simulation Result") | |
if st.session_state["selected_visual"] == "Three Planes": | |
show_three_planes_interactive(st.session_state["processed_img"], | |
st.session_state["z_idx"], | |
st.session_state["y_idx"], | |
st.session_state["x_idx"],) | |
else: | |
show_single_planes_interactive(st.session_state["processed_img"], | |
st.session_state["z_idx"], | |
st.session_state["y_idx"], | |
st.session_state["x_idx"],) | |
if st.session_state.get("output_img") is not None: | |
st.session_state["output_volume_to_save"] = np.expand_dims(st.session_state["output_img"].T, axis=-1) | |
if st.session_state["selected_visual"] == "Three Planes": | |
show_three_planes_interactive(np.expand_dims(st.session_state["output_img"], axis=-1),0,0,0,orientation_type='none',) # model output already in correct orientation | |
else: | |
show_single_planes_interactive(np.expand_dims(st.session_state["output_img"], axis=-1),0,0,0,orientation_type='none',) # model output already in correct orientation | |
#st.success(f"Saved to {filename_output}") | |
# ========== RANDOM DRAW PAGE PLACEHOLDER ========== | |
elif load_method == "\U0001F3AE Random sample & manual draw": | |
st.markdown("## 🎮 Frankenstein Interactive creating tool") | |
frankenstein_control() | |
make_step_renderer(step5_frankenstein) | |
simulation_controls(app_root) | |
make_step_renderer(step7_frankenstein) | |
inference_controls() | |
if st.button("⚙️ Run inference by Gradio"): | |
st.info("Running inference...") | |
modality = st.session_state["modality_idx"] | |
image_slice = st.session_state["processed_img"][:, :, st.session_state["z_idx"]] | |
result = call_gradio_gpu_infer(modality, image_slice) | |
st.image(result, caption="Predicted Image") | |
import matplotlib.pyplot as plt | |
if "output_img" in st.session_state: | |
output_img = st.session_state["output_img"] | |
plt.figure() | |
plt.imshow(output_img, cmap="gray") | |
plt.grid(False) | |
plt.savefig(r'seg2med_app\modeloutput.png') | |
plt.close() | |
width=400 | |
col1, col2, col3, col4 = st.columns([1, 1, 1, 1]) | |
with col1: | |
if "contour" in st.session_state: | |
show_single_slice_image(st.session_state["contour"].squeeze(),title="contour") | |
with col2: | |
if "combined_seg" in st.session_state: | |
show_single_slice_label(st.session_state["combined_seg"].squeeze(), | |
st.session_state["label_to_color"], | |
title="combined segs") | |
with col3: | |
if st.session_state.get("processed_img") is not None: | |
print(np.unique(st.session_state["processed_img"])) | |
show_single_slice_image(st.session_state["processed_img"].squeeze(), title="image prior") | |
with col4: | |
if st.session_state.get("output_img") is not None: | |
st.session_state["output_volume_to_save"] = np.expand_dims(st.session_state["output_img"].T, axis=-1) | |
# no need to set orientation because the model output should be correct | |
show_single_slice_image(st.session_state["output_img"], title="inference image", orientation_type='none') | |
make_step_renderer(step8_frankenstein) | |
# ========== SAVE ========== | |
output_folder = os.path.join(app_root, 'output') | |
os.makedirs(output_folder, exist_ok=True) | |
col1, col2, col3, col4 = st.columns([1,1,1,1]) | |
with col1: | |
filename_prior = st.text_input("Filename (.nii.gz)", value="contour.nii.gz", key="filename_contour") | |
prior_save_path = os.path.join(output_folder, filename_prior) | |
if st.session_state.get("contour") is not None: # st.button("💾 Save Image Prior") and | |
img_to_save = nib.Nifti1Image(st.session_state["contour"], st.session_state["orig_affine"]) | |
nib.save(img_to_save, prior_save_path) | |
if os.path.exists(prior_save_path): | |
with open(prior_save_path, "rb") as f: | |
st.download_button( | |
label="⬇️ Download Contour", | |
data=f, | |
file_name=filename_prior, | |
mime="application/gzip" | |
) | |
#st.success(f"Saved to {filename_prior}") | |
with col2: | |
filename_output = st.text_input("Filename (.nii.gz)", value="combined_seg.nii.gz", key="filename_combined") | |
output_save_path = os.path.join(output_folder, filename_output) | |
if st.session_state.get("combined_seg") is not None : # and st.button("💾 Save Output") | |
img_to_save = nib.Nifti1Image(st.session_state["combined_seg"], st.session_state["orig_affine"]) | |
nib.save(img_to_save, output_save_path) | |
if os.path.exists(output_save_path): | |
with open(output_save_path, "rb") as f: | |
st.download_button( | |
label="⬇️ Download Combined Segmentation", | |
data=f, | |
file_name=filename_output, | |
mime="application/gzip" | |
) | |
with col3: | |
filename_prior = st.text_input("Filename (.nii.gz)", value="prior_image.nii.gz", key="filename_prior") | |
prior_save_path = os.path.join(output_folder, filename_prior) | |
if st.session_state.get("processed_img") is not None: # st.button("💾 Save Image Prior") and | |
img_to_save = nib.Nifti1Image(st.session_state["processed_img"], st.session_state["orig_affine"]) | |
nib.save(img_to_save, prior_save_path) | |
if os.path.exists(prior_save_path): | |
with open(prior_save_path, "rb") as f: | |
st.download_button( | |
label="⬇️ Download Prior Image", | |
data=f, | |
file_name=filename_prior, | |
mime="application/gzip" | |
) | |
with col4: | |
filename_output = st.text_input("Filename (.nii.gz)", value="model_output.nii.gz", key="filename_output") | |
output_save_path = os.path.join(output_folder, filename_output) | |
if st.session_state.get("output_volume_to_save") is not None : # and st.button("💾 Save Output") | |
img_to_save = nib.Nifti1Image(st.session_state["output_volume_to_save"], st.session_state["orig_affine"]) | |
nib.save(img_to_save, output_save_path) | |
if os.path.exists(output_save_path): | |
with open(output_save_path, "rb") as f: | |
st.download_button( | |
label="⬇️ Download Output Image", | |
data=f, | |
file_name=filename_output, | |
mime="application/gzip" | |
) | |