MIPHEI-vit-demo / app.py
Estabousi's picture
Update mean std normalization
32cdeb9 verified
import gradio as gr
from datetime import datetime
from huggingface_hub import hf_hub_download
import torch
import json
from PIL import Image
from PIL import ImageDraw, ImageFont
import numpy as np
from model import MIPHEIViT
# Load model once
repo_id = "Estabousi/MIPHEI-vit"
model = MIPHEIViT.from_pretrained_hf(repo_id=repo_id)
config_path = hf_hub_download(repo_id=repo_id, filename="config_hf.json")
model.eval()
mean = torch.Tensor([0.485, 0.456, 0.406]).to(torch.float32).reshape((-1, 1, 1))
std = torch.Tensor([0.229, 0.224, 0.225]).to(torch.float32).reshape((-1, 1, 1))
with open(config_path, "r") as f:
config = json.load(f)
channel_names = config["targ_channel_names"]
channel_colors = {
"Hoechst": (0, 0, 255), # Blue (DAPI, nuclear stain)
"CD31": (0, 255, 255), # Cyan (endothelial)
"CD45": (255, 255, 0), # Yellow (leukocyte common antigen)
"CD68": (255, 165, 0), # Orange (macrophages)
"CD4": (255, 0, 0), # Red (helper T cells)
"FOXP3": (138, 43, 226), # Purple/Blue-Violet (regulatory T cells)
"CD8a": (303, 100, 100), # Green (cytotoxic T cells)
"CD45RO": (255, 105, 180), # Hot Pink (memory T cells)
"CD20": (0, 191, 255), # Deep Sky Blue (B cells)
"PD-L1": (255, 0, 255), # Magenta
"CD3e": (95, 95, 94), # Crimson (T cells)
"CD163": (184, 134, 11), # Dark Goldenrod (M2 macrophages)
"E-cadherin": (242, 12, 43), # Spring Green (epithelial marker)
"Ki67": (255, 20, 147), # Deep Pink (proliferation marker)
"Pan-CK": (255, 0, 0), # Red (epithelial/carcinoma)
"SMA": (0, 255, 0), # Green (smooth muscle, myofibroblasts)
}
# Contrast correction factors per channel (255 for Hoechst, 150 otherwise)
default_contrast = 150.0
correction_map = {"Hoechst": 255.0, "CD8a": 100, "CD31": 100, "CD4": 100, "CD68": 100, "FOXP3": 100}
max_contrast_correction_value = torch.tensor([
correction_map.get(name, default_contrast) / 255 for name in channel_names
]).reshape(len(channel_names), 1, 1)
overlay_markers = ["Hoechst", "Pan-CK", "SMA", "CD45"]
def preprocess(image):
image = image.convert("RGB").resize((256, 256))
tensor = torch.from_numpy(np.array(image)).permute(2, 0, 1).float() / 255
tensor = (tensor - mean) / std
return tensor.unsqueeze(0) # [1, 3, H, W]
def draw_legend_on_image(image, channel_names, channel_colors, indices, box_size=18, spacing=5, top_margin=5):
"""Draw a semi-transparent legend on the bottom-right of the image."""
overlay = image.convert("RGBA") # to allow alpha
legend_layer = Image.new("RGBA", overlay.size, (255, 255, 255, 0))
draw = ImageDraw.Draw(legend_layer)
font = ImageFont.load_default()
legend_height = top_margin + box_size * len(indices) + spacing * (len(indices) - 1)
legend_width = 60 # adjust as needed
x_start = overlay.width - legend_width - 10
y_start = overlay.height - legend_height - 10
# Semi-transparent background
draw.rectangle(
[x_start - 5, y_start - 5, x_start + legend_width + 5, y_start + legend_height + 5],
fill=(255, 255, 255, 180) # semi-transparent white
)
for i, idx in enumerate(indices):
name = channel_names[idx]
color = channel_colors[name]
y = y_start + i * (box_size + spacing)
draw.rectangle([x_start, y, x_start + box_size, y + box_size], fill=color + (255,))
draw.text((x_start + box_size + 5, y), name, fill=(0, 0, 0, 255), font=font)
# Composite legend onto overlay
combined = Image.alpha_composite(overlay, legend_layer)
return combined.convert("RGB") # back to RGB for display
def merge_colored_images(color_imgs, top4_idx):
# Convert images to float32 NumPy arrays
accum = np.zeros_like(np.array(color_imgs[0]), dtype=np.float32)
for idx in top4_idx:
img = np.array(color_imgs[idx]).astype(np.float32)
accum += img # additive blending
accum = np.clip(accum, 0, 255).astype(np.uint8)
return Image.fromarray(accum, mode='RGB')
def apply_color_map(gray_img, rgb_color):
"""Map a grayscale image to RGB using a fixed pseudocolor."""
gray = np.asarray(gray_img).astype(np.float32) / 255.0
rgb = np.stack([gray * rgb_color[i] for i in range(3)], axis=-1).astype(np.uint8)
return Image.fromarray(rgb, mode='RGB')
def predict(image):
print(f"[{datetime.now().isoformat()}] Inference run")
input_tensor = preprocess(image)
with torch.inference_mode():
output = model(input_tensor)[0] # [16, H, W]
output = (output.clamp(-0.9, 0.9) + 0.9) / 1.8
output_vis = output / max_contrast_correction_value.to(output.device).clamp(min=1e-6)
output_vis = output_vis.clamp(0, 1) * 255
output_vis = np.uint8(output_vis.cpu().numpy())
output = output.cpu().numpy()
# Convert each mIF channel to grayscale PIL image
channel_imgs = []
for i in range(output_vis.shape[0]):
ch_name = channel_names[i]
ch_gray = Image.fromarray(output_vis[i], mode='L')
ch_colored = apply_color_map(ch_gray, channel_colors[ch_name])
channel_imgs.append(ch_colored)
fixed_idx = [channel_names.index(name) for name in overlay_markers]
overlay = merge_colored_images(channel_imgs, fixed_idx)
overlay_with_legend = draw_legend_on_image(overlay, channel_names, channel_colors, fixed_idx)
return [overlay_with_legend] + channel_imgs
# Markdown header
with open("HEADER.md", "r", encoding="utf-8") as f:
HEADER_MD = f.read()
# Build interface using Blocks
with gr.Blocks() as demo:
gr.Markdown(HEADER_MD)
with gr.Row():
# LEFT: input + examples + button
with gr.Column(scale=1):
input_image = gr.Image(type="pil", label="Input H&E")
run_btn = gr.Button("Run Prediction")
gr.Examples(
examples=[
["examples/crc100k_val.jpg"],
["examples/orion_test_1.jpg"],
["examples/orion_test_2.jpg"],
["examples/orion_test_3.jpg"],
["examples/orion_test_4.jpg"],
["examples/orion_test_5.jpg"],
["examples/tcga.jpg"],
["examples/hemit.jpg"],
],
inputs=[input_image],
label="Example H&E tile (TCGA, ORION Test, CRC100K, HEMIT)"
)
# RIGHT: outputs
with gr.Column(scale=2):
overlay_image = gr.Image(type="pil", label="mIF Overlay")
channel_images = [
gr.Image(type="pil", label=f"mIF Channel {channel_names[i]}")
for i in range(16)
]
output_images = [overlay_image] + channel_images
run_btn.click(fn=predict, inputs=input_image, outputs=output_images)
if __name__ == "__main__":
demo.launch(ssr_mode=False)