Create utils/mask_bridge.py
Browse files- utils/mask_bridge.py +24 -42
utils/mask_bridge.py
CHANGED
@@ -1,60 +1,42 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
mask_bridge.py - SAM2 to MatAnyone mask conversion
|
4 |
-
Handles shape/dtype/device normalization between models
|
5 |
-
"""
|
6 |
-
|
7 |
import torch
|
8 |
-
import math
|
9 |
-
from typing import Optional, Tuple
|
10 |
|
11 |
def log_shape(tag: str, t: torch.Tensor) -> None:
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
|
|
|
|
|
|
16 |
|
17 |
def sam2_to_matanyone_mask(
|
18 |
-
sam2_masks: torch.Tensor, #
|
19 |
-
iou_scores:
|
20 |
-
threshold: float = 0.5,
|
21 |
-
return_mode: str = "single", # "single"
|
22 |
-
keep_soft: bool = False,
|
23 |
) -> torch.Tensor:
|
24 |
-
"""
|
25 |
-
Convert SAM2 output masks to MatAnyone-ready format.
|
26 |
-
|
27 |
-
Returns a MatAnyone-ready tensor on the same device:
|
28 |
-
- "single": (1,H,W) float32 in [0,1]
|
29 |
-
- "multi": (C,H,W) float32 in [0,1]
|
30 |
-
"""
|
31 |
assert sam2_masks.ndim == 4, f"Expect (B,M,H,W). Got {tuple(sam2_masks.shape)}"
|
32 |
B, M, H, W = sam2_masks.shape
|
33 |
-
assert B == 1, "
|
34 |
-
|
35 |
-
masks = sam2_masks[0] # (M,H,W)
|
36 |
|
37 |
-
|
38 |
-
if iou_scores is not None and iou_scores.ndim == 2:
|
39 |
best_idx = int(torch.argmax(iou_scores[0]).item())
|
40 |
else:
|
41 |
-
|
42 |
-
areas = masks.sum(dim=(1,2))
|
43 |
best_idx = int(torch.argmax(areas).item())
|
44 |
|
45 |
if return_mode == "multi":
|
46 |
-
out =
|
47 |
else:
|
48 |
-
out =
|
49 |
|
50 |
-
|
51 |
-
out = out.to(dtype=torch.float32)
|
52 |
if not keep_soft:
|
53 |
out = (out >= threshold).float()
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
assert out.ndim == 3, f"Expect (C,H,W); got {tuple(out.shape)}"
|
58 |
-
assert out.shape[0] >= 1, f"Need at least 1 channel; got {out.shape[0]}"
|
59 |
-
|
60 |
-
return out # (C,H,W) float32 [0,1]
|
|
|
1 |
+
# utils/mask_bridge.py
|
2 |
+
from __future__ import annotations
|
|
|
|
|
|
|
|
|
3 |
import torch
|
|
|
|
|
4 |
|
5 |
def log_shape(tag: str, t: torch.Tensor) -> None:
|
6 |
+
try:
|
7 |
+
mn = float(t.min()) if t.numel() else float("nan")
|
8 |
+
mx = float(t.max()) if t.numel() else float("nan")
|
9 |
+
print(f"[mask_bridge] {tag}: shape={tuple(t.shape)} dtype={t.dtype} device={t.device} "
|
10 |
+
f"range=[{mn:.4f},{mx:.4f}]")
|
11 |
+
except Exception:
|
12 |
+
pass
|
13 |
|
14 |
def sam2_to_matanyone_mask(
|
15 |
+
sam2_masks: torch.Tensor, # (B,M,H,W) after post_process
|
16 |
+
iou_scores: torch.Tensor | None, # (B,M) or None
|
17 |
+
threshold: float = 0.5,
|
18 |
+
return_mode: str = "single", # "single"→(1,H,W) or "multi"→(C,H,W)
|
19 |
+
keep_soft: bool = False,
|
20 |
) -> torch.Tensor:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
assert sam2_masks.ndim == 4, f"Expect (B,M,H,W). Got {tuple(sam2_masks.shape)}"
|
22 |
B, M, H, W = sam2_masks.shape
|
23 |
+
assert B == 1, "Bridge expects B=1 for first-frame bootstrapping"
|
|
|
|
|
24 |
|
25 |
+
candidates = sam2_masks[0] # (M,H,W)
|
26 |
+
if iou_scores is not None and iou_scores.ndim == 2 and iou_scores.shape[0] == 1:
|
27 |
best_idx = int(torch.argmax(iou_scores[0]).item())
|
28 |
else:
|
29 |
+
areas = candidates.sum(dim=(-2,-1))
|
|
|
30 |
best_idx = int(torch.argmax(areas).item())
|
31 |
|
32 |
if return_mode == "multi":
|
33 |
+
out = candidates # (M,H,W) treat as (C,H,W)
|
34 |
else:
|
35 |
+
out = candidates[best_idx:best_idx+1] # (1,H,W)
|
36 |
|
37 |
+
out = out.to(torch.float32)
|
|
|
38 |
if not keep_soft:
|
39 |
out = (out >= threshold).float()
|
40 |
+
out = out.clamp_(0.0, 1.0).contiguous()
|
41 |
+
log_shape("sam2→mat.mask", out)
|
42 |
+
return out
|
|
|
|
|
|
|
|