MogensR commited on
Commit
967f336
·
1 Parent(s): e94d263

Create utils/interop.py

Browse files
Files changed (1) hide show
  1. utils/interop.py +99 -0
utils/interop.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # utils/interop.py
2
+ from __future__ import annotations
3
+ import torch
4
+
5
+ def log_shape(tag: str, t: torch.Tensor) -> None:
6
+ try:
7
+ mn = float(t.min()) if t.numel() else float("nan")
8
+ mx = float(t.max()) if t.numel() else float("nan")
9
+ print(f"[interop] {tag}: shape={tuple(t.shape)} dtype={t.dtype} device={t.device} "
10
+ f"range=[{mn:.4f},{mx:.4f}]")
11
+ except Exception as e:
12
+ print(f"[interop] {tag}: <log failed: {e!r}>")
13
+
14
+ def _to_float01(x: torch.Tensor) -> torch.Tensor:
15
+ x = x.to(torch.float32)
16
+ if x.max() > 1.0:
17
+ x = x / 255.0
18
+ return x.clamp_(0.0, 1.0)
19
+
20
+ def _squeeze_bt(x: torch.Tensor) -> torch.Tensor:
21
+ # Drop singleton Time and extra Batch: (B,T,C,H,W) → (B,C,H,W) or (C,H,W)
22
+ if x.ndim == 5:
23
+ if x.shape[1] == 1:
24
+ x = x.squeeze(1) # drop T
25
+ if x.ndim == 5 and x.shape[0] == 1:
26
+ x = x.squeeze(0) # drop B
27
+ # Edge case: (1,1,3,H,W)
28
+ if x.ndim == 4 and x.shape[0] == 1 and x.shape[1] == 1 and x.shape[-3] == 3:
29
+ x = x.squeeze(1) # → (1,3,H,W)
30
+ return x
31
+
32
+ def ensure_image_nchw(
33
+ img: torch.Tensor,
34
+ device: torch.device | str = "cuda",
35
+ want_batched: bool = True,
36
+ ) -> torch.Tensor:
37
+ img = img.to(device)
38
+ img = _squeeze_bt(img)
39
+ if img.ndim == 3:
40
+ # CHW or HWC
41
+ if img.shape[0] in (1,3):
42
+ chw = img
43
+ else:
44
+ chw = img.permute(2,0,1) # HWC→CHW
45
+ chw = _to_float01(chw.contiguous())
46
+ return chw.unsqueeze(0) if want_batched else chw
47
+ if img.ndim == 4:
48
+ N,A,B,C = img.shape
49
+ if A == 3:
50
+ nchw = img
51
+ elif C == 3:
52
+ nchw = img.permute(0,3,1,2) # NHWC→NCHW
53
+ else:
54
+ raise AssertionError(f"Cannot infer channels in image: {tuple(img.shape)}")
55
+ return _to_float01(nchw.contiguous())
56
+ raise AssertionError(f"Image must be 3D/4D; got {tuple(img.shape)}")
57
+
58
+ def ensure_mask_for_matanyone(
59
+ mask: torch.Tensor,
60
+ *,
61
+ idx_mask: bool = False,
62
+ threshold: float = 0.5,
63
+ keep_soft: bool = False,
64
+ device: torch.device | str = "cuda",
65
+ ) -> torch.Tensor:
66
+ mask = mask.to(device)
67
+ mask = _squeeze_bt(mask)
68
+
69
+ if idx_mask:
70
+ # Return (H,W) labels {0,1}
71
+ if mask.ndim == 3:
72
+ if mask.shape[0] == 1:
73
+ idx = (mask[0] >= threshold).to(torch.long)
74
+ else:
75
+ idx = torch.argmax(mask, dim=0).to(torch.long)
76
+ idx = (idx > 0).to(torch.long)
77
+ elif mask.ndim == 2:
78
+ idx = (mask >= threshold).to(torch.long)
79
+ else:
80
+ raise AssertionError(f"idx mask must be 2D/3D; got {tuple(mask.shape)}")
81
+ return idx
82
+
83
+ # Channel mask path → (1,H,W) float [0,1]
84
+ if mask.ndim == 2:
85
+ out = mask.unsqueeze(0)
86
+ elif mask.ndim == 3:
87
+ if mask.shape[0] == 1:
88
+ out = mask
89
+ else:
90
+ # choose largest area channel
91
+ areas = mask.sum(dim=(-2,-1))
92
+ out = mask[areas.argmax():areas.argmax()+1]
93
+ else:
94
+ raise AssertionError(f"mask must be 2D/3D; got {tuple(mask.shape)}")
95
+
96
+ out = out.to(torch.float32)
97
+ if not keep_soft:
98
+ out = (out >= threshold).to(torch.float32)
99
+ return out.clamp_(0.0, 1.0).contiguous()