PuristanLabs1 commited on
Commit
39910f1
·
verified ·
1 Parent(s): c1f278e

Upload 7 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ distorted/27_2[[:space:]]copy.png filter=lfs diff=lfs merge=lfs -text
37
+ distorted/42_2[[:space:]]copy.png filter=lfs diff=lfs merge=lfs -text
38
+ distorted/48_1[[:space:]]copy.png filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,593 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import numpy as np
6
+ import cv2
7
+ import os
8
+ from PIL import Image
9
+ import warnings
10
+ import sys # Added for PyInstaller
11
+
12
+ warnings.filterwarnings('ignore')
13
+
14
+ # --- PyInstaller Helper ---
15
+ # Determines the correct path for bundled data files (models)
16
+ def resource_path(relative_path):
17
+ """ Get absolute path to resource, works for dev and for PyInstaller """
18
+ try:
19
+ # PyInstaller creates a temp folder and stores path in _MEIPASS
20
+ base_path = sys._MEIPASS
21
+ except Exception:
22
+ base_path = os.path.abspath(".")
23
+
24
+ return os.path.join(base_path, relative_path)
25
+
26
+ # --- Model and Helper Class Definitions ---
27
+ # Most of these classes are copied directly from the project's files
28
+ # (extractor.py, update.py, seg.py, model.py, inference.py)
29
+ # to make this Gradio app a self-contained script.
30
+
31
+ # from extractor.py
32
+ class ResidualBlock(nn.Module):
33
+ def __init__(self, in_planes, planes, norm_fn='group', stride=1):
34
+ super(ResidualBlock, self).__init__()
35
+ self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride)
36
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1)
37
+ self.relu = nn.ReLU(inplace=True)
38
+ if norm_fn == 'batch':
39
+ self.norm1 = nn.BatchNorm2d(planes)
40
+ self.norm2 = nn.BatchNorm2d(planes)
41
+ if not stride == 1:
42
+ self.norm3 = nn.BatchNorm2d(planes)
43
+ elif norm_fn == 'instance':
44
+ self.norm1 = nn.InstanceNorm2d(planes)
45
+ self.norm2 = nn.InstanceNorm2d(planes)
46
+ if not stride == 1:
47
+ self.norm3 = nn.InstanceNorm2d(planes)
48
+ if stride == 1:
49
+ self.downsample = None
50
+ else:
51
+ self.downsample = nn.Sequential(
52
+ nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3)
53
+ def forward(self, x):
54
+ y = x
55
+ y = self.relu(self.norm1(self.conv1(y)))
56
+ y = self.relu(self.norm2(self.conv2(y)))
57
+ if self.downsample is not None:
58
+ x = self.downsample(x)
59
+ return self.relu(x + y)
60
+
61
+ class BasicEncoder(nn.Module):
62
+ def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0):
63
+ super(BasicEncoder, self).__init__()
64
+ self.norm_fn = norm_fn
65
+ if self.norm_fn == 'batch':
66
+ self.norm1 = nn.BatchNorm2d(64)
67
+ elif self.norm_fn == 'instance':
68
+ self.norm1 = nn.InstanceNorm2d(64)
69
+ self.conv1 = nn.Conv2d(3, 80, kernel_size=7, stride=2, padding=3)
70
+ self.relu1 = nn.ReLU(inplace=True)
71
+ self.in_planes = 80
72
+ self.layer1 = self._make_layer(80, stride=1)
73
+ self.layer2 = self._make_layer(160, stride=2)
74
+ self.layer3 = self._make_layer(240, stride=2)
75
+ self.conv2 = nn.Conv2d(240, output_dim, kernel_size=1)
76
+ for m in self.modules():
77
+ if isinstance(m, nn.Conv2d):
78
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
79
+ elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
80
+ if m.weight is not None:
81
+ nn.init.constant_(m.weight, 1)
82
+ if m.bias is not None:
83
+ nn.init.constant_(m.bias, 0)
84
+ def _make_layer(self, dim, stride=1):
85
+ layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
86
+ layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
87
+ layers = (layer1, layer2)
88
+ self.in_planes = dim
89
+ return nn.Sequential(*layers)
90
+ def forward(self, x):
91
+ x = self.conv1(x)
92
+ x = self.norm1(x)
93
+ x = self.relu1(x)
94
+ x = self.layer1(x)
95
+ x = self.layer2(x)
96
+ x = self.layer3(x)
97
+ x = self.conv2(x)
98
+ return x
99
+
100
+ # from update.py
101
+ class FlowHead(nn.Module):
102
+ def __init__(self, input_dim=128, hidden_dim=256):
103
+ super(FlowHead, self).__init__()
104
+ self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1)
105
+ self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1)
106
+ self.relu = nn.ReLU(inplace=True)
107
+ def forward(self, x):
108
+ return self.conv2(self.relu(self.conv1(x)))
109
+
110
+ class SepConvGRU(nn.Module):
111
+ def __init__(self, hidden_dim=128, input_dim=192+128):
112
+ super(SepConvGRU, self).__init__()
113
+ self.convz1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
114
+ self.convr1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
115
+ self.convq1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
116
+ self.convz2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
117
+ self.convr2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
118
+ self.convq2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
119
+ def forward(self, h, x):
120
+ hx = torch.cat([h, x], dim=1)
121
+ z = torch.sigmoid(self.convz1(hx))
122
+ r = torch.sigmoid(self.convr1(hx))
123
+ q = torch.tanh(self.convq1(torch.cat([r*h, x], dim=1)))
124
+ h = (1-z) * h + z * q
125
+ hx = torch.cat([h, x], dim=1)
126
+ z = torch.sigmoid(self.convz2(hx))
127
+ r = torch.sigmoid(self.convr2(hx))
128
+ q = torch.tanh(self.convq2(torch.cat([r*h, x], dim=1)))
129
+ h = (1-z) * h + z * q
130
+ return h
131
+
132
+ class BasicMotionEncoder(nn.Module):
133
+ def __init__(self):
134
+ super(BasicMotionEncoder, self).__init__()
135
+ self.convc1 = nn.Conv2d(320, 240, 1, padding=0)
136
+ self.convc2 = nn.Conv2d(240, 160, 3, padding=1)
137
+ self.convf1 = nn.Conv2d(2, 160, 7, padding=3)
138
+ self.convf2 = nn.Conv2d(160, 80, 3, padding=1)
139
+ self.conv = nn.Conv2d(160+80, 160-2, 3, padding=1)
140
+ def forward(self, flow, corr):
141
+ cor = F.relu(self.convc1(corr))
142
+ cor = F.relu(self.convc2(cor))
143
+ flo = F.relu(self.convf1(flow))
144
+ flo = F.relu(self.convf2(flo))
145
+ cor_flo = torch.cat([cor, flo], dim=1)
146
+ out = F.relu(self.conv(cor_flo))
147
+ return torch.cat([out, flow], dim=1)
148
+
149
+ class BasicUpdateBlock(nn.Module):
150
+ def __init__(self, hidden_dim=128):
151
+ super(BasicUpdateBlock, self).__init__()
152
+ self.encoder = BasicMotionEncoder()
153
+ self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=160+160)
154
+ self.flow_head = FlowHead(hidden_dim, hidden_dim=320)
155
+ self.mask = nn.Sequential(
156
+ nn.Conv2d(hidden_dim, 288, 3, padding=1),
157
+ nn.ReLU(inplace=True),
158
+ nn.Conv2d(288, 64*9, 1, padding=0))
159
+ def forward(self, net, inp, corr, flow):
160
+ motion_features = self.encoder(flow, corr)
161
+ inp = torch.cat([inp, motion_features], dim=1)
162
+ net = self.gru(net, inp)
163
+ delta_flow = self.flow_head(net)
164
+ mask = .25 * self.mask(net)
165
+ return net, mask, delta_flow
166
+
167
+ # from seg.py
168
+ class REBNCONV(nn.Module):
169
+ def __init__(self, in_ch=3, out_ch=3, dirate=1):
170
+ super(REBNCONV, self).__init__()
171
+ self.conv_s1 = nn.Conv2d(in_ch, out_ch, 3, padding=1 * dirate, dilation=1 * dirate)
172
+ self.bn_s1 = nn.BatchNorm2d(out_ch)
173
+ self.relu_s1 = nn.ReLU(inplace=True)
174
+ def forward(self, x):
175
+ return self.relu_s1(self.bn_s1(self.conv_s1(x)))
176
+
177
+ def _upsample_like(src, tar):
178
+ return F.interpolate(src, size=tar.shape[2:], mode='bilinear', align_corners=False)
179
+
180
+ class RSU7(nn.Module):
181
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
182
+ super(RSU7, self).__init__()
183
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
184
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
185
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
186
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
187
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
188
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
189
+ self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
190
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
191
+ self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
192
+ self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
193
+ self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
194
+ self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=1)
195
+ self.rebnconv7 = REBNCONV(mid_ch, mid_ch, dirate=2)
196
+ self.rebnconv6d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
197
+ self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
198
+ self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
199
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
200
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
201
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
202
+ def forward(self, x):
203
+ hxin = self.rebnconvin(x)
204
+ hx1 = self.rebnconv1(hxin)
205
+ hx = self.pool1(hx1)
206
+ hx2 = self.rebnconv2(hx)
207
+ hx = self.pool2(hx2)
208
+ hx3 = self.rebnconv3(hx)
209
+ hx = self.pool3(hx3)
210
+ hx4 = self.rebnconv4(hx)
211
+ hx = self.pool4(hx4)
212
+ hx5 = self.rebnconv5(hx)
213
+ hx = self.pool5(hx5)
214
+ hx6 = self.rebnconv6(hx)
215
+ hx7 = self.rebnconv7(hx6)
216
+ hx6d = self.rebnconv6d(torch.cat((hx7, hx6), 1))
217
+ hx6dup = _upsample_like(hx6d, hx5)
218
+ hx5d = self.rebnconv5d(torch.cat((hx6dup, hx5), 1))
219
+ hx5dup = _upsample_like(hx5d, hx4)
220
+ hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
221
+ hx4dup = _upsample_like(hx4d, hx3)
222
+ hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
223
+ hx3dup = _upsample_like(hx3d, hx2)
224
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
225
+ hx2dup = _upsample_like(hx2d, hx1)
226
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
227
+ return hx1d + hxin
228
+
229
+ class RSU6(nn.Module):
230
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
231
+ super(RSU6, self).__init__()
232
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
233
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
234
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
235
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
236
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
237
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
238
+ self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
239
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
240
+ self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
241
+ self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
242
+ self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=2)
243
+ self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
244
+ self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
245
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
246
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
247
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
248
+ def forward(self, x):
249
+ hxin = self.rebnconvin(x)
250
+ hx1 = self.rebnconv1(hxin)
251
+ hx = self.pool1(hx1)
252
+ hx2 = self.rebnconv2(hx)
253
+ hx = self.pool2(hx2)
254
+ hx3 = self.rebnconv3(hx)
255
+ hx = self.pool3(hx3)
256
+ hx4 = self.rebnconv4(hx)
257
+ hx = self.pool4(hx4)
258
+ hx5 = self.rebnconv5(hx)
259
+ hx6 = self.rebnconv6(hx5)
260
+ hx5d = self.rebnconv5d(torch.cat((hx6, hx5), 1))
261
+ hx5dup = _upsample_like(hx5d, hx4)
262
+ hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
263
+ hx4dup = _upsample_like(hx4d, hx3)
264
+ hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
265
+ hx3dup = _upsample_like(hx3d, hx2)
266
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
267
+ hx2dup = _upsample_like(hx2d, hx1)
268
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
269
+ return hx1d + hxin
270
+
271
+ class RSU5(nn.Module):
272
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
273
+ super(RSU5, self).__init__()
274
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
275
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
276
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
277
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
278
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
279
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
280
+ self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
281
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
282
+ self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=2)
283
+ self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
284
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
285
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
286
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
287
+ def forward(self, x):
288
+ hxin = self.rebnconvin(x)
289
+ hx1 = self.rebnconv1(hxin)
290
+ hx = self.pool1(hx1)
291
+ hx2 = self.rebnconv2(hx)
292
+ hx = self.pool2(hx2)
293
+ hx3 = self.rebnconv3(hx)
294
+ hx = self.pool3(hx3)
295
+ hx4 = self.rebnconv4(hx)
296
+ hx5 = self.rebnconv5(hx4)
297
+ hx4d = self.rebnconv4d(torch.cat((hx5, hx4), 1))
298
+ hx4dup = _upsample_like(hx4d, hx3)
299
+ hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
300
+ hx3dup = _upsample_like(hx3d, hx2)
301
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
302
+ hx2dup = _upsample_like(hx2d, hx1)
303
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
304
+ return hx1d + hxin
305
+
306
+ class RSU4(nn.Module):
307
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
308
+ super(RSU4, self).__init__()
309
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
310
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
311
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
312
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
313
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
314
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
315
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=2)
316
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
317
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
318
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
319
+ def forward(self, x):
320
+ hxin = self.rebnconvin(x)
321
+ hx1 = self.rebnconv1(hxin)
322
+ hx = self.pool1(hx1)
323
+ hx2 = self.rebnconv2(hx)
324
+ hx = self.pool2(hx2)
325
+ hx3 = self.rebnconv3(hx)
326
+ hx4 = self.rebnconv4(hx3)
327
+ hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
328
+ hx3dup = _upsample_like(hx3d, hx2)
329
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
330
+ hx2dup = _upsample_like(hx2d, hx1)
331
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
332
+ return hx1d + hxin
333
+
334
+ class RSU4F(nn.Module):
335
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
336
+ super(RSU4F, self).__init__()
337
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
338
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
339
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=2)
340
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=4)
341
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=8)
342
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=4)
343
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=2)
344
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
345
+ def forward(self, x):
346
+ hxin = self.rebnconvin(x)
347
+ hx1 = self.rebnconv1(hxin)
348
+ hx2 = self.rebnconv2(hx1)
349
+ hx3 = self.rebnconv3(hx2)
350
+ hx4 = self.rebnconv4(hx3)
351
+ hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
352
+ hx2d = self.rebnconv2d(torch.cat((hx3d, hx2), 1))
353
+ hx1d = self.rebnconv1d(torch.cat((hx2d, hx1), 1))
354
+ return hx1d + hxin
355
+
356
+ class U2NETP(nn.Module):
357
+ def __init__(self, in_ch=3, out_ch=1):
358
+ super(U2NETP, self).__init__()
359
+ self.stage1 = RSU7(in_ch, 16, 64)
360
+ self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
361
+ self.stage2 = RSU6(64, 16, 64)
362
+ self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
363
+ self.stage3 = RSU5(64, 16, 64)
364
+ self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
365
+ self.stage4 = RSU4(64, 16, 64)
366
+ self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
367
+ self.stage5 = RSU4F(64, 16, 64)
368
+ self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
369
+ self.stage6 = RSU4F(64, 16, 64)
370
+ self.stage5d = RSU4F(128, 16, 64)
371
+ self.stage4d = RSU4(128, 16, 64)
372
+ self.stage3d = RSU5(128, 16, 64)
373
+ self.stage2d = RSU6(128, 16, 64)
374
+ self.stage1d = RSU7(128, 16, 64)
375
+ self.side1 = nn.Conv2d(64, out_ch, 3, padding=1)
376
+ self.side2 = nn.Conv2d(64, out_ch, 3, padding=1)
377
+ self.side3 = nn.Conv2d(64, out_ch, 3, padding=1)
378
+ self.side4 = nn.Conv2d(64, out_ch, 3, padding=1)
379
+ self.side5 = nn.Conv2d(64, out_ch, 3, padding=1)
380
+ self.side6 = nn.Conv2d(64, out_ch, 3, padding=1)
381
+ self.outconv = nn.Conv2d(6, out_ch, 1)
382
+ def forward(self, x):
383
+ hx = x
384
+ hx1 = self.stage1(hx)
385
+ hx = self.pool12(hx1)
386
+ hx2 = self.stage2(hx)
387
+ hx = self.pool23(hx2)
388
+ hx3 = self.stage3(hx)
389
+ hx = self.pool34(hx3)
390
+ hx4 = self.stage4(hx)
391
+ hx = self.pool45(hx4)
392
+ hx5 = self.stage5(hx)
393
+ hx = self.pool56(hx5)
394
+ hx6 = self.stage6(hx)
395
+ hx6up = _upsample_like(hx6, hx5)
396
+ hx5d = self.stage5d(torch.cat((hx6up, hx5), 1))
397
+ hx5dup = _upsample_like(hx5d, hx4)
398
+ hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1))
399
+ hx4dup = _upsample_like(hx4d, hx3)
400
+ hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1))
401
+ hx3dup = _upsample_like(hx3d, hx2)
402
+ hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1))
403
+ hx2dup = _upsample_like(hx2d, hx1)
404
+ hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1))
405
+ d1 = self.side1(hx1d)
406
+ d2 = self.side2(hx2d)
407
+ d2 = _upsample_like(d2, d1)
408
+ d3 = self.side3(hx3d)
409
+ d3 = _upsample_like(d3, d1)
410
+ d4 = self.side4(hx4d)
411
+ d4 = _upsample_like(d4, d1)
412
+ d5 = self.side5(hx5d)
413
+ d5 = _upsample_like(d5, d1)
414
+ d6 = self.side6(hx6)
415
+ d6 = _upsample_like(d6, d1)
416
+ d0 = self.outconv(torch.cat((d1, d2, d3, d4, d5, d6), 1))
417
+ return torch.sigmoid(d0), torch.sigmoid(d1), torch.sigmoid(d2), torch.sigmoid(d3), torch.sigmoid(d4), torch.sigmoid(d5), torch.sigmoid(d6)
418
+
419
+ # from model.py
420
+ def bilinear_sampler(img, coords, mode='bilinear', mask=False):
421
+ H, W = img.shape[-2:]
422
+ xgrid, ygrid = coords.split([1, 1], dim=-1)
423
+ xgrid = 2 * xgrid / (W - 1) - 1
424
+ ygrid = 2 * ygrid / (H - 1) - 1
425
+ grid = torch.cat([xgrid, ygrid], dim=-1)
426
+ img = F.grid_sample(img, grid, align_corners=True)
427
+ if mask:
428
+ mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)
429
+ return img, mask.float()
430
+ return img
431
+
432
+ def coords_grid(batch, ht, wd):
433
+ coords = torch.meshgrid(torch.arange(ht), torch.arange(wd))
434
+ coords = torch.stack(coords[::-1], dim=0).float()
435
+ return coords[None].repeat(batch, 1, 1, 1)
436
+
437
+ class DocScanner(nn.Module):
438
+ def __init__(self):
439
+ super(DocScanner, self).__init__()
440
+ self.hidden_dim = hdim = 160
441
+ self.context_dim = 160
442
+ self.fnet = BasicEncoder(output_dim=320, norm_fn='instance')
443
+ self.update_block = BasicUpdateBlock(hidden_dim=hdim)
444
+ def forward(self, image1, iters=12, flow_init=None, test_mode=False):
445
+ image1 = image1.contiguous()
446
+ fmap1 = self.fnet(image1)
447
+ warpfea = fmap1
448
+ net, inp = torch.split(fmap1, [160, 160], dim=1)
449
+ net = torch.tanh(net)
450
+ inp = torch.relu(inp)
451
+ coodslar, coords0, coords1 = self.initialize_flow(image1)
452
+ if flow_init is not None:
453
+ coords1 = coords1 + flow_init
454
+ flow_predictions = []
455
+ for itr in range(iters):
456
+ coords1 = coords1.detach()
457
+ flow = coords1 - coords0
458
+ net, up_mask, delta_flow = self.update_block(net, inp, warpfea, flow)
459
+ coords1 = coords1 + delta_flow
460
+ flow_up = self.upsample_flow(coords1 - coords0, up_mask)
461
+ bm_up = coodslar + flow_up
462
+ warpfea = bilinear_sampler(fmap1, coords1.permute(0, 2, 3, 1))
463
+ flow_predictions.append(bm_up)
464
+ if test_mode:
465
+ return bm_up
466
+ return flow_predictions
467
+ def initialize_flow(self, img):
468
+ N, C, H, W = img.shape
469
+ coodslar = coords_grid(N, H, W).to(img.device)
470
+ coords0 = coords_grid(N, H // 8, W // 8).to(img.device)
471
+ coords1 = coords_grid(N, H // 8, W // 8).to(img.device)
472
+ return coodslar, coords0, coords1
473
+ def upsample_flow(self, flow, mask):
474
+ N, _, H, W = flow.shape
475
+ mask = mask.view(N, 1, 9, 8, 8, H, W)
476
+ mask = torch.softmax(mask, dim=2)
477
+ up_flow = F.unfold(8 * flow, [3, 3], padding=1)
478
+ up_flow = up_flow.view(N, 2, 9, 1, 1, H, W)
479
+ up_flow = torch.sum(mask * up_flow, dim=2)
480
+ up_flow = up_flow.permute(0, 1, 4, 2, 5, 3)
481
+ return up_flow.reshape(N, 2, 8 * H, 8 * W)
482
+
483
+ # from inference.py
484
+ class Net(nn.Module):
485
+ def __init__(self):
486
+ super(Net, self).__init__()
487
+ self.msk = U2NETP(3, 1)
488
+ self.bm = DocScanner()
489
+ def forward(self, x):
490
+ msk, _, _, _, _, _, _ = self.msk(x)
491
+ msk = (msk > 0.5).float()
492
+ x = msk * x
493
+ bm = self.bm(x, iters=12, test_mode=True)
494
+ bm = (2 * (bm / 286.8) - 1) * 0.99
495
+ return bm
496
+
497
+ def reload_seg_model(model, path=""):
498
+ if not bool(path) or not os.path.exists(path):
499
+ print("Warning: Segmentation model path not found. Using initial weights.")
500
+ return model
501
+ model_dict = model.state_dict()
502
+ pretrained_dict = torch.load(path, map_location='cuda:0' if torch.cuda.is_available() else 'cpu')
503
+ pretrained_dict = {k[6:]: v for k, v in pretrained_dict.items() if k[6:] in model_dict}
504
+ model_dict.update(pretrained_dict)
505
+ model.load_state_dict(model_dict)
506
+ return model
507
+
508
+ def reload_rec_model(model, path=""):
509
+ if not bool(path) or not os.path.exists(path):
510
+ print("Warning: Rectification model path not found. Using initial weights.")
511
+ return model
512
+ model_dict = model.state_dict()
513
+ pretrained_dict = torch.load(path, map_location='cuda:0' if torch.cuda.is_available() else 'cpu')
514
+ pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
515
+ model_dict.update(pretrained_dict)
516
+ model.load_state_dict(model_dict)
517
+ return model
518
+
519
+ # --- Gradio App Logic ---
520
+
521
+ # Configuration
522
+ SEG_MODEL_PATH = resource_path('model_pretrained/seg.pth')
523
+ REC_MODEL_PATH = resource_path('model_pretrained/DocScanner-L.pth')
524
+ DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'
525
+
526
+ # Load models once
527
+ print("Initializing and loading models...")
528
+ net = Net().to(DEVICE)
529
+ reload_seg_model(net.msk, SEG_MODEL_PATH)
530
+ reload_rec_model(net.bm, REC_MODEL_PATH)
531
+ net.eval()
532
+ print("Models loaded successfully.")
533
+
534
+ def rectify_image(distorted_image):
535
+ """
536
+ Takes a distorted image as a numpy array, rectifies it using the DocScanner model,
537
+ and returns the rectified image as a numpy array.
538
+ """
539
+ if distorted_image is None:
540
+ return None
541
+
542
+ im_ori = distorted_image.astype(np.float32) / 255.
543
+ h, w, _ = im_ori.shape
544
+
545
+ # Pre-process
546
+ im = cv2.resize(im_ori, (288, 288))
547
+ im = im.transpose(2, 0, 1)
548
+ im = torch.from_numpy(im).float().unsqueeze(0)
549
+
550
+ with torch.no_grad():
551
+ # Inference
552
+ bm = net(im.to(DEVICE))
553
+ bm = bm.cpu()
554
+
555
+ # Post-process
556
+ bm0 = cv2.resize(bm[0, 0].numpy(), (w, h)) # x flow
557
+ bm1 = cv2.resize(bm[0, 1].numpy(), (w, h)) # y flow
558
+ bm0 = cv2.blur(bm0, (3, 3))
559
+ bm1 = cv2.blur(bm1, (3, 3))
560
+ lbl = torch.from_numpy(np.stack([bm0, bm1], axis=2)).unsqueeze(0) # h * w * 2
561
+
562
+ # Warp the original image
563
+ out = F.grid_sample(torch.from_numpy(im_ori).permute(2, 0, 1).unsqueeze(0).float(), lbl, align_corners=True)
564
+
565
+ # Convert to displayable format
566
+ rectified_image = (out[0].permute(1, 2, 0).numpy() * 255).astype(np.uint8)
567
+
568
+ return rectified_image
569
+
570
+ # --- Gradio Interface ---
571
+
572
+ DESCRIPTION = """
573
+ # DocScanner: Robust Document Image Rectification with Progressive Learning
574
+ This is a Gradio demo for the DocScanner model.
575
+ 1. Upload a distorted document image.
576
+ 2. The model will process it and display the rectified (unwarped) image.
577
+ This demo uses the **DocScanner-L** model as described in the paper. Make sure the pretrained models (`seg.pth`, `DocScanner-L.pth`) are located in the `./model_pretrained/` directory.
578
+ """
579
+
580
+ if __name__ == "__main__":
581
+ iface = gr.Interface(
582
+ fn=rectify_image,
583
+ inputs=gr.Image(type="numpy", label="Upload Distorted Document"),
584
+ outputs=gr.Image(type="numpy", label="Rectified Document"),
585
+ title="DocScanner Document Rectification",
586
+ description=DESCRIPTION,
587
+ examples=[
588
+ ['distorted/27_2 copy.png'],
589
+ ['distorted/42_2 copy.png'],
590
+ ['distorted/48_1 copy.png']
591
+ ]
592
+ )
593
+ iface.launch()
distorted/27_2 copy.png ADDED

Git LFS Details

  • SHA256: 565e9885993b104b04558bdb68f396697cd2351a988cd41b944c8cc64a325eb6
  • Pointer size: 132 Bytes
  • Size of remote file: 1.73 MB
distorted/42_2 copy.png ADDED

Git LFS Details

  • SHA256: 4af2fe45570283d31e3d9b8e6b2589a18ba228b9bad882328f8790a410b1911d
  • Pointer size: 132 Bytes
  • Size of remote file: 2.32 MB
distorted/48_1 copy.png ADDED

Git LFS Details

  • SHA256: ec274605b35194aad1f8cd947ab16c79e0eacf8fe9e12bb9ecaade6fe1bc5af9
  • Pointer size: 132 Bytes
  • Size of remote file: 5.61 MB
hf_requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ gradio
2
+ torch
3
+ numpy
4
+ opencv-python
5
+ Pillow
6
+ scikit-image
model_pretrained/DocScanner-L.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1d907965aa5d8e99ea8d0891fb66d13bc4f23838547bac6f568d01d480ff8c8a
3
+ size 29328510
model_pretrained/seg.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cb79fdec55a5ed435dc74d8112aa9285d8213bae475022f711c709744fb19dd4
3
+ size 4715923