Ransaka commited on
Commit
232505b
·
verified ·
1 Parent(s): 93ea391

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +90 -68
app.py CHANGED
@@ -3,10 +3,30 @@ import torch
3
  import torch.nn as nn
4
  import torch.nn.functional as F
5
  from torchvision import transforms
6
- from torchvision.transforms import functional as TF
7
  from PIL import Image
8
- from sinlib import Tokenizer
9
  from pathlib import Path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  MAX_LENGTH = 32
12
  DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
@@ -14,75 +34,84 @@ DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
14
  # Load tokenizer
15
  @st.cache_resource
16
  def load_tokenizer():
17
- tokenizer = Tokenizer(max_length=1000).load_from_pretrained("gpt2.json")
18
- tokenizer.max_length = MAX_LENGTH
19
  return tokenizer
20
 
21
  tokenizer = load_tokenizer()
 
 
22
 
23
  class CRNN(nn.Module):
24
- def __init__(self, num_chars):
25
  super(CRNN, self).__init__()
26
 
27
- self.cnn = nn.Sequential(
28
- nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1),
29
- nn.ReLU(),
30
- nn.MaxPool2d(kernel_size=2, stride=2),
31
- nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
32
- nn.ReLU(),
33
- nn.MaxPool2d(kernel_size=2, stride=2),
34
- nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
35
- nn.BatchNorm2d(256),
36
  nn.ReLU(),
37
- nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
38
- nn.ReLU(),
39
- nn.MaxPool2d(kernel_size=(2, 1)),
40
- nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1),
41
- nn.BatchNorm2d(512),
42
- nn.ReLU(),
43
- nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
44
- nn.ReLU(),
45
- nn.MaxPool2d(kernel_size=(2, 1)),
46
- nn.Conv2d(512, 512, kernel_size=2, stride=1),
47
- nn.BatchNorm2d(512),
48
- nn.ReLU()
49
  )
50
 
51
- # RNN layers
52
- self.rnn = nn.GRU(512 * 7, 256, bidirectional=True, batch_first=True, num_layers=2)
53
- self.linear = nn.Linear(512, num_chars)
 
 
 
 
 
 
54
 
55
  def forward(self, x):
56
- conv = self.cnn(x)
57
- batch, channel, height, width = conv.size()
58
- conv = conv.permute(0, 3, 1, 2)
59
- conv = conv.contiguous().view(batch, width, channel * height)
60
- output, _ = self.rnn(conv)
61
- output = self.linear(output)
 
 
 
 
 
 
 
 
62
  return output
63
 
 
64
  @st.cache_resource
65
  def load_model(selected_model_path):
66
- model = CRNN(num_chars=len(tokenizer))
67
- model.load_state_dict(torch.load(f'{selected_model_path}', map_location=torch.device('cpu')))
68
  model.eval()
69
  return model
70
 
71
- def preprocess_image(image):
72
- transform = transforms.Compose([
73
- transforms.Grayscale(),
74
- transforms.ToTensor(),
75
- ])
76
-
77
- image = TF.resize(image, (128, 2600), interpolation=Image.BILINEAR)
78
  image = transform(image)
 
79
 
80
- if image.shape[0] != 1:
81
- image = image.mean(dim=0, keepdim=True)
82
 
83
- image = image.unsqueeze(0)
84
- return image
 
 
 
 
 
 
 
 
 
 
 
85
 
 
86
  def inference(model, image):
87
  with torch.no_grad():
88
  image = image.to(DEVICE)
@@ -91,35 +120,28 @@ def inference(model, image):
91
  pred_chars = torch.argmax(log_probs, dim=2)
92
  return pred_chars.squeeze().cpu().numpy()
93
 
94
- st.title("CRNN Printed Text Recognition")
95
- st.warning("**Note**: This model was trained on images with these settings, \
96
- with width ranging from 800 to 2600 pixels and height ranging from 128 to 600 pixels. \
97
- For better results, use images within these limitations."
98
- )
99
- fp = Path(".").glob("*.pth")
 
 
 
 
100
  selected_model_path = st.selectbox(label="Select Model...", options=fp)
101
  model = load_model(selected_model_path)
102
  uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
103
 
104
  if uploaded_file is not None:
105
- image = Image.open(uploaded_file)
106
  st.image(image, caption='Uploaded Image', use_column_width=True)
107
- w,h = image.size
108
- w_color = h_color = 'green'
109
- if not 800 <= w <= 2600:
110
- w_color = "red"
111
- if not 128 <= h <= 600:
112
- h_color = "red"
113
- with st.expander("Click See Image Details"):
114
- st.write(f"Width = :{w_color}[{w}];",f"Height = :{h_color}[{h}]")
115
 
116
  if st.button('Predict'):
117
- processed_image = preprocess_image(image)
118
- predicted_sequence = inference(model, processed_image)
119
-
120
- decoded_text = tokenizer.decode(predicted_sequence, skip_special_tokens=True)
121
  st.write("Predicted Text:")
122
- st.write(decoded_text)
123
 
124
  st.markdown("---")
125
  st.write("Note: This app uses a pre-trained CRNN model for printed Sinhala text recognition.")
 
3
  import torch.nn as nn
4
  import torch.nn.functional as F
5
  from torchvision import transforms
 
6
  from PIL import Image
 
7
  from pathlib import Path
