Spaces:
Running
on
Zero
Running
on
Zero
""" | |
CREStereo Gradio Demo with ZeroGPU Integration | |
This demo showcases the CREStereo model for stereo depth estimation. | |
Optimized for Hugging Face Spaces with ZeroGPU support. | |
Key ZeroGPU optimizations: | |
- @spaces.GPU decorators for GPU-intensive functions | |
- CUDA operations only within GPU context | |
- Memory-efficient inference with cleanup | |
- Safe CUDA initialization patterns | |
""" | |
import os | |
import sys | |
import logging | |
import tempfile | |
import gc | |
from pathlib import Path | |
from typing import Optional, Tuple, Union | |
import numpy as np | |
import cv2 | |
import gradio as gr | |
import imageio | |
# Import spaces BEFORE torch to ensure proper ZeroGPU initialization | |
import spaces | |
# Import torch after spaces - avoid any CUDA calls during import | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch.cuda.amp import autocast | |
# Completely avoid CUDA operations during import phase | |
# Do not set default tensor type or modify CUDA settings outside GPU context | |
# torch.set_default_tensor_type('torch.FloatTensor') # Commented out - causes CUDA init | |
# Do not modify CUDA settings during import - this can trigger CUDA initialization | |
# torch.backends.cudnn.enabled = False # Commented out | |
# torch.backends.cudnn.benchmark = False # Commented out | |
# Use current directory as base | |
current_dir = os.path.dirname(os.path.abspath(__file__)) | |
base_dir = current_dir | |
# Add current directory to path for local imports | |
sys.path.insert(0, current_dir) | |
# Import local modules | |
from nets import Model | |
# Import Open3D with error handling | |
OPEN3D_AVAILABLE = False | |
try: | |
# Set Open3D to CPU mode to avoid CUDA initialization | |
os.environ['OPEN3D_CPU_RENDERING'] = '1' | |
# Don't import open3d here - do it inside functions | |
# import open3d as o3d | |
OPEN3D_AVAILABLE = True # Assume available, will check later | |
except Exception as e: | |
logging.warning(f"Open3D setup failed: {e}") | |
OPEN3D_AVAILABLE = False | |
# Configure logging | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
# Model configuration | |
MODEL_VARIANTS = { | |
"crestereo_eth3d": { | |
"display_name": "CREStereo ETH3D (Pre-trained model)", | |
"model_file": "models/crestereo_eth3d.pth", | |
"max_disp": 256 | |
} | |
} | |
# Global variables for model caching | |
_cached_model = None | |
_cached_device = None | |
_cached_model_selection = None | |
class InputPadder: | |
""" Pads images such that dimensions are divisible by divis_by """ | |
def __init__(self, dims, divis_by=8, force_square=False): | |
self.ht, self.wd = dims[-2:] | |
pad_ht = (((self.ht // divis_by) + 1) * divis_by - self.ht) % divis_by | |
pad_wd = (((self.wd // divis_by) + 1) * divis_by - self.wd) % divis_by | |
if force_square: | |
# Make the padded dimensions square | |
max_dim = max(self.ht + pad_ht, self.wd + pad_wd) | |
pad_ht = max_dim - self.ht | |
pad_wd = max_dim - self.wd | |
self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2] | |
def pad(self, *inputs): | |
return [F.pad(x, self._pad, mode='replicate') for x in inputs] | |
def unpad(self, x): | |
ht, wd = x.shape[-2:] | |
c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]] | |
return x[..., c[0]:c[1], c[2]:c[3]] | |
def aggressive_cleanup(): | |
"""Perform basic cleanup - no CUDA operations outside GPU context""" | |
import gc | |
gc.collect() | |
logging.info("Performed basic memory cleanup") | |
def initialize_gpu_context(): | |
"""Initialize GPU context safely for ZeroGPU""" | |
try: | |
# Set CUDA settings safely within GPU context | |
torch.set_default_tensor_type('torch.cuda.FloatTensor') | |
torch.backends.cudnn.enabled = True | |
torch.backends.cudnn.benchmark = True | |
# Check GPU availability and log info | |
if torch.cuda.is_available(): | |
device_name = torch.cuda.get_device_name(0) | |
memory_total = torch.cuda.get_device_properties(0).total_memory / 1024**3 | |
logging.info(f"GPU initialized: {device_name}, Total memory: {memory_total:.2f}GB") | |
return True | |
else: | |
logging.error("CUDA not available after GPU context initialization") | |
return False | |
except Exception as e: | |
logging.error(f"GPU context initialization failed: {e}") | |
return False | |
def check_gpu_memory(): | |
"""Check and log current GPU memory usage - only call within GPU context""" | |
try: | |
allocated = torch.cuda.memory_allocated(0) / 1024**3 | |
reserved = torch.cuda.memory_reserved(0) / 1024**3 | |
max_allocated = torch.cuda.max_memory_allocated(0) / 1024**3 | |
total = torch.cuda.get_device_properties(0).total_memory / 1024**3 | |
logging.info(f"GPU Memory - Allocated: {allocated:.2f}GB, Reserved: {reserved:.2f}GB, Max: {max_allocated:.2f}GB, Total: {total:.2f}GB") | |
return allocated, reserved, max_allocated, total | |
except RuntimeError as e: | |
logging.warning(f"Failed to get GPU memory info: {e}") | |
return None, None, None, None | |
def get_available_models() -> dict: | |
"""Get all available models with their display names""" | |
models = {} | |
# Check for local models | |
for variant, info in MODEL_VARIANTS.items(): | |
model_path = os.path.join(current_dir, info["model_file"]) | |
if os.path.exists(model_path): | |
display_name = info["display_name"] | |
models[display_name] = { | |
"model_path": model_path, | |
"variant": variant, | |
"max_disp": info["max_disp"], | |
"source": "local" | |
} | |
return models | |
def get_model_paths_from_selection(model_selection: str) -> Tuple[Optional[str], Optional[dict]]: | |
"""Get model path and config from the selected model""" | |
models = get_available_models() | |
# Check if it's in our models dict | |
if model_selection in models: | |
model_info = models[model_selection] | |
logging.info(f"π Using local model: {model_selection}") | |
return model_info["model_path"], model_info | |
return None, None | |
def load_model_for_inference(model_path: str, model_info: dict): | |
"""Load CREStereo model for inference temporarily (demo-style)""" | |
# Set CUDA settings safely within GPU context | |
torch.set_default_tensor_type('torch.cuda.FloatTensor') # Now safe to use CUDA tensors | |
torch.backends.cudnn.enabled = True | |
torch.backends.cudnn.benchmark = True | |
# Check if CUDA is available after ZeroGPU initialization | |
if not torch.cuda.is_available(): | |
raise RuntimeError("CUDA is not available. ZeroGPU initialization may have failed.") | |
# Use the first available CUDA device | |
device = torch.device("cuda") | |
# Set CUDA seed safely within GPU context | |
try: | |
random_seed = 0 | |
torch.cuda.manual_seed_all(random_seed) | |
torch.backends.cudnn.deterministic = True | |
torch.backends.cudnn.benchmark = False | |
except Exception as e: | |
logging.warning(f"Could not set CUDA seed: {e}") | |
try: | |
# Create model | |
max_disp = model_info.get("max_disp", 256) | |
model = Model(max_disp=max_disp, mixed_precision=False, test_mode=True) | |
# Load checkpoint | |
ckpt = torch.load(model_path, map_location=device) | |
model.load_state_dict(ckpt, strict=True) | |
model.to(device) | |
model.eval() | |
logging.info("Loaded CREStereo model weights") | |
# Memory optimizations | |
torch.set_grad_enabled(False) | |
logging.info("Applied memory optimizations") | |
return model, device | |
except Exception as e: | |
logging.error(f"Model loading failed: {e}") | |
raise RuntimeError(f"Failed to load model: {e}") | |
def get_cached_model(model_selection: str): | |
"""Get cached model or load new one if selection changed""" | |
global _cached_model, _cached_device, _cached_model_selection | |
# Get model paths from selection | |
model_path, model_info = get_model_paths_from_selection(model_selection) | |
if model_path is None or model_info is None: | |
raise ValueError(f"Selected model not found: {model_selection}") | |
# Check if we need to reload the model | |
if (_cached_model is None or | |
_cached_model_selection != model_selection): | |
# Clear previous model if exists | |
if _cached_model is not None: | |
del _cached_model | |
torch.cuda.empty_cache() | |
gc.collect() | |
logging.info(f"π Loading model: {model_selection}") | |
_cached_model, _cached_device = load_model_for_inference(model_path, model_info) | |
_cached_model_selection = model_selection | |
logging.info(f"β Model loaded successfully: {model_selection}") | |
else: | |
logging.info(f"β Using cached model: {model_selection}") | |
return _cached_model, _cached_device | |
def clear_model_cache(): | |
"""Clear the cached model to free memory""" | |
global _cached_model, _cached_device, _cached_model_selection | |
if _cached_model is not None: | |
logging.info("Clearing model cache...") | |
del _cached_model | |
_cached_model = None | |
_cached_device = None | |
_cached_model_selection = None | |
# Simple cleanup | |
import gc | |
gc.collect() | |
torch.cuda.empty_cache() | |
logging.info("Model cache cleared") | |
else: | |
logging.info("No model in cache to clear") | |
def inference(left, right, model, device, n_iter=20): | |
"""Run CREStereo inference on stereo pair""" | |
print("Model Forwarding...") | |
imgL = left.transpose(2, 0, 1) | |
imgR = right.transpose(2, 0, 1) | |
imgL = np.ascontiguousarray(imgL[None, :, :, :]) | |
imgR = np.ascontiguousarray(imgR[None, :, :, :]) | |
imgL = torch.tensor(imgL.astype("float32")).to(device) | |
imgR = torch.tensor(imgR.astype("float32")).to(device) | |
# Use InputPadder to handle any image size | |
padder = InputPadder(imgL.shape, divis_by=8) | |
imgL_padded, imgR_padded = padder.pad(imgL, imgR) | |
# Downsample for coarse prediction | |
imgL_dw2 = F.interpolate( | |
imgL_padded, | |
size=(imgL_padded.shape[2] // 2, imgL_padded.shape[3] // 2), | |
mode="bilinear", | |
align_corners=True, | |
) | |
imgR_dw2 = F.interpolate( | |
imgR_padded, | |
size=(imgL_padded.shape[2] // 2, imgL_padded.shape[3] // 2), | |
mode="bilinear", | |
align_corners=True, | |
) | |
with torch.inference_mode(): | |
pred_flow_dw2 = model(imgL_dw2, imgR_dw2, iters=n_iter, flow_init=None) | |
pred_flow = model(imgL_padded, imgR_padded, iters=n_iter, flow_init=pred_flow_dw2) | |
# Unpad the result to original dimensions | |
pred_flow = padder.unpad(pred_flow) | |
pred_disp = torch.squeeze(pred_flow[:, 0, :, :]).cpu().detach().numpy() | |
return pred_disp | |
def vis_disparity(disparity_map, max_val=None): | |
"""Visualize disparity map""" | |
if max_val is None: | |
disp_vis = (disparity_map - disparity_map.min()) / (disparity_map.max() - disparity_map.min()) * 255.0 | |
else: | |
disp_vis = np.clip(disparity_map / max_val * 255.0, 0, 255) | |
disp_vis = disp_vis.astype("uint8") | |
disp_vis = cv2.applyColorMap(disp_vis, cv2.COLORMAP_INFERNO) | |
disp_vis = cv2.cvtColor(disp_vis, cv2.COLOR_BGR2RGB) | |
return disp_vis | |
# Fixed with static duration | |
# Static 60 seconds for basic processing | |
def process_stereo_pair(model_selection: str, left_image: str, right_image: str, | |
progress: gr.Progress = gr.Progress()) -> Tuple[Optional[np.ndarray], str]: | |
""" | |
Main processing function for stereo pair (with model caching) | |
""" | |
logging.info("Starting stereo pair processing...") | |
if left_image is None or right_image is None: | |
return None, "β Please upload both left and right images." | |
# Convert image paths to numpy arrays | |
logging.info(f"Loading images: left={left_image}, right={right_image}") | |
try: | |
# Load left image | |
if not os.path.exists(left_image): | |
logging.error(f"Left image file does not exist: {left_image}") | |
return None, f"β Left image file not found: {left_image}" | |
logging.info(f"Loading left image from: {left_image}") | |
left_img = cv2.imread(left_image) | |
if left_img is not None: | |
left_img = cv2.cvtColor(left_img, cv2.COLOR_BGR2RGB) | |
else: | |
# Try with imageio as fallback | |
left_img = imageio.imread(left_image) | |
if len(left_img.shape) == 3 and left_img.shape[2] == 4: | |
left_img = left_img[:, :, :3] | |
# Load right image | |
if not os.path.exists(right_image): | |
logging.error(f"Right image file does not exist: {right_image}") | |
return None, f"β Right image file not found: {right_image}" | |
logging.info(f"Loading right image from: {right_image}") | |
right_img = cv2.imread(right_image) | |
if right_img is not None: | |
right_img = cv2.cvtColor(right_img, cv2.COLOR_BGR2RGB) | |
else: | |
# Try with imageio as fallback | |
right_img = imageio.imread(right_image) | |
if len(right_img.shape) == 3 and right_img.shape[2] == 4: | |
right_img = right_img[:, :, :3] | |
logging.info(f"Images loaded successfully - Left: {left_img.shape}, Right: {right_img.shape}") | |
except Exception as e: | |
logging.error(f"Failed to load images: {e}") | |
return None, f"β Failed to load images: {str(e)}" | |
try: | |
# Get cached model | |
variant_name = model_selection.split('(')[0].strip() if '(' in model_selection else model_selection | |
progress(0.1, desc=f"Loading cached model ({variant_name})...") | |
logging.info("π Getting cached model...") | |
model, device = get_cached_model(model_selection) | |
logging.info("β Cached model loaded successfully") | |
progress(0.2, desc="Preprocessing images...") | |
# Validate input images | |
if left_img.shape != right_img.shape: | |
return None, "β Left and right images must have the same dimensions." | |
H, W = left_img.shape[:2] | |
progress(0.5, desc="Running inference...") | |
# Process stereo pair | |
torch.cuda.empty_cache() # Clear any cached memory before inference | |
disp_cpu = inference(left_img, right_img, model, device, n_iter=20) | |
progress(0.8, desc="Creating visualization...") | |
# Create visualization | |
disparity_vis = vis_disparity(disp_cpu) | |
result_image = disparity_vis | |
progress(1.0, desc="Complete!") | |
# Create status message | |
valid_mask = ~np.isinf(disp_cpu) | |
min_disp = disp_cpu[valid_mask].min() if valid_mask.any() else 0 | |
max_disp = disp_cpu[valid_mask].max() if valid_mask.any() else 0 | |
mean_disp = disp_cpu[valid_mask].mean() if valid_mask.any() else 0 | |
# Get model variant for status | |
variant = variant_name | |
# Check current memory usage | |
try: | |
current_memory = torch.cuda.memory_allocated(0) / 1024**3 | |
max_memory = torch.cuda.max_memory_allocated(0) / 1024**3 | |
memory_info = f" | GPU: {current_memory:.2f}GB/{max_memory:.2f}GB peak" | |
except: | |
memory_info = "" | |
status = f"""β Processing successful! | |
π§ Model: {variant}{memory_info} | |
π Disparity Statistics: | |
β’ Range: {min_disp:.2f} - {max_disp:.2f} | |
β’ Mean: {mean_disp:.2f} | |
β’ Input size: {W}Γ{H} | |
β’ Valid pixels: {valid_mask.sum()}/{valid_mask.size}""" | |
return result_image, status | |
except Exception as e: | |
logging.error(f"Processing failed: {e}") | |
# Clean up GPU memory | |
torch.cuda.empty_cache() | |
gc.collect() | |
return None, f"β Error: {str(e)}" | |
# Fixed with static duration | |
# Static 120 seconds for depth processing | |
def process_with_depth(model_selection: str, left_image: str, right_image: str, | |
camera_matrix: str, baseline: float, | |
progress: gr.Progress = gr.Progress()) -> Tuple[Optional[np.ndarray], Optional[str], Optional[str], str]: | |
""" | |
Process stereo pair and generate depth map and point cloud (with model caching) | |
""" | |
# Import Open3D | |
global OPEN3D_AVAILABLE | |
try: | |
import open3d as o3d | |
OPEN3D_AVAILABLE = True | |
except ImportError as e: | |
logging.warning(f"Open3D not available: {e}") | |
OPEN3D_AVAILABLE = False | |
return None, None, None, "β Open3D not available. Point cloud generation disabled." | |
if left_image is None or right_image is None: | |
return None, None, None, "β Please upload both left and right images." | |
try: | |
progress(0.1, desc="Parsing camera parameters...") | |
# Parse camera matrix | |
try: | |
K_values = list(map(float, camera_matrix.strip().split())) | |
if len(K_values) != 9: | |
return None, None, None, "β Camera matrix must contain exactly 9 values." | |
K = np.array(K_values).reshape(3, 3) | |
except ValueError: | |
return None, None, None, "β Invalid camera matrix format. Use space-separated numbers." | |
if baseline <= 0: | |
return None, None, None, "β Baseline must be positive." | |
# First get disparity using the same process as basic function | |
disparity_result, status = process_stereo_pair(model_selection, left_image, right_image, progress) | |
if disparity_result is None: | |
return None, None, None, status | |
# Load images again for depth processing | |
left_img = cv2.imread(left_image) | |
left_img = cv2.cvtColor(left_img, cv2.COLOR_BGR2RGB) | |
# Get disparity from model again (we need the raw values, not the visualization) | |
model, device = get_cached_model(model_selection) | |
disp_cpu = inference(left_img, cv2.cvtColor(cv2.imread(right_image), cv2.COLOR_BGR2RGB), model, device, n_iter=20) | |
progress(0.6, desc="Converting to depth...") | |
# Remove invisible points | |
H, W = disp_cpu.shape | |
yy, xx = np.meshgrid(np.arange(H), np.arange(W), indexing='ij') | |
us_right = xx - disp_cpu | |
invalid = us_right < 0 | |
disp_cpu[invalid] = np.inf | |
# Convert to depth using the formula: depth = focal_length * baseline / disparity | |
depth = K[0, 0] * baseline / disp_cpu | |
# Visualize depth | |
depth_vis = vis_disparity(depth, max_val=10.0) | |
progress(0.8, desc="Generating point cloud...") | |
# Generate point cloud | |
fx, fy = K[0, 0], K[1, 1] | |
cx, cy = K[0, 2], K[1, 2] | |
# Create coordinate meshgrids | |
u, v = np.meshgrid(np.arange(W), np.arange(H)) | |
# Convert to 3D coordinates | |
valid_depth = ~np.isinf(depth) | |
z = depth[valid_depth] # Z coordinate (depth) | |
x = (u[valid_depth] - cx) * z / fx # X coordinate | |
y = (v[valid_depth] - cy) * z / fy # Y coordinate | |
# Stack coordinates (X, Y, Z) | |
points = np.stack([x, y, z], axis=-1) | |
# Get corresponding colors | |
colors = left_img[valid_depth] | |
# Filter points by depth range | |
depth_mask = (z > 0) & (z <= 10.0) | |
valid_points = points[depth_mask] | |
valid_colors = colors[depth_mask] | |
if len(valid_points) == 0: | |
return depth_vis, None, None, "β οΈ No valid points generated for point cloud." | |
# Subsample points for better performance | |
if len(valid_points) > 100000: | |
indices = np.random.choice(len(valid_points), 100000, replace=False) | |
valid_points = valid_points[indices] | |
valid_colors = valid_colors[indices] | |
# Transform coordinates for proper visualization | |
transformed_points = valid_points.copy() | |
transformed_points[:, 1] = -transformed_points[:, 1] # Flip Y axis | |
transformed_points[:, 2] = -transformed_points[:, 2] # Flip Z axis | |
# Generate point cloud | |
pcd = o3d.geometry.PointCloud() | |
pcd.points = o3d.utility.Vector3dVector(transformed_points) | |
pcd.colors = o3d.utility.Vector3dVector(valid_colors / 255.0) | |
progress(1.0, desc="Complete!") | |
# Check current memory usage | |
try: | |
current_memory = torch.cuda.memory_allocated(0) / 1024**3 | |
max_memory = torch.cuda.max_memory_allocated(0) / 1024**3 | |
memory_info = f" | GPU: {current_memory:.2f}GB/{max_memory:.2f}GB peak" | |
except: | |
memory_info = "" | |
variant = model_selection.split('(')[0].strip() if '(' in model_selection else model_selection | |
status = f"""β Depth processing successful! | |
π§ Model: {variant}{memory_info} | |
π Statistics: | |
β’ Valid points: {len(valid_points):,} | |
β’ Depth range: {z.min():.2f} - {z.max():.2f} m | |
β’ Baseline: {baseline} m | |
β’ Point cloud generated with {len(valid_points)} points | |
β’ 3D visualization available""" | |
return depth_vis, None, None, status | |
except Exception as e: | |
logging.error(f"Depth processing failed: {e}") | |
torch.cuda.empty_cache() | |
gc.collect() | |
return None, None, None, f"β Error: {str(e)}" | |
def create_app() -> gr.Blocks: | |
"""Create the Gradio application""" | |
# Get available models | |
try: | |
available_models = get_available_models() | |
logging.info(f"Successfully got available models: {len(available_models)} found") | |
except Exception as e: | |
logging.error(f"Failed to get available models: {e}") | |
available_models = {} | |
with gr.Blocks( | |
title="CREStereo - Stereo Depth Estimation", | |
theme=gr.themes.Soft(), | |
css="footer {visibility: hidden}", | |
delete_cache=(60, 60) | |
) as app: | |
gr.Markdown(""" | |
# π CREStereo: Practical Stereo Matching | |
Upload a pair of **rectified** stereo images to get disparity estimation using CREStereo. | |
β οΈ **Important**: Images should be rectified (epipolar lines are horizontal) and undistorted. | |
β‘ **GPU Powered**: Runs on CUDA-enabled GPUs for fast inference. | |
""") | |
# Instructions section | |
with gr.Accordion("π Instructions", open=False): | |
gr.Markdown(""" | |
## π How to Use This Demo | |
### πΌοΈ Input Requirements | |
1. **Image Format**: Upload images in JPEG or PNG format. | |
2. **Image Size**: Images should be of the same size and resolution. | |
3. **Rectification**: Ensure images are rectified (epipolar lines are horizontal) and undistorted. | |
4. **Camera Parameters**: For depth processing, provide camera matrix and baseline distance. | |
### π Using the Demo | |
1. **Select Model**: Choose the CREStereo model variant | |
2. **Upload Images**: Provide rectified stereo image pairs | |
3. **Basic Processing**: Get disparity visualization | |
4. **Advanced Processing**: Generate depth maps and 3D point clouds (requires camera parameters) | |
### π Original Work | |
This demo is based on CREStereo: Practical Stereo Matching via Cascaded Recurrent Network. | |
- **Paper**: [CREStereo: Practical Stereo Matching via Cascaded Recurrent Network](https://arxiv.org/abs/2203.11483) | |
- **Official Repository**: [https://github.com/megvii-research/CREStereo](https://github.com/megvii-research/CREStereo) | |
""") | |
# Model selection | |
with gr.Row(): | |
all_choices = list(available_models.keys()) | |
if not all_choices: | |
all_choices = ["No models found - Please ensure crestereo_eth3d.pth is in models/ directory"] | |
default_model = all_choices[0] if all_choices else None | |
model_selector = gr.Dropdown( | |
choices=all_choices, | |
value=default_model, | |
label="π― Select Model", | |
info="Choose the CREStereo model variant.", | |
interactive=True | |
) | |
with gr.Tabs(): | |
# Basic stereo processing tab | |
with gr.TabItem("πΌοΈ Basic Stereo Processing"): | |
with gr.Row(): | |
with gr.Column(): | |
left_input = gr.Image( | |
label="π· Left Image", | |
type="filepath", | |
height=300 | |
) | |
right_input = gr.Image( | |
label="π· Right Image", | |
type="filepath", | |
height=300 | |
) | |
process_btn = gr.Button( | |
"π Process Stereo Pair", | |
variant="primary", | |
size="lg" | |
) | |
with gr.Column(): | |
output_image = gr.Image( | |
label="π Disparity Visualization", | |
height=400 | |
) | |
status_text = gr.Textbox( | |
label="Status", | |
interactive=False, | |
lines=8 | |
) | |
# Example images | |
examples_list = [] | |
# Example 1 | |
if os.path.exists(os.path.join(current_dir, "assets", "example1", "left.png")): | |
examples_list.append([ | |
os.path.join(current_dir, "assets", "example1", "left.png"), | |
os.path.join(current_dir, "assets", "example1", "right.png") | |
]) | |
# Example 2 | |
if os.path.exists(os.path.join(current_dir, "assets", "example2", "left.png")): | |
examples_list.append([ | |
os.path.join(current_dir, "assets", "example2", "left.png"), | |
os.path.join(current_dir, "assets", "example2", "right.png") | |
]) | |
if examples_list: | |
gr.Examples( | |
examples=examples_list, | |
inputs=[left_input, right_input], | |
label="π Example Images" | |
) | |
# Advanced processing with depth | |
with gr.TabItem("π Advanced Processing (Depth & Point Cloud)"): | |
with gr.Row(): | |
with gr.Column(): | |
left_input_adv = gr.Image( | |
label="π· Left Image", | |
type="filepath", | |
height=250 | |
) | |
right_input_adv = gr.Image( | |
label="π· Right Image", | |
type="filepath", | |
height=250 | |
) | |
# Camera parameters | |
with gr.Group(): | |
gr.Markdown("### πΉ Camera Parameters") | |
camera_matrix_input = gr.Textbox( | |
label="Camera Matrix (9 values: fx 0 cx 0 fy cy 0 0 1)", | |
value="", | |
) | |
baseline_input = gr.Number( | |
label="Baseline (meters)", | |
value=None, | |
minimum=0.001, | |
maximum=10.0, | |
step=0.001 | |
) | |
process_depth_btn = gr.Button( | |
"π¬ Process with Depth", | |
variant="primary", | |
size="lg" | |
) | |
with gr.Column(): | |
depth_output = gr.Image( | |
label="π Depth Visualization", | |
height=300 | |
) | |
pointcloud_output = gr.File( | |
label="βοΈ Point Cloud Download (.ply)", | |
file_types=[".ply"] | |
) | |
status_depth = gr.Textbox( | |
label="Status", | |
interactive=False, | |
lines=6 | |
) | |
# 3D Point Cloud Visualization | |
with gr.Row(): | |
pointcloud_3d = gr.Model3D( | |
label="π 3D Point Cloud Viewer", | |
clear_color=[0.0, 0.0, 0.0, 0.0], | |
height=400 | |
) | |
# Example images for advanced processing | |
examples_advanced_list = [] | |
# Try to read camera parameters from K.txt files | |
# Example 1 | |
if os.path.exists(os.path.join(current_dir, "assets", "example1", "left.png")): | |
k_file = os.path.join(current_dir, "assets", "example1", "K.txt") | |
camera_matrix_str = "" | |
baseline_val = 0.063 # default | |
if os.path.exists(k_file): | |
try: | |
with open(k_file, 'r') as f: | |
lines = f.readlines() | |
if len(lines) >= 1: | |
camera_matrix_str = lines[0].strip() | |
if len(lines) >= 2: | |
baseline_val = float(lines[1].strip()) | |
except: | |
camera_matrix_str = "754.6680908203125 0.0 489.3794860839844 0.0 754.6680908203125 265.16162109375 0.0 0.0 1.0" | |
examples_advanced_list.append([ | |
os.path.join(current_dir, "assets", "example1", "left.png"), | |
os.path.join(current_dir, "assets", "example1", "right.png"), | |
camera_matrix_str, | |
baseline_val | |
]) | |
# Example 2 | |
if os.path.exists(os.path.join(current_dir, "assets", "example2", "left.png")): | |
k_file = os.path.join(current_dir, "assets", "example2", "K.txt") | |
camera_matrix_str = "" | |
baseline_val = 0.537 # default | |
if os.path.exists(k_file): | |
try: | |
with open(k_file, 'r') as f: | |
lines = f.readlines() | |
if len(lines) >= 1: | |
camera_matrix_str = lines[0].strip() | |
if len(lines) >= 2: | |
baseline_val = float(lines[1].strip()) | |
except: | |
camera_matrix_str = "1733.74 0.0 792.27 0.0 1733.74 541.89 0.0 0.0 1.0" | |
examples_advanced_list.append([ | |
os.path.join(current_dir, "assets", "example2", "left.png"), | |
os.path.join(current_dir, "assets", "example2", "right.png"), | |
camera_matrix_str, | |
baseline_val | |
]) | |
if examples_advanced_list: | |
gr.Examples( | |
examples=examples_advanced_list, | |
inputs=[left_input_adv, right_input_adv, camera_matrix_input, baseline_input], | |
label="π Example Images with Camera Parameters" | |
) | |
# Event handlers | |
if available_models: | |
process_btn.click( | |
fn=process_stereo_pair, | |
inputs=[model_selector, left_input, right_input], | |
outputs=[output_image, status_text], | |
show_progress=True | |
) | |
if OPEN3D_AVAILABLE: | |
process_depth_btn.click( | |
fn=process_with_depth, | |
inputs=[model_selector, left_input_adv, right_input_adv, camera_matrix_input, baseline_input], | |
outputs=[depth_output, pointcloud_output, pointcloud_3d, status_depth], | |
show_progress=True | |
) | |
else: | |
process_depth_btn.click( | |
fn=lambda *args: (None, None, None, "β Open3D not available. Install with: pip install open3d"), | |
inputs=[model_selector, left_input_adv, right_input_adv, camera_matrix_input, baseline_input], | |
outputs=[depth_output, pointcloud_output, pointcloud_3d, status_depth] | |
) | |
else: | |
# No models available | |
process_btn.click( | |
fn=lambda *args: (None, "β No models available. Please ensure crestereo_eth3d.pth is in models/ directory."), | |
inputs=[model_selector, left_input, right_input], | |
outputs=[output_image, status_text] | |
) | |
process_depth_btn.click( | |
fn=lambda *args: (None, None, None, "β No models available. Please ensure crestereo_eth3d.pth is in models/ directory."), | |
inputs=[model_selector, left_input_adv, right_input_adv, camera_matrix_input, baseline_input], | |
outputs=[depth_output, pointcloud_output, pointcloud_3d, status_depth] | |
) | |
# Citation section at the bottom | |
with gr.Accordion("π Citation", open=False): | |
gr.Markdown(""" | |
### π Please Cite the Original Paper | |
If you use this work in your research, please cite: | |
```bibtex | |
@article{li2022practical, | |
title={Practical Stereo Matching via Cascaded Recurrent Network with Adaptive Correlation}, | |
author={Li, Jiankun and Wang, Peisen and Xiong, Pengfei and Cai, Tao and Yan, Ziwei and Yang, Lei and Liu, Jiangyu and Fan, Haoqiang and Liu, Shuaicheng}, | |
journal={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, | |
pages={16263--16272}, | |
year={2022} | |
} | |
``` | |
""") | |
# Footer | |
gr.Markdown(""" | |
--- | |
### π Notes: | |
- **Input images must be rectified stereo pairs** (epipolar lines are horizontal) | |
- **β‘ GPU Acceleration**: Requires CUDA-compatible GPU | |
- **π¦ Model Caching**: Models are cached for efficient repeated usage | |
- For best results, use high-quality rectified stereo pairs | |
- Model works on RGB images and supports various resolutions | |
### π References: | |
- [CREStereo Paper](https://arxiv.org/abs/2203.11483) | |
- [Original GitHub Repository](https://github.com/megvii-research/CREStereo) | |
- [This PyTorch Implementation](https://github.com/ibaiGorordo/CREStereo-Pytorch) | |
""") | |
return app | |
def main(): | |
"""Main function to launch the app""" | |
# Ensure no CUDA operations during startup | |
if torch.cuda.is_available(): | |
logging.warning("CUDA detected during startup - this should not happen in ZeroGPU") | |
logging.info("π Starting CREStereo Gradio App...") | |
# Parse command line arguments | |
import argparse | |
parser = argparse.ArgumentParser(description="CREStereo Gradio App") | |
parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to bind to") | |
parser.add_argument("--port", type=int, default=7860, help="Port to bind to") | |
parser.add_argument("--share", action="store_true", help="Create shareable link") | |
parser.add_argument("--debug", action="store_true", help="Enable debug mode") | |
args = parser.parse_args() | |
if args.debug: | |
logging.getLogger().setLevel(logging.DEBUG) | |
try: | |
# Create and launch app | |
logging.info("Creating Gradio app...") | |
app = create_app() | |
logging.info("β Gradio app created successfully") | |
logging.info(f"Launching app on {args.host}:{args.port}") | |
if args.share: | |
logging.info("Share link will be created") | |
# For ZeroGPU compatibility, launch with appropriate settings | |
app.launch( | |
server_name=args.host, | |
server_port=args.port, | |
share=args.share, | |
show_error=True, | |
favicon_path=None, | |
ssr_mode=False, # Disable SSR for ZeroGPU compatibility | |
allowed_paths=["./"] # Allow access to local files | |
) | |
except Exception as e: | |
logging.error(f"Failed to launch app: {e}") | |
raise | |
if __name__ == "__main__": | |
# Additional safety check for ZeroGPU environment | |
if 'SPACE_ID' in os.environ: | |
logging.info("Running in Hugging Face Spaces environment") | |
# Do not check CUDA status during startup - this can trigger CUDA initialization | |
# The CUDA status will be checked inside the @spaces.GPU decorated functions | |
logging.info("β CUDA status will be checked within GPU-decorated functions") | |
main() | |