import os import torch import argparse import streamlit as st import sentencepiece as spm from utils import utils_cls from model import BanglaTransformer from config import config as cfg torch.manual_seed(0) # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') device = torch.device('cpu') uobj = utils_cls(device=device) __MODULE__ = "Bangla Language Translation" __MAIL__ = "saifulbrur79@gmail.com" __MODIFICAIOTN__ = "28/03/2023" __LICENSE__ = "MIT" st.write(""" Bangla to English Translation """) BASE_URL = "./model" class Bn2EnTranslation: def __init__(self): self.bn_tokenizer= os.path.join(BASE_URL , "bn_model.model") self.en_tokenizer=os.path.join(BASE_URL, 'en_model.model') self.bn_vocab=os.path.join(BASE_URL,'bn_vocab.pkl') self.en_vocab=os.path.join(BASE_URL, 'en_vocab.pkl') self.model= os.path.join(BASE_URL,'pytorch_model.pt') def read_data(self, data_path): with open(data_path, "r") as f: data = f.readlines() data = list(map(lambda x: [x.split("\t")[0], x.split("\t")[1].replace("\n", "")], data)) return data def load_tokenizer(self, tokenizer_path:str = "")->object: _tokenizer = spm.SentencePieceProcessor(model_file=tokenizer_path) return _tokenizer def get_vocab(self, BN_VOCAL_PATH:str="", EN_VOCAL_PATH:str=""): bn_vocal, en_vocal = uobj.load_bn_vocal(BN_VOCAL_PATH), uobj.load_en_vocal(EN_VOCAL_PATH) return bn_vocal, en_vocal def load_model(self, model_path:str = "", SRC_VOCAB_SIZE:int=0, TGT_VOCAB_SIZE:int=0): model = BanglaTransformer( cfg.NUM_ENCODER_LAYERS, cfg.NUM_DECODER_LAYERS, cfg.EMB_SIZE, SRC_VOCAB_SIZE, TGT_VOCAB_SIZE, cfg.FFN_HID_DIM, nhead= cfg.NHEAD) model.to(device) checkpoint = torch.load(model_path) model.load_state_dict(checkpoint['model_state_dict']) model.eval() return model def greedy_decode(self, model, src, src_mask, max_len, start_symbol, eos_index): src = src.to(device) src_mask = src_mask.to(device) memory = model.encode(src, src_mask) ys = torch.ones(1, 1).fill_(start_symbol).type(torch.long).to(device) for i in range(max_len-1): memory = memory.to(device) memory_mask = torch.zeros(ys.shape[0], memory.shape[0]).to(device).type(torch.bool) tgt_mask = (uobj.generate_square_subsequent_mask(ys.size(0)) .type(torch.bool)).to(device) out = model.decode(ys, memory, tgt_mask) out = out.transpose(0, 1) prob = model.generator(out[:, -1]) _, next_word = torch.max(prob, dim = 1) next_word = next_word.item() ys = torch.cat([ys,torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=0) if next_word == eos_index: break return ys def get_bntoen_model(self): print("Tokenizer Loading ...... : ", end="", flush=True) bn_tokenizer = self.load_tokenizer(tokenizer_path=self.bn_tokenizer) print("Done") print("Vocab Loading ...... : ", end="", flush=True) bn_vocab, en_vocab = self.get_vocab(BN_VOCAL_PATH=self.bn_vocab, EN_VOCAL_PATH=self.en_vocab) print("Done") print("Model Loading ...... : ", end="", flush=True) model = self.load_model(model_path=self.model, SRC_VOCAB_SIZE=len(bn_vocab), TGT_VOCAB_SIZE=len(en_vocab)) print("Done") models = { "bn_tokenizer" : bn_tokenizer, "bn_vocab" : bn_vocab, "en_vocab" : en_vocab, "model": model } return models def translate(self, text, models): model = models["model"] src_vocab = models["bn_vocab"] tgt_vocab = models["en_vocab"] src_tokenizer = models["bn_tokenizer"] src = text PAD_IDX, BOS_IDX, EOS_IDX= src_vocab[''], src_vocab[''], src_vocab[''] tokens = [BOS_IDX] + [src_vocab.get_stoi()[tok] for tok in src_tokenizer.encode(src, out_type=str)]+ [EOS_IDX] num_tokens = len(tokens) src = (torch.LongTensor(tokens).reshape(num_tokens, 1) ) src_mask = (torch.zeros(num_tokens, num_tokens)).type(torch.bool) tgt_tokens = self.greedy_decode(model, src, src_mask, max_len=num_tokens + 5, start_symbol=BOS_IDX, eos_index= EOS_IDX).flatten() p_text = " ".join([tgt_vocab.get_itos()[tok] for tok in tgt_tokens]).replace("", "").replace("", "") pts = " ".join(list(map(lambda x : x , p_text.replace(" ", "").split("▁")))) return pts.strip() # if __name__ == "__main__": # print(torch.cuda.get_device_name(0)) text = "এই উপজেলায় ১টি সরকারি কলেজ রয়েছে" obj = Bn2EnTranslation() models = obj.get_bntoen_model() text = st.text_area("Enter some text:এই উপজেলায় ১টি সরকারি কলেজ রয়েছে") if text: pre = obj.translate(text, models) print(f"Input : {text}") print(f"Prediction : {pre}")