MogensR commited on
Commit
dbe3751
·
1 Parent(s): 967f336

Create utils/mask_bridge.py

Browse files
Files changed (1) hide show
  1. utils/mask_bridge.py +24 -42
utils/mask_bridge.py CHANGED
@@ -1,60 +1,42 @@
1
- #!/usr/bin/env python3
2
- """
3
- mask_bridge.py - SAM2 to MatAnyone mask conversion
4
- Handles shape/dtype/device normalization between models
5
- """
6
-
7
  import torch
8
- import math
9
- from typing import Optional, Tuple
10
 
11
  def log_shape(tag: str, t: torch.Tensor) -> None:
12
- """Debug logging for tensor shapes and values"""
13
- mn = float(t.min()) if t.numel() else math.nan
14
- mx = float(t.max()) if t.numel() else math.nan
15
- print(f"{tag}: shape={tuple(t.shape)} dtype={t.dtype} device={t.device} range=[{mn:.3f},{mx:.3f}]")
 
 
 
16
 
17
  def sam2_to_matanyone_mask(
18
- sam2_masks: torch.Tensor, # shape: (B, M, H, W) from SAM2 post_process_masks
19
- iou_scores: Optional[torch.Tensor] = None, # optional, (B, M)
20
- threshold: float = 0.5, # binarization for hard mask
21
- return_mode: str = "single", # "single" (1,H,W); "multi" (C,H,W)
22
- keep_soft: bool = False, # if True → soft [0,1] alpha channel
23
  ) -> torch.Tensor:
24
- """
25
- Convert SAM2 output masks to MatAnyone-ready format.
26
-
27
- Returns a MatAnyone-ready tensor on the same device:
28
- - "single": (1,H,W) float32 in [0,1]
29
- - "multi": (C,H,W) float32 in [0,1]
30
- """
31
  assert sam2_masks.ndim == 4, f"Expect (B,M,H,W). Got {tuple(sam2_masks.shape)}"
32
  B, M, H, W = sam2_masks.shape
33
- assert B == 1, "We pass one frame to build first-frame mask."
34
-
35
- masks = sam2_masks[0] # (M,H,W)
36
 
37
- # Choose best mask
38
- if iou_scores is not None and iou_scores.ndim == 2:
39
  best_idx = int(torch.argmax(iou_scores[0]).item())
40
  else:
41
- # Fallback: pick the mask with largest foreground area
42
- areas = masks.sum(dim=(1,2))
43
  best_idx = int(torch.argmax(areas).item())
44
 
45
  if return_mode == "multi":
46
- out = masks
47
  else:
48
- out = masks[best_idx:best_idx+1] # (1,H,W)
49
 
50
- # Ensure float32 [0,1]
51
- out = out.to(dtype=torch.float32)
52
  if not keep_soft:
53
  out = (out >= threshold).float()
54
-
55
- # Final sanity: contiguous, shapes
56
- out = out.contiguous()
57
- assert out.ndim == 3, f"Expect (C,H,W); got {tuple(out.shape)}"
58
- assert out.shape[0] >= 1, f"Need at least 1 channel; got {out.shape[0]}"
59
-
60
- return out # (C,H,W) float32 [0,1]
 
1
+ # utils/mask_bridge.py
2
+ from __future__ import annotations
 
 
 
 
3
  import torch
 
 
4
 
5
  def log_shape(tag: str, t: torch.Tensor) -> None:
6
+ try:
7
+ mn = float(t.min()) if t.numel() else float("nan")
8
+ mx = float(t.max()) if t.numel() else float("nan")
9
+ print(f"[mask_bridge] {tag}: shape={tuple(t.shape)} dtype={t.dtype} device={t.device} "
10
+ f"range=[{mn:.4f},{mx:.4f}]")
11
+ except Exception:
12
+ pass
13
 
14
  def sam2_to_matanyone_mask(
15
+ sam2_masks: torch.Tensor, # (B,M,H,W) after post_process
16
+ iou_scores: torch.Tensor | None, # (B,M) or None
17
+ threshold: float = 0.5,
18
+ return_mode: str = "single", # "single"→(1,H,W) or "multi"→(C,H,W)
19
+ keep_soft: bool = False,
20
  ) -> torch.Tensor:
 
 
 
 
 
 
 
21
  assert sam2_masks.ndim == 4, f"Expect (B,M,H,W). Got {tuple(sam2_masks.shape)}"
22
  B, M, H, W = sam2_masks.shape
23
+ assert B == 1, "Bridge expects B=1 for first-frame bootstrapping"
 
 
24
 
25
+ candidates = sam2_masks[0] # (M,H,W)
26
+ if iou_scores is not None and iou_scores.ndim == 2 and iou_scores.shape[0] == 1:
27
  best_idx = int(torch.argmax(iou_scores[0]).item())
28
  else:
29
+ areas = candidates.sum(dim=(-2,-1))
 
30
  best_idx = int(torch.argmax(areas).item())
31
 
32
  if return_mode == "multi":
33
+ out = candidates # (M,H,W) treat as (C,H,W)
34
  else:
35
+ out = candidates[best_idx:best_idx+1] # (1,H,W)
36
 
37
+ out = out.to(torch.float32)
 
38
  if not keep_soft:
39
  out = (out >= threshold).float()
40
+ out = out.clamp_(0.0, 1.0).contiguous()
41
+ log_shape("sam2→mat.mask", out)
42
+ return out