prasannareddyp commited on
Commit
4e2c1e5
·
verified ·
1 Parent(s): 5226a32

Update spm.py

Browse files
Files changed (1) hide show
  1. spm.py +13 -25
spm.py CHANGED
@@ -32,9 +32,7 @@ def _to_divisible_by(img, N):
32
  def _edgelogic(i, j, ph, pw, N, overlap):
33
  """
34
  Base (no-overlap) patch is [i*ph:(i+1)*ph, j*pw:(j+1)*pw].
35
- Extend with overlap, biasing inward.
36
- Uses 2*overlap for edges to keep patch areas roughly comparable.
37
- Returns (start_h, end_h, start_w, end_w) BEFORE clamping to image bounds.
38
  """
39
  start_h = i * ph
40
  start_w = j * pw
@@ -70,21 +68,20 @@ def spm_augment(
70
  mix_prob=0.5,
71
  beta_a=2.0,
72
  beta_b=2.0,
73
- overlap_px=0,
74
  seed=None
75
  ):
76
  """
77
  SPM-style augmentation with optional overlap + feathered blending.
78
 
79
- When overlap_px <= 0:
80
  - Standard global shuffle over N×N patches;
81
  - Per-patch mixing with a single alpha ~ Beta(a,b) for the image.
82
 
83
- When overlap_px > 0:
84
- - Each base cell (N×N grid) expands by +/-overlap_px (2*overlap at borders),
85
- clipped to the image. Patches are mixed per location and alpha sampled per-patch
86
- for a bit more stochasticity (can be changed to per-image alpha by editing below).
87
- - Patches are blended into the canvas with a feather mask of size `overlap_px`.
88
  """
89
  # Normalize to PIL and ensure divisibility
90
  if isinstance(image, np.ndarray):
@@ -100,10 +97,10 @@ def spm_augment(
100
  ph = H // N
101
  pw = W // N
102
 
103
- # Clamp overlap to < half patch size
104
- if overlap_px is None:
105
- overlap_px = 0
106
- overlap_px = int(overlap_px)
107
  max_ov = max(0, min(ph, pw) // 2 - 1)
108
  ov = int(np.clip(overlap_px, 0, max_ov))
109
 
@@ -167,7 +164,6 @@ def spm_augment(
167
  total = len(patches)
168
  perm = rng.permutation(total)
169
 
170
- # We'll sample alpha per-patch to echo your overlap snippet
171
  def sample_alpha():
172
  if beta_a > 0 and beta_b > 0:
173
  return float(rng.beta(beta_a, beta_b))
@@ -178,9 +174,7 @@ def spm_augment(
178
 
179
  for k, (sh, eh, sw, ew) in enumerate(coords):
180
  if rng.random() >= float(mix_prob):
181
- # keep original content in that region
182
- src = patches[k]
183
- patch = src
184
  else:
185
  lam = sample_alpha()
186
  src = patches[k].astype(np.float32)
@@ -188,18 +182,12 @@ def spm_augment(
188
  patch = lam * shf + (1.0 - lam) * src
189
 
190
  ph_k, pw_k, _ = patch.shape
191
- # Slice feather mask down if needed (near borders)
192
  mask2d = feather_full[:ph_k, :pw_k]
193
- if arr.shape[2] == 1:
194
- mask3d = mask2d[..., None]
195
- else:
196
- mask3d = np.repeat(mask2d[..., None], arr.shape[2], axis=2)
197
 
198
- # Accumulate
199
  canvas[sh:eh, sw:ew] += patch * mask3d
200
  weight[sh:eh, sw:ew] += mask2d
201
 
202
- # Normalize
203
  weight = np.clip(weight, 1e-8, None)
204
  out = (canvas / weight[..., None])
205
  out = np.clip(out, 0, 255).astype(np.uint8)
 
32
  def _edgelogic(i, j, ph, pw, N, overlap):
33
  """
34
  Base (no-overlap) patch is [i*ph:(i+1)*ph, j*pw:(j+1)*pw].
35
+ Extend with overlap, biasing inward. Uses 2*overlap at borders.
 
 
36
  """
37
  start_h = i * ph
38
  start_w = j * pw
 
68
  mix_prob=0.5,
69
  beta_a=2.0,
70
  beta_b=2.0,
71
+ overlap_pct=0.0, # percentage of patch size (0..49 typically)
72
  seed=None
73
  ):
74
  """
75
  SPM-style augmentation with optional overlap + feathered blending.
76
 
77
+ When overlap_pct <= 0:
78
  - Standard global shuffle over N×N patches;
79
  - Per-patch mixing with a single alpha ~ Beta(a,b) for the image.
80
 
81
+ When overlap_pct > 0:
82
+ - Each base cell (N×N grid) expands by ±overlap_px (derived from percentage),
83
+ clipped to the image. Patches are mixed per location with per-patch alpha.
84
+ - Patches are blended into the canvas with a feather mask of size overlap_px.
 
85
  """
86
  # Normalize to PIL and ensure divisibility
87
  if isinstance(image, np.ndarray):
 
97
  ph = H // N
98
  pw = W // N
99
 
100
+ # Convert percentage to pixel overlap; clamp to < half patch size
101
+ pct = float(overlap_pct)
102
+ pct = max(0.0, min(pct, 49.0)) # keep below 50% for stability
103
+ overlap_px = int(round((pct / 100.0) * min(ph, pw)))
104
  max_ov = max(0, min(ph, pw) // 2 - 1)
105
  ov = int(np.clip(overlap_px, 0, max_ov))
106
 
 
164
  total = len(patches)
165
  perm = rng.permutation(total)
166
 
 
167
  def sample_alpha():
168
  if beta_a > 0 and beta_b > 0:
169
  return float(rng.beta(beta_a, beta_b))
 
174
 
175
  for k, (sh, eh, sw, ew) in enumerate(coords):
176
  if rng.random() >= float(mix_prob):
177
+ patch = patches[k]
 
 
178
  else:
179
  lam = sample_alpha()
180
  src = patches[k].astype(np.float32)
 
182
  patch = lam * shf + (1.0 - lam) * src
183
 
184
  ph_k, pw_k, _ = patch.shape
 
185
  mask2d = feather_full[:ph_k, :pw_k]
186
+ mask3d = mask2d[..., None] if arr.shape[2] == 1 else np.repeat(mask2d[..., None], arr.shape[2], axis=2)
 
 
 
187
 
 
188
  canvas[sh:eh, sw:ew] += patch * mask3d
189
  weight[sh:eh, sw:ew] += mask2d
190
 
 
191
  weight = np.clip(weight, 1e-8, None)
192
  out = (canvas / weight[..., None])
193
  out = np.clip(out, 0, 255).astype(np.uint8)