|
|
""" |
|
|
CREStereo Gradio Demo with ZeroGPU Integration |
|
|
|
|
|
This demo showcases the CREStereo model for stereo depth estimation. |
|
|
Optimized for Hugging Face Spaces with ZeroGPU support. |
|
|
|
|
|
Key ZeroGPU optimizations: |
|
|
- @spaces.GPU decorators for GPU-intensive functions |
|
|
- CUDA operations only within GPU context |
|
|
- Memory-efficient inference with cleanup |
|
|
- Safe CUDA initialization patterns |
|
|
""" |
|
|
|
|
|
import os |
|
|
import sys |
|
|
import logging |
|
|
import tempfile |
|
|
import gc |
|
|
from pathlib import Path |
|
|
from typing import Optional, Tuple, Union |
|
|
import numpy as np |
|
|
import cv2 |
|
|
import gradio as gr |
|
|
import imageio |
|
|
|
|
|
|
|
|
import spaces |
|
|
|
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from torch.cuda.amp import autocast |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
current_dir = os.path.dirname(os.path.abspath(__file__)) |
|
|
base_dir = current_dir |
|
|
|
|
|
|
|
|
sys.path.insert(0, current_dir) |
|
|
|
|
|
|
|
|
from nets import Model |
|
|
|
|
|
|
|
|
OPEN3D_AVAILABLE = False |
|
|
try: |
|
|
|
|
|
os.environ['OPEN3D_CPU_RENDERING'] = '1' |
|
|
|
|
|
|
|
|
OPEN3D_AVAILABLE = True |
|
|
except Exception as e: |
|
|
logging.warning(f"Open3D setup failed: {e}") |
|
|
OPEN3D_AVAILABLE = False |
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') |
|
|
|
|
|
|
|
|
MODEL_VARIANTS = { |
|
|
"crestereo_eth3d": { |
|
|
"display_name": "CREStereo ETH3D (Pre-trained model)", |
|
|
"model_file": "models/crestereo_eth3d.pth", |
|
|
"max_disp": 256 |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
_cached_model = None |
|
|
_cached_device = None |
|
|
_cached_model_selection = None |
|
|
|
|
|
|
|
|
class InputPadder: |
|
|
""" Pads images such that dimensions are divisible by divis_by """ |
|
|
def __init__(self, dims, divis_by=8, force_square=False): |
|
|
self.ht, self.wd = dims[-2:] |
|
|
pad_ht = (((self.ht // divis_by) + 1) * divis_by - self.ht) % divis_by |
|
|
pad_wd = (((self.wd // divis_by) + 1) * divis_by - self.wd) % divis_by |
|
|
|
|
|
if force_square: |
|
|
|
|
|
max_dim = max(self.ht + pad_ht, self.wd + pad_wd) |
|
|
pad_ht = max_dim - self.ht |
|
|
pad_wd = max_dim - self.wd |
|
|
|
|
|
self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2] |
|
|
|
|
|
def pad(self, *inputs): |
|
|
return [F.pad(x, self._pad, mode='replicate') for x in inputs] |
|
|
|
|
|
def unpad(self, x): |
|
|
ht, wd = x.shape[-2:] |
|
|
c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]] |
|
|
return x[..., c[0]:c[1], c[2]:c[3]] |
|
|
|
|
|
|
|
|
def aggressive_cleanup(): |
|
|
"""Perform basic cleanup - no CUDA operations outside GPU context""" |
|
|
import gc |
|
|
gc.collect() |
|
|
logging.info("Performed basic memory cleanup") |
|
|
|
|
|
|
|
|
@spaces.GPU |
|
|
def initialize_gpu_context(): |
|
|
"""Initialize GPU context safely for ZeroGPU""" |
|
|
try: |
|
|
|
|
|
torch.set_default_tensor_type('torch.cuda.FloatTensor') |
|
|
torch.backends.cudnn.enabled = True |
|
|
torch.backends.cudnn.benchmark = True |
|
|
|
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
device_name = torch.cuda.get_device_name(0) |
|
|
memory_total = torch.cuda.get_device_properties(0).total_memory / 1024**3 |
|
|
logging.info(f"GPU initialized: {device_name}, Total memory: {memory_total:.2f}GB") |
|
|
return True |
|
|
else: |
|
|
logging.error("CUDA not available after GPU context initialization") |
|
|
return False |
|
|
except Exception as e: |
|
|
logging.error(f"GPU context initialization failed: {e}") |
|
|
return False |
|
|
|
|
|
|
|
|
@spaces.GPU |
|
|
def check_gpu_memory(): |
|
|
"""Check and log current GPU memory usage - only call within GPU context""" |
|
|
try: |
|
|
allocated = torch.cuda.memory_allocated(0) / 1024**3 |
|
|
reserved = torch.cuda.memory_reserved(0) / 1024**3 |
|
|
max_allocated = torch.cuda.max_memory_allocated(0) / 1024**3 |
|
|
total = torch.cuda.get_device_properties(0).total_memory / 1024**3 |
|
|
|
|
|
logging.info(f"GPU Memory - Allocated: {allocated:.2f}GB, Reserved: {reserved:.2f}GB, Max: {max_allocated:.2f}GB, Total: {total:.2f}GB") |
|
|
return allocated, reserved, max_allocated, total |
|
|
except RuntimeError as e: |
|
|
logging.warning(f"Failed to get GPU memory info: {e}") |
|
|
return None, None, None, None |
|
|
|
|
|
|
|
|
def get_available_models() -> dict: |
|
|
"""Get all available models with their display names""" |
|
|
models = {} |
|
|
|
|
|
|
|
|
for variant, info in MODEL_VARIANTS.items(): |
|
|
model_path = os.path.join(current_dir, info["model_file"]) |
|
|
|
|
|
if os.path.exists(model_path): |
|
|
display_name = info["display_name"] |
|
|
models[display_name] = { |
|
|
"model_path": model_path, |
|
|
"variant": variant, |
|
|
"max_disp": info["max_disp"], |
|
|
"source": "local" |
|
|
} |
|
|
|
|
|
return models |
|
|
|
|
|
|
|
|
def get_model_paths_from_selection(model_selection: str) -> Tuple[Optional[str], Optional[dict]]: |
|
|
"""Get model path and config from the selected model""" |
|
|
models = get_available_models() |
|
|
|
|
|
|
|
|
if model_selection in models: |
|
|
model_info = models[model_selection] |
|
|
logging.info(f"π Using local model: {model_selection}") |
|
|
return model_info["model_path"], model_info |
|
|
|
|
|
return None, None |
|
|
|
|
|
|
|
|
@spaces.GPU |
|
|
def load_model_for_inference(model_path: str, model_info: dict): |
|
|
"""Load CREStereo model for inference temporarily (demo-style)""" |
|
|
|
|
|
torch.set_default_tensor_type('torch.cuda.FloatTensor') |
|
|
torch.backends.cudnn.enabled = True |
|
|
torch.backends.cudnn.benchmark = True |
|
|
|
|
|
|
|
|
if not torch.cuda.is_available(): |
|
|
raise RuntimeError("CUDA is not available. ZeroGPU initialization may have failed.") |
|
|
|
|
|
|
|
|
device = torch.device("cuda") |
|
|
|
|
|
|
|
|
try: |
|
|
random_seed = 0 |
|
|
torch.cuda.manual_seed_all(random_seed) |
|
|
torch.backends.cudnn.deterministic = True |
|
|
torch.backends.cudnn.benchmark = False |
|
|
except Exception as e: |
|
|
logging.warning(f"Could not set CUDA seed: {e}") |
|
|
|
|
|
try: |
|
|
|
|
|
max_disp = model_info.get("max_disp", 256) |
|
|
model = Model(max_disp=max_disp, mixed_precision=False, test_mode=True) |
|
|
|
|
|
|
|
|
ckpt = torch.load(model_path, map_location=device) |
|
|
model.load_state_dict(ckpt, strict=True) |
|
|
model.to(device) |
|
|
model.eval() |
|
|
|
|
|
logging.info("Loaded CREStereo model weights") |
|
|
|
|
|
|
|
|
torch.set_grad_enabled(False) |
|
|
logging.info("Applied memory optimizations") |
|
|
|
|
|
return model, device |
|
|
|
|
|
except Exception as e: |
|
|
logging.error(f"Model loading failed: {e}") |
|
|
raise RuntimeError(f"Failed to load model: {e}") |
|
|
|
|
|
|
|
|
def get_cached_model(model_selection: str): |
|
|
"""Get cached model or load new one if selection changed""" |
|
|
global _cached_model, _cached_device, _cached_model_selection |
|
|
|
|
|
|
|
|
model_path, model_info = get_model_paths_from_selection(model_selection) |
|
|
|
|
|
if model_path is None or model_info is None: |
|
|
raise ValueError(f"Selected model not found: {model_selection}") |
|
|
|
|
|
|
|
|
if (_cached_model is None or |
|
|
_cached_model_selection != model_selection): |
|
|
|
|
|
|
|
|
if _cached_model is not None: |
|
|
del _cached_model |
|
|
torch.cuda.empty_cache() |
|
|
gc.collect() |
|
|
|
|
|
logging.info(f"π Loading model: {model_selection}") |
|
|
_cached_model, _cached_device = load_model_for_inference(model_path, model_info) |
|
|
_cached_model_selection = model_selection |
|
|
|
|
|
logging.info(f"β
Model loaded successfully: {model_selection}") |
|
|
else: |
|
|
logging.info(f"β
Using cached model: {model_selection}") |
|
|
|
|
|
return _cached_model, _cached_device |
|
|
|
|
|
|
|
|
def clear_model_cache(): |
|
|
"""Clear the cached model to free memory""" |
|
|
global _cached_model, _cached_device, _cached_model_selection |
|
|
|
|
|
if _cached_model is not None: |
|
|
logging.info("Clearing model cache...") |
|
|
del _cached_model |
|
|
_cached_model = None |
|
|
_cached_device = None |
|
|
_cached_model_selection = None |
|
|
|
|
|
|
|
|
import gc |
|
|
gc.collect() |
|
|
torch.cuda.empty_cache() |
|
|
logging.info("Model cache cleared") |
|
|
else: |
|
|
logging.info("No model in cache to clear") |
|
|
|
|
|
|
|
|
def inference(left, right, model, device, n_iter=20): |
|
|
"""Run CREStereo inference on stereo pair""" |
|
|
print("Model Forwarding...") |
|
|
imgL = left.transpose(2, 0, 1) |
|
|
imgR = right.transpose(2, 0, 1) |
|
|
imgL = np.ascontiguousarray(imgL[None, :, :, :]) |
|
|
imgR = np.ascontiguousarray(imgR[None, :, :, :]) |
|
|
|
|
|
imgL = torch.tensor(imgL.astype("float32")).to(device) |
|
|
imgR = torch.tensor(imgR.astype("float32")).to(device) |
|
|
|
|
|
|
|
|
padder = InputPadder(imgL.shape, divis_by=8) |
|
|
imgL_padded, imgR_padded = padder.pad(imgL, imgR) |
|
|
|
|
|
|
|
|
imgL_dw2 = F.interpolate( |
|
|
imgL_padded, |
|
|
size=(imgL_padded.shape[2] // 2, imgL_padded.shape[3] // 2), |
|
|
mode="bilinear", |
|
|
align_corners=True, |
|
|
) |
|
|
imgR_dw2 = F.interpolate( |
|
|
imgR_padded, |
|
|
size=(imgL_padded.shape[2] // 2, imgL_padded.shape[3] // 2), |
|
|
mode="bilinear", |
|
|
align_corners=True, |
|
|
) |
|
|
|
|
|
with torch.inference_mode(): |
|
|
pred_flow_dw2 = model(imgL_dw2, imgR_dw2, iters=n_iter, flow_init=None) |
|
|
pred_flow = model(imgL_padded, imgR_padded, iters=n_iter, flow_init=pred_flow_dw2) |
|
|
|
|
|
|
|
|
pred_flow = padder.unpad(pred_flow) |
|
|
pred_disp = torch.squeeze(pred_flow[:, 0, :, :]).cpu().detach().numpy() |
|
|
|
|
|
return pred_disp |
|
|
|
|
|
|
|
|
def vis_disparity(disparity_map, max_val=None): |
|
|
"""Visualize disparity map""" |
|
|
if max_val is None: |
|
|
disp_vis = (disparity_map - disparity_map.min()) / (disparity_map.max() - disparity_map.min()) * 255.0 |
|
|
else: |
|
|
disp_vis = np.clip(disparity_map / max_val * 255.0, 0, 255) |
|
|
|
|
|
disp_vis = disp_vis.astype("uint8") |
|
|
disp_vis = cv2.applyColorMap(disp_vis, cv2.COLORMAP_INFERNO) |
|
|
disp_vis = cv2.cvtColor(disp_vis, cv2.COLOR_BGR2RGB) |
|
|
return disp_vis |
|
|
|
|
|
|
|
|
|
|
|
@spaces.GPU(duration=60) |
|
|
def process_stereo_pair(model_selection: str, left_image: str, right_image: str, |
|
|
progress: gr.Progress = gr.Progress()) -> Tuple[Optional[np.ndarray], str]: |
|
|
""" |
|
|
Main processing function for stereo pair (with model caching) |
|
|
""" |
|
|
logging.info("Starting stereo pair processing...") |
|
|
|
|
|
if left_image is None or right_image is None: |
|
|
return None, "β Please upload both left and right images." |
|
|
|
|
|
|
|
|
logging.info(f"Loading images: left={left_image}, right={right_image}") |
|
|
|
|
|
try: |
|
|
|
|
|
if not os.path.exists(left_image): |
|
|
logging.error(f"Left image file does not exist: {left_image}") |
|
|
return None, f"β Left image file not found: {left_image}" |
|
|
|
|
|
logging.info(f"Loading left image from: {left_image}") |
|
|
left_img = cv2.imread(left_image) |
|
|
if left_img is not None: |
|
|
left_img = cv2.cvtColor(left_img, cv2.COLOR_BGR2RGB) |
|
|
else: |
|
|
|
|
|
left_img = imageio.imread(left_image) |
|
|
if len(left_img.shape) == 3 and left_img.shape[2] == 4: |
|
|
left_img = left_img[:, :, :3] |
|
|
|
|
|
|
|
|
if not os.path.exists(right_image): |
|
|
logging.error(f"Right image file does not exist: {right_image}") |
|
|
return None, f"β Right image file not found: {right_image}" |
|
|
|
|
|
logging.info(f"Loading right image from: {right_image}") |
|
|
right_img = cv2.imread(right_image) |
|
|
if right_img is not None: |
|
|
right_img = cv2.cvtColor(right_img, cv2.COLOR_BGR2RGB) |
|
|
else: |
|
|
|
|
|
right_img = imageio.imread(right_image) |
|
|
if len(right_img.shape) == 3 and right_img.shape[2] == 4: |
|
|
right_img = right_img[:, :, :3] |
|
|
|
|
|
logging.info(f"Images loaded successfully - Left: {left_img.shape}, Right: {right_img.shape}") |
|
|
|
|
|
except Exception as e: |
|
|
logging.error(f"Failed to load images: {e}") |
|
|
return None, f"β Failed to load images: {str(e)}" |
|
|
|
|
|
try: |
|
|
|
|
|
variant_name = model_selection.split('(')[0].strip() if '(' in model_selection else model_selection |
|
|
progress(0.1, desc=f"Loading cached model ({variant_name})...") |
|
|
logging.info("π Getting cached model...") |
|
|
model, device = get_cached_model(model_selection) |
|
|
logging.info("β
Cached model loaded successfully") |
|
|
|
|
|
progress(0.2, desc="Preprocessing images...") |
|
|
|
|
|
|
|
|
if left_img.shape != right_img.shape: |
|
|
return None, "β Left and right images must have the same dimensions." |
|
|
|
|
|
H, W = left_img.shape[:2] |
|
|
|
|
|
progress(0.5, desc="Running inference...") |
|
|
|
|
|
|
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
disp_cpu = inference(left_img, right_img, model, device, n_iter=20) |
|
|
|
|
|
progress(0.8, desc="Creating visualization...") |
|
|
|
|
|
|
|
|
disparity_vis = vis_disparity(disp_cpu) |
|
|
result_image = disparity_vis |
|
|
|
|
|
progress(1.0, desc="Complete!") |
|
|
|
|
|
|
|
|
valid_mask = ~np.isinf(disp_cpu) |
|
|
min_disp = disp_cpu[valid_mask].min() if valid_mask.any() else 0 |
|
|
max_disp = disp_cpu[valid_mask].max() if valid_mask.any() else 0 |
|
|
mean_disp = disp_cpu[valid_mask].mean() if valid_mask.any() else 0 |
|
|
|
|
|
|
|
|
variant = variant_name |
|
|
|
|
|
|
|
|
try: |
|
|
current_memory = torch.cuda.memory_allocated(0) / 1024**3 |
|
|
max_memory = torch.cuda.max_memory_allocated(0) / 1024**3 |
|
|
memory_info = f" | GPU: {current_memory:.2f}GB/{max_memory:.2f}GB peak" |
|
|
except: |
|
|
memory_info = "" |
|
|
|
|
|
status = f"""β
Processing successful! |
|
|
π§ Model: {variant}{memory_info} |
|
|
π Disparity Statistics: |
|
|
β’ Range: {min_disp:.2f} - {max_disp:.2f} |
|
|
β’ Mean: {mean_disp:.2f} |
|
|
β’ Input size: {W}Γ{H} |
|
|
β’ Valid pixels: {valid_mask.sum()}/{valid_mask.size}""" |
|
|
|
|
|
return result_image, status |
|
|
|
|
|
except Exception as e: |
|
|
logging.error(f"Processing failed: {e}") |
|
|
|
|
|
torch.cuda.empty_cache() |
|
|
gc.collect() |
|
|
return None, f"β Error: {str(e)}" |
|
|
|
|
|
|
|
|
|
|
|
@spaces.GPU(duration=120) |
|
|
def process_with_depth(model_selection: str, left_image: str, right_image: str, |
|
|
camera_matrix: str, baseline: float, |
|
|
progress: gr.Progress = gr.Progress()) -> Tuple[Optional[np.ndarray], Optional[str], Optional[str], str]: |
|
|
""" |
|
|
Process stereo pair and generate depth map and point cloud (with model caching) |
|
|
""" |
|
|
|
|
|
global OPEN3D_AVAILABLE |
|
|
try: |
|
|
import open3d as o3d |
|
|
OPEN3D_AVAILABLE = True |
|
|
except ImportError as e: |
|
|
logging.warning(f"Open3D not available: {e}") |
|
|
OPEN3D_AVAILABLE = False |
|
|
return None, None, None, "β Open3D not available. Point cloud generation disabled." |
|
|
|
|
|
if left_image is None or right_image is None: |
|
|
return None, None, None, "β Please upload both left and right images." |
|
|
|
|
|
try: |
|
|
progress(0.1, desc="Parsing camera parameters...") |
|
|
|
|
|
|
|
|
try: |
|
|
K_values = list(map(float, camera_matrix.strip().split())) |
|
|
if len(K_values) != 9: |
|
|
return None, None, None, "β Camera matrix must contain exactly 9 values." |
|
|
K = np.array(K_values).reshape(3, 3) |
|
|
except ValueError: |
|
|
return None, None, None, "β Invalid camera matrix format. Use space-separated numbers." |
|
|
|
|
|
if baseline <= 0: |
|
|
return None, None, None, "β Baseline must be positive." |
|
|
|
|
|
|
|
|
disparity_result, status = process_stereo_pair(model_selection, left_image, right_image, progress) |
|
|
|
|
|
if disparity_result is None: |
|
|
return None, None, None, status |
|
|
|
|
|
|
|
|
left_img = cv2.imread(left_image) |
|
|
left_img = cv2.cvtColor(left_img, cv2.COLOR_BGR2RGB) |
|
|
|
|
|
|
|
|
model, device = get_cached_model(model_selection) |
|
|
disp_cpu = inference(left_img, cv2.cvtColor(cv2.imread(right_image), cv2.COLOR_BGR2RGB), model, device, n_iter=20) |
|
|
|
|
|
progress(0.6, desc="Converting to depth...") |
|
|
|
|
|
|
|
|
H, W = disp_cpu.shape |
|
|
yy, xx = np.meshgrid(np.arange(H), np.arange(W), indexing='ij') |
|
|
us_right = xx - disp_cpu |
|
|
invalid = us_right < 0 |
|
|
disp_cpu[invalid] = np.inf |
|
|
|
|
|
|
|
|
depth = K[0, 0] * baseline / disp_cpu |
|
|
|
|
|
|
|
|
depth_vis = vis_disparity(depth, max_val=10.0) |
|
|
|
|
|
progress(0.8, desc="Generating point cloud...") |
|
|
|
|
|
|
|
|
fx, fy = K[0, 0], K[1, 1] |
|
|
cx, cy = K[0, 2], K[1, 2] |
|
|
|
|
|
|
|
|
u, v = np.meshgrid(np.arange(W), np.arange(H)) |
|
|
|
|
|
|
|
|
valid_depth = ~np.isinf(depth) |
|
|
z = depth[valid_depth] |
|
|
x = (u[valid_depth] - cx) * z / fx |
|
|
y = (v[valid_depth] - cy) * z / fy |
|
|
|
|
|
|
|
|
points = np.stack([x, y, z], axis=-1) |
|
|
|
|
|
|
|
|
colors = left_img[valid_depth] |
|
|
|
|
|
|
|
|
depth_mask = (z > 0) & (z <= 10.0) |
|
|
valid_points = points[depth_mask] |
|
|
valid_colors = colors[depth_mask] |
|
|
|
|
|
if len(valid_points) == 0: |
|
|
return depth_vis, None, None, "β οΈ No valid points generated for point cloud." |
|
|
|
|
|
|
|
|
if len(valid_points) > 100000: |
|
|
indices = np.random.choice(len(valid_points), 100000, replace=False) |
|
|
valid_points = valid_points[indices] |
|
|
valid_colors = valid_colors[indices] |
|
|
|
|
|
|
|
|
transformed_points = valid_points.copy() |
|
|
transformed_points[:, 1] = -transformed_points[:, 1] |
|
|
transformed_points[:, 2] = -transformed_points[:, 2] |
|
|
|
|
|
|
|
|
pcd = o3d.geometry.PointCloud() |
|
|
pcd.points = o3d.utility.Vector3dVector(transformed_points) |
|
|
pcd.colors = o3d.utility.Vector3dVector(valid_colors / 255.0) |
|
|
|
|
|
progress(1.0, desc="Complete!") |
|
|
|
|
|
|
|
|
try: |
|
|
current_memory = torch.cuda.memory_allocated(0) / 1024**3 |
|
|
max_memory = torch.cuda.max_memory_allocated(0) / 1024**3 |
|
|
memory_info = f" | GPU: {current_memory:.2f}GB/{max_memory:.2f}GB peak" |
|
|
except: |
|
|
memory_info = "" |
|
|
|
|
|
variant = model_selection.split('(')[0].strip() if '(' in model_selection else model_selection |
|
|
|
|
|
status = f"""β
Depth processing successful! |
|
|
π§ Model: {variant}{memory_info} |
|
|
π Statistics: |
|
|
β’ Valid points: {len(valid_points):,} |
|
|
β’ Depth range: {z.min():.2f} - {z.max():.2f} m |
|
|
β’ Baseline: {baseline} m |
|
|
β’ Point cloud generated with {len(valid_points)} points |
|
|
β’ 3D visualization available""" |
|
|
|
|
|
return depth_vis, None, None, status |
|
|
|
|
|
except Exception as e: |
|
|
logging.error(f"Depth processing failed: {e}") |
|
|
torch.cuda.empty_cache() |
|
|
gc.collect() |
|
|
return None, None, None, f"β Error: {str(e)}" |
|
|
|
|
|
|
|
|
def create_app() -> gr.Blocks: |
|
|
"""Create the Gradio application""" |
|
|
|
|
|
|
|
|
try: |
|
|
available_models = get_available_models() |
|
|
logging.info(f"Successfully got available models: {len(available_models)} found") |
|
|
except Exception as e: |
|
|
logging.error(f"Failed to get available models: {e}") |
|
|
available_models = {} |
|
|
|
|
|
with gr.Blocks( |
|
|
title="CREStereo - Stereo Depth Estimation", |
|
|
theme=gr.themes.Soft(), |
|
|
css="footer {visibility: hidden}", |
|
|
delete_cache=(60, 60) |
|
|
) as app: |
|
|
|
|
|
gr.Markdown(""" |
|
|
# π CREStereo: Practical Stereo Matching |
|
|
|
|
|
Upload a pair of **rectified** stereo images to get disparity estimation using CREStereo. |
|
|
|
|
|
β οΈ **Important**: Images should be rectified (epipolar lines are horizontal) and undistorted. |
|
|
β‘ **GPU Powered**: Runs on CUDA-enabled GPUs for fast inference. |
|
|
""") |
|
|
|
|
|
|
|
|
with gr.Accordion("π Instructions", open=False): |
|
|
gr.Markdown(""" |
|
|
## π How to Use This Demo |
|
|
|
|
|
### πΌοΈ Input Requirements |
|
|
1. **Image Format**: Upload images in JPEG or PNG format. |
|
|
2. **Image Size**: Images should be of the same size and resolution. |
|
|
3. **Rectification**: Ensure images are rectified (epipolar lines are horizontal) and undistorted. |
|
|
4. **Camera Parameters**: For depth processing, provide camera matrix and baseline distance. |
|
|
|
|
|
### π Using the Demo |
|
|
1. **Select Model**: Choose the CREStereo model variant |
|
|
2. **Upload Images**: Provide rectified stereo image pairs |
|
|
3. **Basic Processing**: Get disparity visualization |
|
|
4. **Advanced Processing**: Generate depth maps and 3D point clouds (requires camera parameters) |
|
|
|
|
|
### π Original Work |
|
|
This demo is based on CREStereo: Practical Stereo Matching via Cascaded Recurrent Network. |
|
|
- **Paper**: [CREStereo: Practical Stereo Matching via Cascaded Recurrent Network](https://arxiv.org/abs/2203.11483) |
|
|
- **Official Repository**: [https://github.com/megvii-research/CREStereo](https://github.com/megvii-research/CREStereo) |
|
|
""") |
|
|
|
|
|
|
|
|
with gr.Row(): |
|
|
all_choices = list(available_models.keys()) |
|
|
|
|
|
if not all_choices: |
|
|
all_choices = ["No models found - Please ensure crestereo_eth3d.pth is in models/ directory"] |
|
|
|
|
|
default_model = all_choices[0] if all_choices else None |
|
|
|
|
|
model_selector = gr.Dropdown( |
|
|
choices=all_choices, |
|
|
value=default_model, |
|
|
label="π― Select Model", |
|
|
info="Choose the CREStereo model variant.", |
|
|
interactive=True |
|
|
) |
|
|
|
|
|
with gr.Tabs(): |
|
|
|
|
|
with gr.TabItem("πΌοΈ Basic Stereo Processing"): |
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
left_input = gr.Image( |
|
|
label="π· Left Image", |
|
|
type="filepath", |
|
|
height=300 |
|
|
) |
|
|
right_input = gr.Image( |
|
|
label="π· Right Image", |
|
|
type="filepath", |
|
|
height=300 |
|
|
) |
|
|
|
|
|
process_btn = gr.Button( |
|
|
"π Process Stereo Pair", |
|
|
variant="primary", |
|
|
size="lg" |
|
|
) |
|
|
|
|
|
with gr.Column(): |
|
|
output_image = gr.Image( |
|
|
label="π Disparity Visualization", |
|
|
height=400 |
|
|
) |
|
|
status_text = gr.Textbox( |
|
|
label="Status", |
|
|
interactive=False, |
|
|
lines=8 |
|
|
) |
|
|
|
|
|
|
|
|
examples_list = [] |
|
|
|
|
|
|
|
|
if os.path.exists(os.path.join(current_dir, "assets", "example1", "left.png")): |
|
|
examples_list.append([ |
|
|
os.path.join(current_dir, "assets", "example1", "left.png"), |
|
|
os.path.join(current_dir, "assets", "example1", "right.png") |
|
|
]) |
|
|
|
|
|
|
|
|
if os.path.exists(os.path.join(current_dir, "assets", "example2", "left.png")): |
|
|
examples_list.append([ |
|
|
os.path.join(current_dir, "assets", "example2", "left.png"), |
|
|
os.path.join(current_dir, "assets", "example2", "right.png") |
|
|
]) |
|
|
|
|
|
if examples_list: |
|
|
gr.Examples( |
|
|
examples=examples_list, |
|
|
inputs=[left_input, right_input], |
|
|
label="π Example Images" |
|
|
) |
|
|
|
|
|
|
|
|
with gr.TabItem("π Advanced Processing (Depth & Point Cloud)"): |
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
left_input_adv = gr.Image( |
|
|
label="π· Left Image", |
|
|
type="filepath", |
|
|
height=250 |
|
|
) |
|
|
right_input_adv = gr.Image( |
|
|
label="π· Right Image", |
|
|
type="filepath", |
|
|
height=250 |
|
|
) |
|
|
|
|
|
|
|
|
with gr.Group(): |
|
|
gr.Markdown("### πΉ Camera Parameters") |
|
|
camera_matrix_input = gr.Textbox( |
|
|
label="Camera Matrix (9 values: fx 0 cx 0 fy cy 0 0 1)", |
|
|
value="", |
|
|
) |
|
|
baseline_input = gr.Number( |
|
|
label="Baseline (meters)", |
|
|
value=None, |
|
|
minimum=0.001, |
|
|
maximum=10.0, |
|
|
step=0.001 |
|
|
) |
|
|
|
|
|
process_depth_btn = gr.Button( |
|
|
"π¬ Process with Depth", |
|
|
variant="primary", |
|
|
size="lg" |
|
|
) |
|
|
|
|
|
with gr.Column(): |
|
|
depth_output = gr.Image( |
|
|
label="π Depth Visualization", |
|
|
height=300 |
|
|
) |
|
|
pointcloud_output = gr.File( |
|
|
label="βοΈ Point Cloud Download (.ply)", |
|
|
file_types=[".ply"] |
|
|
) |
|
|
status_depth = gr.Textbox( |
|
|
label="Status", |
|
|
interactive=False, |
|
|
lines=6 |
|
|
) |
|
|
|
|
|
|
|
|
with gr.Row(): |
|
|
pointcloud_3d = gr.Model3D( |
|
|
label="π 3D Point Cloud Viewer", |
|
|
clear_color=[0.0, 0.0, 0.0, 0.0], |
|
|
height=400 |
|
|
) |
|
|
|
|
|
|
|
|
examples_advanced_list = [] |
|
|
|
|
|
|
|
|
|
|
|
if os.path.exists(os.path.join(current_dir, "assets", "example1", "left.png")): |
|
|
k_file = os.path.join(current_dir, "assets", "example1", "K.txt") |
|
|
camera_matrix_str = "" |
|
|
baseline_val = 0.063 |
|
|
|
|
|
if os.path.exists(k_file): |
|
|
try: |
|
|
with open(k_file, 'r') as f: |
|
|
lines = f.readlines() |
|
|
if len(lines) >= 1: |
|
|
camera_matrix_str = lines[0].strip() |
|
|
if len(lines) >= 2: |
|
|
baseline_val = float(lines[1].strip()) |
|
|
except: |
|
|
camera_matrix_str = "754.6680908203125 0.0 489.3794860839844 0.0 754.6680908203125 265.16162109375 0.0 0.0 1.0" |
|
|
|
|
|
examples_advanced_list.append([ |
|
|
os.path.join(current_dir, "assets", "example1", "left.png"), |
|
|
os.path.join(current_dir, "assets", "example1", "right.png"), |
|
|
camera_matrix_str, |
|
|
baseline_val |
|
|
]) |
|
|
|
|
|
|
|
|
if os.path.exists(os.path.join(current_dir, "assets", "example2", "left.png")): |
|
|
k_file = os.path.join(current_dir, "assets", "example2", "K.txt") |
|
|
camera_matrix_str = "" |
|
|
baseline_val = 0.537 |
|
|
|
|
|
if os.path.exists(k_file): |
|
|
try: |
|
|
with open(k_file, 'r') as f: |
|
|
lines = f.readlines() |
|
|
if len(lines) >= 1: |
|
|
camera_matrix_str = lines[0].strip() |
|
|
if len(lines) >= 2: |
|
|
baseline_val = float(lines[1].strip()) |
|
|
except: |
|
|
camera_matrix_str = "1733.74 0.0 792.27 0.0 1733.74 541.89 0.0 0.0 1.0" |
|
|
|
|
|
examples_advanced_list.append([ |
|
|
os.path.join(current_dir, "assets", "example2", "left.png"), |
|
|
os.path.join(current_dir, "assets", "example2", "right.png"), |
|
|
camera_matrix_str, |
|
|
baseline_val |
|
|
]) |
|
|
|
|
|
if examples_advanced_list: |
|
|
gr.Examples( |
|
|
examples=examples_advanced_list, |
|
|
inputs=[left_input_adv, right_input_adv, camera_matrix_input, baseline_input], |
|
|
label="π Example Images with Camera Parameters" |
|
|
) |
|
|
|
|
|
|
|
|
if available_models: |
|
|
process_btn.click( |
|
|
fn=process_stereo_pair, |
|
|
inputs=[model_selector, left_input, right_input], |
|
|
outputs=[output_image, status_text], |
|
|
show_progress=True |
|
|
) |
|
|
|
|
|
if OPEN3D_AVAILABLE: |
|
|
process_depth_btn.click( |
|
|
fn=process_with_depth, |
|
|
inputs=[model_selector, left_input_adv, right_input_adv, camera_matrix_input, baseline_input], |
|
|
outputs=[depth_output, pointcloud_output, pointcloud_3d, status_depth], |
|
|
show_progress=True |
|
|
) |
|
|
else: |
|
|
process_depth_btn.click( |
|
|
fn=lambda *args: (None, None, None, "β Open3D not available. Install with: pip install open3d"), |
|
|
inputs=[model_selector, left_input_adv, right_input_adv, camera_matrix_input, baseline_input], |
|
|
outputs=[depth_output, pointcloud_output, pointcloud_3d, status_depth] |
|
|
) |
|
|
else: |
|
|
|
|
|
process_btn.click( |
|
|
fn=lambda *args: (None, "β No models available. Please ensure crestereo_eth3d.pth is in models/ directory."), |
|
|
inputs=[model_selector, left_input, right_input], |
|
|
outputs=[output_image, status_text] |
|
|
) |
|
|
|
|
|
process_depth_btn.click( |
|
|
fn=lambda *args: (None, None, None, "β No models available. Please ensure crestereo_eth3d.pth is in models/ directory."), |
|
|
inputs=[model_selector, left_input_adv, right_input_adv, camera_matrix_input, baseline_input], |
|
|
outputs=[depth_output, pointcloud_output, pointcloud_3d, status_depth] |
|
|
) |
|
|
|
|
|
|
|
|
with gr.Accordion("π Citation", open=False): |
|
|
gr.Markdown(""" |
|
|
### π Please Cite the Original Paper |
|
|
|
|
|
If you use this work in your research, please cite: |
|
|
|
|
|
```bibtex |
|
|
@article{li2022practical, |
|
|
title={Practical Stereo Matching via Cascaded Recurrent Network with Adaptive Correlation}, |
|
|
author={Li, Jiankun and Wang, Peisen and Xiong, Pengfei and Cai, Tao and Yan, Ziwei and Yang, Lei and Liu, Jiangyu and Fan, Haoqiang and Liu, Shuaicheng}, |
|
|
journal={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, |
|
|
pages={16263--16272}, |
|
|
year={2022} |
|
|
} |
|
|
``` |
|
|
""") |
|
|
|
|
|
|
|
|
gr.Markdown(""" |
|
|
--- |
|
|
### π Notes: |
|
|
- **Input images must be rectified stereo pairs** (epipolar lines are horizontal) |
|
|
- **β‘ GPU Acceleration**: Requires CUDA-compatible GPU |
|
|
- **π¦ Model Caching**: Models are cached for efficient repeated usage |
|
|
- For best results, use high-quality rectified stereo pairs |
|
|
- Model works on RGB images and supports various resolutions |
|
|
|
|
|
### π References: |
|
|
- [CREStereo Paper](https://arxiv.org/abs/2203.11483) |
|
|
- [Original GitHub Repository](https://github.com/megvii-research/CREStereo) |
|
|
- [This PyTorch Implementation](https://github.com/ibaiGorordo/CREStereo-Pytorch) |
|
|
""") |
|
|
|
|
|
return app |
|
|
|
|
|
|
|
|
def main(): |
|
|
"""Main function to launch the app""" |
|
|
|
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
logging.warning("CUDA detected during startup - this should not happen in ZeroGPU") |
|
|
|
|
|
logging.info("π Starting CREStereo Gradio App...") |
|
|
|
|
|
|
|
|
import argparse |
|
|
parser = argparse.ArgumentParser(description="CREStereo Gradio App") |
|
|
parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to bind to") |
|
|
parser.add_argument("--port", type=int, default=7860, help="Port to bind to") |
|
|
parser.add_argument("--share", action="store_true", help="Create shareable link") |
|
|
parser.add_argument("--debug", action="store_true", help="Enable debug mode") |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
if args.debug: |
|
|
logging.getLogger().setLevel(logging.DEBUG) |
|
|
|
|
|
try: |
|
|
|
|
|
logging.info("Creating Gradio app...") |
|
|
app = create_app() |
|
|
logging.info("β
Gradio app created successfully") |
|
|
|
|
|
logging.info(f"Launching app on {args.host}:{args.port}") |
|
|
if args.share: |
|
|
logging.info("Share link will be created") |
|
|
|
|
|
|
|
|
app.launch( |
|
|
server_name=args.host, |
|
|
server_port=args.port, |
|
|
share=args.share, |
|
|
show_error=True, |
|
|
favicon_path=None, |
|
|
ssr_mode=False, |
|
|
allowed_paths=["./"] |
|
|
) |
|
|
except Exception as e: |
|
|
logging.error(f"Failed to launch app: {e}") |
|
|
raise |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
if 'SPACE_ID' in os.environ: |
|
|
logging.info("Running in Hugging Face Spaces environment") |
|
|
|
|
|
|
|
|
|
|
|
logging.info("β
CUDA status will be checked within GPU-decorated functions") |
|
|
|
|
|
main() |
|
|
|