vorkna commited on
Commit
991e502
·
verified ·
1 Parent(s): 595d367

Upload 3 files

Browse files
Files changed (3) hide show
  1. model/palocr.pth +3 -0
  2. model/palocr.py +78 -0
  3. model/palocr.yaml +9 -0
model/palocr.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:83c450861f064af31ee4c309c34e8712ac953527fdb533bd7ca9d70b00e7fa09
3
+ size 15213813
model/palocr.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+ class BidirectionalLSTM(nn.Module):
4
+
5
+ def __init__(self, input_size, hidden_size, output_size):
6
+ super(BidirectionalLSTM, self).__init__()
7
+ self.rnn = nn.LSTM(input_size, hidden_size, bidirectional=True, batch_first=True)
8
+ self.linear = nn.Linear(hidden_size * 2, output_size)
9
+
10
+ def forward(self, input):
11
+ """
12
+ input : visual feature [batch_size x T x input_size]
13
+ output : contextual feature [batch_size x T x output_size]
14
+ """
15
+ try: # multi gpu needs this
16
+ self.rnn.flatten_parameters()
17
+ except: # quantization doesn't work with this
18
+ pass
19
+ recurrent, _ = self.rnn(input) # batch_size x T x input_size -> batch_size x T x (2*hidden_size)
20
+ output = self.linear(recurrent) # batch_size x T x output_size
21
+ return output
22
+
23
+ class VGG_FeatureExtractor(nn.Module):
24
+
25
+ def __init__(self, input_channel, output_channel=256):
26
+ super(VGG_FeatureExtractor, self).__init__()
27
+ self.output_channel = [int(output_channel / 8), int(output_channel / 4),
28
+ int(output_channel / 2), output_channel]
29
+ self.ConvNet = nn.Sequential(
30
+ nn.Conv2d(input_channel, self.output_channel[0], 3, 1, 1), nn.ReLU(True),
31
+ nn.MaxPool2d(2, 2),
32
+ nn.Conv2d(self.output_channel[0], self.output_channel[1], 3, 1, 1), nn.ReLU(True),
33
+ nn.MaxPool2d(2, 2),
34
+ nn.Conv2d(self.output_channel[1], self.output_channel[2], 3, 1, 1), nn.ReLU(True),
35
+ nn.Conv2d(self.output_channel[2], self.output_channel[2], 3, 1, 1), nn.ReLU(True),
36
+ nn.MaxPool2d((2, 1), (2, 1)),
37
+ nn.Conv2d(self.output_channel[2], self.output_channel[3], 3, 1, 1, bias=False),
38
+ nn.BatchNorm2d(self.output_channel[3]), nn.ReLU(True),
39
+ nn.Conv2d(self.output_channel[3], self.output_channel[3], 3, 1, 1, bias=False),
40
+ nn.BatchNorm2d(self.output_channel[3]), nn.ReLU(True),
41
+ nn.MaxPool2d((2, 1), (2, 1)),
42
+ nn.Conv2d(self.output_channel[3], self.output_channel[3], 2, 1, 0), nn.ReLU(True))
43
+
44
+ def forward(self, input):
45
+ return self.ConvNet(input)
46
+
47
+ class Model(nn.Module):
48
+
49
+ def __init__(self, input_channel, output_channel, hidden_size, num_class):
50
+ super(Model, self).__init__()
51
+ """ FeatureExtraction """
52
+ self.FeatureExtraction = VGG_FeatureExtractor(input_channel, output_channel)
53
+ self.FeatureExtraction_output = output_channel
54
+ self.AdaptiveAvgPool = nn.AdaptiveAvgPool2d((None, 1))
55
+
56
+ """ Sequence modeling"""
57
+ self.SequenceModeling = nn.Sequential(
58
+ BidirectionalLSTM(self.FeatureExtraction_output, hidden_size, hidden_size),
59
+ BidirectionalLSTM(hidden_size, hidden_size, hidden_size))
60
+ self.SequenceModeling_output = hidden_size
61
+
62
+ """ Prediction """
63
+ self.Prediction = nn.Linear(self.SequenceModeling_output, num_class)
64
+
65
+
66
+ def forward(self, input, text):
67
+ """ Feature extraction stage """
68
+ visual_feature = self.FeatureExtraction(input)
69
+ visual_feature = self.AdaptiveAvgPool(visual_feature.permute(0, 3, 1, 2))
70
+ visual_feature = visual_feature.squeeze(3)
71
+
72
+ """ Sequence modeling stage """
73
+ contextual_feature = self.SequenceModeling(visual_feature)
74
+
75
+ """ Prediction stage """
76
+ prediction = self.Prediction(contextual_feature.contiguous())
77
+
78
+ return prediction
model/palocr.yaml ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ network_params:
2
+ input_channel: 1
3
+ output_channel: 256
4
+ hidden_size: 256
5
+ imgH: 64
6
+ lang_list:
7
+ - 'en'
8
+ - 'th'
9
+ character_list: 0123456789!"#$%&'()*+,-./:;<=>?@[\]^_`{|}~ €กขฃคฅฆงจฉชซฌญฎฏฐฑฒณดตถทธนบปผฝพฟภมยรลวศษสหฬอฮฤฦะาำเแโใไๆ๏๐๑๒๓๔๕๖๗๘๙๚๛ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz