Mayanand commited on
Commit
f7fb909
·
1 Parent(s): ca29c83

Create recognition.py

Browse files
Files changed (1) hide show
  1. recognition.py +207 -0
recognition.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from argparse import ArgumentParser
2
+ from itertools import groupby
3
+ import os
4
+ import cv2
5
+ import torch
6
+ import torch.nn as nn
7
+ from torchvision import transforms
8
+ import utils_
9
+
10
+
11
+ class BidirectionalLSTM(nn.Module):
12
+ def __init__(self, nIn, nHidden, nOut):
13
+ super(BidirectionalLSTM, self).__init__()
14
+
15
+ self.rnn = nn.LSTM(nIn, nHidden, bidirectional=True)
16
+ self.embedding = nn.Linear(nHidden * 2, nOut)
17
+
18
+ def forward(self, input):
19
+ recurrent, _ = self.rnn(input)
20
+ T, b, h = recurrent.size()
21
+ t_rec = recurrent.view(T * b, h)
22
+
23
+ output = self.embedding(t_rec) # [T * b, nOut]
24
+ output = output.view(T, b, -1)
25
+
26
+ return output
27
+
28
+
29
+ class CRNN(nn.Module):
30
+ def __init__(self, imgH, nc, nclass, nh, n_rnn=2, leakyRelu=False):
31
+ super(CRNN, self).__init__()
32
+ assert imgH % 16 == 0, "imgH has to be a multiple of 16"
33
+
34
+ ks = [3, 3, 3, 3, 3, 3, 2]
35
+ ps = [1, 1, 1, 1, 1, 1, 0]
36
+ ss = [1, 1, 1, 1, 1, 1, 1]
37
+ nm = [64, 128, 256, 256, 512, 512, 512]
38
+
39
+ cnn = nn.Sequential()
40
+
41
+ def convRelu(i, batchNormalization=False):
42
+ nIn = nc if i == 0 else nm[i - 1]
43
+ nOut = nm[i]
44
+ cnn.add_module(
45
+ "conv{0}".format(i), nn.Conv2d(nIn, nOut, ks[i], ss[i], ps[i])
46
+ )
47
+ if batchNormalization:
48
+ cnn.add_module("batchnorm{0}".format(i), nn.BatchNorm2d(nOut))
49
+ if leakyRelu:
50
+ cnn.add_module("relu{0}".format(i), nn.LeakyReLU(0.2, inplace=True))
51
+ else:
52
+ cnn.add_module("relu{0}".format(i), nn.ReLU(True))
53
+
54
+ convRelu(0)
55
+ cnn.add_module("pooling{0}".format(0), nn.MaxPool2d(2, 2)) # 64x16x64
56
+ convRelu(1)
57
+ cnn.add_module("pooling{0}".format(1), nn.MaxPool2d(2, 2)) # 128x8x32
58
+ convRelu(2, True)
59
+ convRelu(3)
60
+ cnn.add_module(
61
+ "pooling{0}".format(2), nn.MaxPool2d((2, 2), (2, 1), (0, 1))
62
+ ) # 256x4x16
63
+ convRelu(4, True)
64
+ convRelu(5)
65
+ cnn.add_module(
66
+ "pooling{0}".format(3), nn.MaxPool2d((2, 2), (2, 1), (0, 1))
67
+ ) # 512x2x16
68
+ convRelu(6, True) # 512x1x16
69
+
70
+ self.cnn = cnn
71
+ self.rnn = nn.Sequential(
72
+ BidirectionalLSTM(512, nh, nh), BidirectionalLSTM(nh, nh, nclass)
73
+ )
74
+
75
+ def forward(self, input):
76
+ # conv features
77
+ conv = self.cnn(input)
78
+ b, c, h, w = conv.size()
79
+ assert h == 1, "the height of conv must be 1"
80
+ conv = conv.squeeze(2)
81
+ conv = conv.permute(2, 0, 1) # [w, b, c]
82
+
83
+ # rnn features
84
+ output = self.rnn(conv)
85
+
86
+ return output
87
+
88
+
89
+ VOCAB = [
90
+ "BLANK",
91
+ "Z",
92
+ "B",
93
+ "4",
94
+ "X",
95
+ "R",
96
+ "2",
97
+ "U",
98
+ "D",
99
+ "G",
100
+ "Q",
101
+ "S",
102
+ "A",
103
+ "N",
104
+ "K",
105
+ "0",
106
+ "C",
107
+ "J",
108
+ "P",
109
+ "Y",
110
+ "H",
111
+ "7",
112
+ "W",
113
+ "V",
114
+ "5",
115
+ "F",
116
+ "L",
117
+ "8",
118
+ "1",
119
+ "I",
120
+ "T",
121
+ "M",
122
+ "3",
123
+ "O",
124
+ "9",
125
+ "E",
126
+ "6",
127
+ ]
128
+
129
+
130
+ def add_text(image, text, pos):
131
+ xmin, ymin, xmax, ymax = pos
132
+ image = cv2.putText(
133
+ image,
134
+ text,
135
+ (xmin, ymin - 15),
136
+ cv2.FONT_HERSHEY_COMPLEX,
137
+ 0.85,
138
+ (0, 0, 255),
139
+ 2,
140
+ cv2.LINE_AA,
141
+ )
142
+ return image
143
+
144
+
145
+ def greedy_decode(preds):
146
+ # collapse best path (using itertools.groupby), map to chars, join char list to string
147
+ best_chars_collapsed = [k for k, _ in groupby(preds) if k != "BLANK"]
148
+ res = "".join(best_chars_collapsed)
149
+ return res
150
+
151
+
152
+ def read_image(file):
153
+ img = cv2.imread(file)
154
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
155
+ return img
156
+
157
+
158
+ def idx2char(preds):
159
+ return [VOCAB[idx] for idx in preds]
160
+
161
+
162
+ def post_process(preds):
163
+ # preds shape (seq_len, num_class)
164
+ _, preds = torch.max(preds, dim=1)
165
+ return idx2char(preds.tolist())
166
+
167
+
168
+ transform = transforms.Compose(
169
+ [
170
+ transforms.ToTensor(),
171
+ transforms.Grayscale(),
172
+ transforms.Resize((32, 128)),
173
+ transforms.Normalize(0.5, 0.5),
174
+ ]
175
+ )
176
+
177
+ model = CRNN(32, 1, 37, 512)
178
+
179
+ state = torch.load("./out/ocr_point08.pt")
180
+ model.load_state_dict(state["model"])
181
+
182
+
183
+ def recognize(image):
184
+ model.eval()
185
+ preds = model(transform(image).unsqueeze(0))
186
+ text = post_process(preds[:, 0, :])
187
+ text = greedy_decode(text)
188
+ return text
189
+
190
+
191
+ if __name__ == "__main__":
192
+ parser = ArgumentParser()
193
+ parser.add_argument(
194
+ "--image",
195
+ default=None,
196
+ type=str,
197
+ help="path to image on which prediction will be made",
198
+ )
199
+
200
+ args = parser.parse_args()
201
+
202
+ assert os.path.exists(args.image), f"given path {args.image} does not exists"
203
+
204
+ im = read_image(args.image)
205
+
206
+ text = recognize(im)
207
+