Spaces:
Sleeping
Sleeping
Update spm.py
Browse files
spm.py
CHANGED
@@ -32,9 +32,7 @@ def _to_divisible_by(img, N):
|
|
32 |
def _edgelogic(i, j, ph, pw, N, overlap):
|
33 |
"""
|
34 |
Base (no-overlap) patch is [i*ph:(i+1)*ph, j*pw:(j+1)*pw].
|
35 |
-
Extend with overlap, biasing inward.
|
36 |
-
Uses 2*overlap for edges to keep patch areas roughly comparable.
|
37 |
-
Returns (start_h, end_h, start_w, end_w) BEFORE clamping to image bounds.
|
38 |
"""
|
39 |
start_h = i * ph
|
40 |
start_w = j * pw
|
@@ -70,21 +68,20 @@ def spm_augment(
|
|
70 |
mix_prob=0.5,
|
71 |
beta_a=2.0,
|
72 |
beta_b=2.0,
|
73 |
-
|
74 |
seed=None
|
75 |
):
|
76 |
"""
|
77 |
SPM-style augmentation with optional overlap + feathered blending.
|
78 |
|
79 |
-
When
|
80 |
- Standard global shuffle over N×N patches;
|
81 |
- Per-patch mixing with a single alpha ~ Beta(a,b) for the image.
|
82 |
|
83 |
-
When
|
84 |
-
- Each base cell (N×N grid) expands by
|
85 |
-
clipped to the image. Patches are mixed per location
|
86 |
-
|
87 |
-
- Patches are blended into the canvas with a feather mask of size `overlap_px`.
|
88 |
"""
|
89 |
# Normalize to PIL and ensure divisibility
|
90 |
if isinstance(image, np.ndarray):
|
@@ -100,10 +97,10 @@ def spm_augment(
|
|
100 |
ph = H // N
|
101 |
pw = W // N
|
102 |
|
103 |
-
#
|
104 |
-
|
105 |
-
|
106 |
-
overlap_px = int(
|
107 |
max_ov = max(0, min(ph, pw) // 2 - 1)
|
108 |
ov = int(np.clip(overlap_px, 0, max_ov))
|
109 |
|
@@ -167,7 +164,6 @@ def spm_augment(
|
|
167 |
total = len(patches)
|
168 |
perm = rng.permutation(total)
|
169 |
|
170 |
-
# We'll sample alpha per-patch to echo your overlap snippet
|
171 |
def sample_alpha():
|
172 |
if beta_a > 0 and beta_b > 0:
|
173 |
return float(rng.beta(beta_a, beta_b))
|
@@ -178,9 +174,7 @@ def spm_augment(
|
|
178 |
|
179 |
for k, (sh, eh, sw, ew) in enumerate(coords):
|
180 |
if rng.random() >= float(mix_prob):
|
181 |
-
|
182 |
-
src = patches[k]
|
183 |
-
patch = src
|
184 |
else:
|
185 |
lam = sample_alpha()
|
186 |
src = patches[k].astype(np.float32)
|
@@ -188,18 +182,12 @@ def spm_augment(
|
|
188 |
patch = lam * shf + (1.0 - lam) * src
|
189 |
|
190 |
ph_k, pw_k, _ = patch.shape
|
191 |
-
# Slice feather mask down if needed (near borders)
|
192 |
mask2d = feather_full[:ph_k, :pw_k]
|
193 |
-
if arr.shape[2] == 1
|
194 |
-
mask3d = mask2d[..., None]
|
195 |
-
else:
|
196 |
-
mask3d = np.repeat(mask2d[..., None], arr.shape[2], axis=2)
|
197 |
|
198 |
-
# Accumulate
|
199 |
canvas[sh:eh, sw:ew] += patch * mask3d
|
200 |
weight[sh:eh, sw:ew] += mask2d
|
201 |
|
202 |
-
# Normalize
|
203 |
weight = np.clip(weight, 1e-8, None)
|
204 |
out = (canvas / weight[..., None])
|
205 |
out = np.clip(out, 0, 255).astype(np.uint8)
|
|
|
32 |
def _edgelogic(i, j, ph, pw, N, overlap):
|
33 |
"""
|
34 |
Base (no-overlap) patch is [i*ph:(i+1)*ph, j*pw:(j+1)*pw].
|
35 |
+
Extend with overlap, biasing inward. Uses 2*overlap at borders.
|
|
|
|
|
36 |
"""
|
37 |
start_h = i * ph
|
38 |
start_w = j * pw
|
|
|
68 |
mix_prob=0.5,
|
69 |
beta_a=2.0,
|
70 |
beta_b=2.0,
|
71 |
+
overlap_pct=0.0, # percentage of patch size (0..49 typically)
|
72 |
seed=None
|
73 |
):
|
74 |
"""
|
75 |
SPM-style augmentation with optional overlap + feathered blending.
|
76 |
|
77 |
+
When overlap_pct <= 0:
|
78 |
- Standard global shuffle over N×N patches;
|
79 |
- Per-patch mixing with a single alpha ~ Beta(a,b) for the image.
|
80 |
|
81 |
+
When overlap_pct > 0:
|
82 |
+
- Each base cell (N×N grid) expands by ±overlap_px (derived from percentage),
|
83 |
+
clipped to the image. Patches are mixed per location with per-patch alpha.
|
84 |
+
- Patches are blended into the canvas with a feather mask of size overlap_px.
|
|
|
85 |
"""
|
86 |
# Normalize to PIL and ensure divisibility
|
87 |
if isinstance(image, np.ndarray):
|
|
|
97 |
ph = H // N
|
98 |
pw = W // N
|
99 |
|
100 |
+
# Convert percentage to pixel overlap; clamp to < half patch size
|
101 |
+
pct = float(overlap_pct)
|
102 |
+
pct = max(0.0, min(pct, 49.0)) # keep below 50% for stability
|
103 |
+
overlap_px = int(round((pct / 100.0) * min(ph, pw)))
|
104 |
max_ov = max(0, min(ph, pw) // 2 - 1)
|
105 |
ov = int(np.clip(overlap_px, 0, max_ov))
|
106 |
|
|
|
164 |
total = len(patches)
|
165 |
perm = rng.permutation(total)
|
166 |
|
|
|
167 |
def sample_alpha():
|
168 |
if beta_a > 0 and beta_b > 0:
|
169 |
return float(rng.beta(beta_a, beta_b))
|
|
|
174 |
|
175 |
for k, (sh, eh, sw, ew) in enumerate(coords):
|
176 |
if rng.random() >= float(mix_prob):
|
177 |
+
patch = patches[k]
|
|
|
|
|
178 |
else:
|
179 |
lam = sample_alpha()
|
180 |
src = patches[k].astype(np.float32)
|
|
|
182 |
patch = lam * shf + (1.0 - lam) * src
|
183 |
|
184 |
ph_k, pw_k, _ = patch.shape
|
|
|
185 |
mask2d = feather_full[:ph_k, :pw_k]
|
186 |
+
mask3d = mask2d[..., None] if arr.shape[2] == 1 else np.repeat(mask2d[..., None], arr.shape[2], axis=2)
|
|
|
|
|
|
|
187 |
|
|
|
188 |
canvas[sh:eh, sw:ew] += patch * mask3d
|
189 |
weight[sh:eh, sw:ew] += mask2d
|
190 |
|
|
|
191 |
weight = np.clip(weight, 1e-8, None)
|
192 |
out = (canvas / weight[..., None])
|
193 |
out = np.clip(out, 0, 255).astype(np.uint8)
|