import streamlit as st import torch import torch.nn as nn import torch.nn.functional as F from torchvision import transforms from PIL import Image from pathlib import Path import pickle transform = transforms.Compose([ transforms.ToTensor() ]) class TextProcessor: def __init__(self, alphabet): self.alphabet = alphabet self.pad_token = "[PAD]" self.stoi = {s: i for i, s in enumerate(self.alphabet,1)} self.stoi[self.pad_token] = 0 self.itos = {i: s for s, i in self.stoi.items()} def encode(self, label): return [self.stoi[s] for s in label] def decode(self, ids): return ''.join([self.itos[i] for i in ids]) def __len__(self): return len(self.alphabet) + 1 MAX_LENGTH = 32 DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' # Load tokenizer @st.cache_resource def load_tokenizer(): with open("text_process.cls",'rb') as f: tokenizer = pickle.load(f) return tokenizer tokenizer = load_tokenizer() encode = tokenizer.encode decode = tokenizer.decode class CRNN(nn.Module): def __init__(self, num_channels, hidden_size, num_classes): super(CRNN, self).__init__() self.conv1 = nn.Sequential( nn.Conv2d(1, 64, kernel_size=(2,3), padding=1), nn.ReLU(), nn.MaxPool2d(2, 2) ) self.conv2 = nn.Sequential( nn.Conv2d(64, 128, kernel_size=(2,3), padding=1), nn.ReLU(), nn.MaxPool2d(2, 2) ) self.rnn = nn.LSTM(128 * 16, hidden_size, bidirectional=True, batch_first=True) self.fc = nn.Linear(hidden_size * 2, num_classes) def forward(self, x): # x shape: [batch_size, channels, height, width] # CNN feature extraction conv = self.conv1(x) conv = self.conv2(conv) batch, channels, height, width = conv.size() conv = conv.permute(0, 3, 1, 2) # [batch, width, channels, height] conv = conv.contiguous().view(batch, width, channels * height) rnn, _ = self.rnn(conv) output = self.fc(rnn) return output @st.cache_resource def load_model(selected_model_path): model = CRNN(num_channels=1, hidden_size=256, num_classes=len(tokenizer)) model.load_state_dict(torch.load(selected_model_path, map_location=torch.device('cpu'))) model.eval() return model def preprocess_image(img): # img = image.convert("L") # Ensuring image is in grayscale original_width, original_height = img.size new_width = int(61 * original_width / original_height) # Calculate width to preserve aspect ratio image = img.resize((new_width, 61)) image = transform(image) return image def post_process(preds): encodings = [] is_previous_zero = False for pred in preds: #only considering >0 tokens if pred==0: zero_found = True pass elif not encodings: encodings.append(pred) elif encodings[-1] != pred: encodings.append(pred) return decode(encodings) def inference(model, image): with torch.no_grad(): image = image.to(DEVICE) outputs = model(image) log_probs = F.log_softmax(outputs, dim=2) pred_chars = torch.argmax(log_probs, dim=2) return pred_chars.squeeze().cpu().numpy() def predict(image): image = preprocess_image(image) image = image.unsqueeze(0) #remove batch dim predictions = model(image) pred_ids = torch.argmax(predictions, dim=-1).detach().flatten().tolist() text = post_process(pred_ids) return text st.title("CRNN Sinhala Printed Text Recognition") fp = Path(".").glob("crnn*.pt") selected_model_path = st.selectbox(label="Select Model...", options=fp) model = load_model(selected_model_path) uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) if uploaded_file is not None: image = Image.open(uploaded_file).convert("L") st.image(image, caption='Uploaded Image', use_column_width=True) if st.button('Predict'): predicted_text = predict(image) st.write("Predicted Text:") st.write(predicted_text) st.markdown("---") st.write("Note: This app uses a pre-trained CRNN model for printed Sinhala text recognition.")