Spaces:
Running
Running
import gradio as gr | |
from datetime import datetime | |
from huggingface_hub import hf_hub_download | |
import torch | |
import json | |
from PIL import Image | |
from PIL import ImageDraw, ImageFont | |
import numpy as np | |
from model import MIPHEIViT | |
# Load model once | |
repo_id = "Estabousi/MIPHEI-vit" | |
model = MIPHEIViT.from_pretrained_hf(repo_id=repo_id) | |
config_path = hf_hub_download(repo_id=repo_id, filename="config_hf.json") | |
model.eval() | |
mean = torch.Tensor([0.485, 0.456, 0.406]).to(torch.float32).reshape((-1, 1, 1)) | |
std = torch.Tensor([0.229, 0.224, 0.225]).to(torch.float32).reshape((-1, 1, 1)) | |
with open(config_path, "r") as f: | |
config = json.load(f) | |
channel_names = config["targ_channel_names"] | |
channel_colors = { | |
"Hoechst": (0, 0, 255), # Blue (DAPI, nuclear stain) | |
"CD31": (0, 255, 255), # Cyan (endothelial) | |
"CD45": (255, 255, 0), # Yellow (leukocyte common antigen) | |
"CD68": (255, 165, 0), # Orange (macrophages) | |
"CD4": (255, 0, 0), # Red (helper T cells) | |
"FOXP3": (138, 43, 226), # Purple/Blue-Violet (regulatory T cells) | |
"CD8a": (303, 100, 100), # Green (cytotoxic T cells) | |
"CD45RO": (255, 105, 180), # Hot Pink (memory T cells) | |
"CD20": (0, 191, 255), # Deep Sky Blue (B cells) | |
"PD-L1": (255, 0, 255), # Magenta | |
"CD3e": (95, 95, 94), # Crimson (T cells) | |
"CD163": (184, 134, 11), # Dark Goldenrod (M2 macrophages) | |
"E-cadherin": (242, 12, 43), # Spring Green (epithelial marker) | |
"Ki67": (255, 20, 147), # Deep Pink (proliferation marker) | |
"Pan-CK": (255, 0, 0), # Red (epithelial/carcinoma) | |
"SMA": (0, 255, 0), # Green (smooth muscle, myofibroblasts) | |
} | |
# Contrast correction factors per channel (255 for Hoechst, 150 otherwise) | |
default_contrast = 150.0 | |
correction_map = {"Hoechst": 255.0, "CD8a": 100, "CD31": 100, "CD4": 100, "CD68": 100, "FOXP3": 100} | |
max_contrast_correction_value = torch.tensor([ | |
correction_map.get(name, default_contrast) / 255 for name in channel_names | |
]).reshape(len(channel_names), 1, 1) | |
overlay_markers = ["Hoechst", "Pan-CK", "SMA", "CD45"] | |
def preprocess(image): | |
image = image.convert("RGB").resize((256, 256)) | |
tensor = torch.from_numpy(np.array(image)).permute(2, 0, 1).float() / 255 | |
tensor = (tensor - mean) / std | |
return tensor.unsqueeze(0) # [1, 3, H, W] | |
def draw_legend_on_image(image, channel_names, channel_colors, indices, box_size=18, spacing=5, top_margin=5): | |
"""Draw a semi-transparent legend on the bottom-right of the image.""" | |
overlay = image.convert("RGBA") # to allow alpha | |
legend_layer = Image.new("RGBA", overlay.size, (255, 255, 255, 0)) | |
draw = ImageDraw.Draw(legend_layer) | |
font = ImageFont.load_default() | |
legend_height = top_margin + box_size * len(indices) + spacing * (len(indices) - 1) | |
legend_width = 60 # adjust as needed | |
x_start = overlay.width - legend_width - 10 | |
y_start = overlay.height - legend_height - 10 | |
# Semi-transparent background | |
draw.rectangle( | |
[x_start - 5, y_start - 5, x_start + legend_width + 5, y_start + legend_height + 5], | |
fill=(255, 255, 255, 180) # semi-transparent white | |
) | |
for i, idx in enumerate(indices): | |
name = channel_names[idx] | |
color = channel_colors[name] | |
y = y_start + i * (box_size + spacing) | |
draw.rectangle([x_start, y, x_start + box_size, y + box_size], fill=color + (255,)) | |
draw.text((x_start + box_size + 5, y), name, fill=(0, 0, 0, 255), font=font) | |
# Composite legend onto overlay | |
combined = Image.alpha_composite(overlay, legend_layer) | |
return combined.convert("RGB") # back to RGB for display | |
def merge_colored_images(color_imgs, top4_idx): | |
# Convert images to float32 NumPy arrays | |
accum = np.zeros_like(np.array(color_imgs[0]), dtype=np.float32) | |
for idx in top4_idx: | |
img = np.array(color_imgs[idx]).astype(np.float32) | |
accum += img # additive blending | |
accum = np.clip(accum, 0, 255).astype(np.uint8) | |
return Image.fromarray(accum, mode='RGB') | |
def apply_color_map(gray_img, rgb_color): | |
"""Map a grayscale image to RGB using a fixed pseudocolor.""" | |
gray = np.asarray(gray_img).astype(np.float32) / 255.0 | |
rgb = np.stack([gray * rgb_color[i] for i in range(3)], axis=-1).astype(np.uint8) | |
return Image.fromarray(rgb, mode='RGB') | |
def predict(image): | |
print(f"[{datetime.now().isoformat()}] Inference run") | |
input_tensor = preprocess(image) | |
with torch.inference_mode(): | |
output = model(input_tensor)[0] # [16, H, W] | |
output = (output.clamp(-0.9, 0.9) + 0.9) / 1.8 | |
output_vis = output / max_contrast_correction_value.to(output.device).clamp(min=1e-6) | |
output_vis = output_vis.clamp(0, 1) * 255 | |
output_vis = np.uint8(output_vis.cpu().numpy()) | |
output = output.cpu().numpy() | |
# Convert each mIF channel to grayscale PIL image | |
channel_imgs = [] | |
for i in range(output_vis.shape[0]): | |
ch_name = channel_names[i] | |
ch_gray = Image.fromarray(output_vis[i], mode='L') | |
ch_colored = apply_color_map(ch_gray, channel_colors[ch_name]) | |
channel_imgs.append(ch_colored) | |
fixed_idx = [channel_names.index(name) for name in overlay_markers] | |
overlay = merge_colored_images(channel_imgs, fixed_idx) | |
overlay_with_legend = draw_legend_on_image(overlay, channel_names, channel_colors, fixed_idx) | |
return [overlay_with_legend] + channel_imgs | |
# Markdown header | |
with open("HEADER.md", "r", encoding="utf-8") as f: | |
HEADER_MD = f.read() | |
# Build interface using Blocks | |
with gr.Blocks() as demo: | |
gr.Markdown(HEADER_MD) | |
with gr.Row(): | |
# LEFT: input + examples + button | |
with gr.Column(scale=1): | |
input_image = gr.Image(type="pil", label="Input H&E") | |
run_btn = gr.Button("Run Prediction") | |
gr.Examples( | |
examples=[ | |
["examples/crc100k_val.jpg"], | |
["examples/orion_test_1.jpg"], | |
["examples/orion_test_2.jpg"], | |
["examples/orion_test_3.jpg"], | |
["examples/orion_test_4.jpg"], | |
["examples/orion_test_5.jpg"], | |
["examples/tcga.jpg"], | |
["examples/hemit.jpg"], | |
], | |
inputs=[input_image], | |
label="Example H&E tile (TCGA, ORION Test, CRC100K, HEMIT)" | |
) | |
# RIGHT: outputs | |
with gr.Column(scale=2): | |
overlay_image = gr.Image(type="pil", label="mIF Overlay") | |
channel_images = [ | |
gr.Image(type="pil", label=f"mIF Channel {channel_names[i]}") | |
for i in range(16) | |
] | |
output_images = [overlay_image] + channel_images | |
run_btn.click(fn=predict, inputs=input_image, outputs=output_images) | |
if __name__ == "__main__": | |
demo.launch(ssr_mode=False) | |