8
+ import pickle
9
+
10
+ transform = transforms.Compose([
11
+ transforms.ToTensor()
12
+ ])
13
+
14
+ class TextProcessor:
15
+ def __init__(self, alphabet):
16
+ self.alphabet = alphabet
17
+ self.pad_token = "[PAD]"
18
+ self.stoi = {s: i for i, s in enumerate(self.alphabet,1)}
19
+ self.stoi[self.pad_token] = 0
20
+ self.itos = {i: s for s, i in self.stoi.items()}
21
+
22
+ def encode(self, label):
23
+ return [self.stoi[s] for s in label]
24
+
25
+ def decode(self, ids):
26
+ return ''.join([self.itos[i] for i in ids])
27
+
28
+ def __len__(self):
29
+ return len(self.alphabet) + 1
30
 
31
  MAX_LENGTH = 32
32
  DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
 
34
  # Load tokenizer
35
  @st.cache_resource
36
  def load_tokenizer():
37
+ with open("text_process.cls",'rb') as f:
38
+ tokenizer = pickle.load(f)
39
  return tokenizer
40
 
41
  tokenizer = load_tokenizer()
42
+ encode = tokenizer.encode
43
+ decode = tokenizer.decode
44
 
45
  class CRNN(nn.Module):
46
+ def __init__(self, num_channels, hidden_size, num_classes):
47
  super(CRNN, self).__init__()
48
 
49
+ self.conv1 = nn.Sequential(
50
+ nn.Conv2d(1, 64, kernel_size=(2,3), padding=1),
 
 
 
 
 
 
 
51
  nn.ReLU(),
52
+ nn.MaxPool2d(2, 2)
 
 
 
 
 
 
 
 
 
 
 
53
  )
54
 
55
+ self.conv2 = nn.Sequential(
56
+ nn.Conv2d(64, 128, kernel_size=(2,3), padding=1),
57
+ nn.ReLU(),
58
+ nn.MaxPool2d(2, 2)
59
+ )
60
+
61
+ self.rnn = nn.LSTM(128 * 16, hidden_size, bidirectional=True, batch_first=True)
62
+
63
+ self.fc = nn.Linear(hidden_size * 2, num_classes)
64
 
65
  def forward(self, x):
66
+ # x shape: [batch_size, channels, height, width]
67
+
68
+ # CNN feature extraction
69
+ conv = self.conv1(x)
70
+ conv = self.conv2(conv)
71
+ batch, channels, height, width = conv.size()
72
+
73
+ conv = conv.permute(0, 3, 1, 2) # [batch, width, channels, height]
74
+ conv = conv.contiguous().view(batch, width, channels * height)
75
+
76
+ rnn, _ = self.rnn(conv)
77
+
78
+ output = self.fc(rnn)
79
+
80
  return output
81
 
82
+
83
  @st.cache_resource
84
  def load_model(selected_model_path):
85
+ model = CRNN(num_channels=1, hidden_size=256, num_classes=len(tokenizer))
86
+ model.load_state_dict(torch.load(selected_model_path, map_location=torch.device('cpu')))
87
  model.eval()
88
  return model
89
 
90
+
91
+ def preprocess_image(img):
92
+ # img = image.convert("L") # Ensuring image is in grayscale
93
+ original_width, original_height = img.size
94
+ new_width = int(61 * original_width / original_height) # Calculate width to preserve aspect ratio
95
+ image = img.resize((new_width, 61))
 
96
  image = transform(image)
97
+ return image
98
 
 
 
99
 
100
+ def post_process(preds):
101
+ encodings = []
102
+ is_previous_zero = False
103
+ for pred in preds:
104
+ #only considering >0 tokens
105
+ if pred==0:
106
+ zero_found = True
107
+ pass
108
+ elif not encodings:
109
+ encodings.append(pred)
110
+ elif encodings[-1] != pred:
111
+ encodings.append(pred)
112
+ return decode(encodings)
113
 
114
+
115
  def inference(model, image):
116
  with torch.no_grad():
117
  image = image.to(DEVICE)
 
120
  pred_chars = torch.argmax(log_probs, dim=2)
121
  return pred_chars.squeeze().cpu().numpy()
122
 
123
+ def predict(image):
124
+ image = preprocess_image(image)
125
+ image = image.unsqueeze(0) #remove batch dim
126
+ predictions = model(image)
127
+ pred_ids = torch.argmax(predictions, dim=-1).detach().flatten().tolist()
128
+ text = post_process(pred_ids)
129
+ return text
130
+
131
+ st.title("CRNN Sinhala Printed Text Recognition")
132
+ fp = Path(".").glob("crnn*.pt")
133
  selected_model_path = st.selectbox(label="Select Model...", options=fp)
134
  model = load_model(selected_model_path)
135
  uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
136
 
137
  if uploaded_file is not None:
138
+ image = Image.open(uploaded_file).convert("L")
139
  st.image(image, caption='Uploaded Image', use_column_width=True)
 
 
 
 
 
 
 
 
140
 
141
  if st.button('Predict'):
142
+ predicted_text = predict(image)
 
 
 
143
  st.write("Predicted Text:")
144
+ st.write(predicted_text)
145
 
146
  st.markdown("---")
147
  st.write("Note: This app uses a pre-trained CRNN model for printed Sinhala text recognition.")