|
import streamlit as st |
|
import numpy as np |
|
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer |
|
from normalizer import normalize |
|
import torch |
|
|
|
|
|
st.set_page_config( |
|
page_title="NMT", |
|
page_icon=":robot_face:", |
|
initial_sidebar_state="auto" |
|
) |
|
|
|
|
|
st.markdown( |
|
""" |
|
<style> |
|
.stTextArea textarea { |
|
border: 2px solid #ccc; /* Customize the border color and width */ |
|
border-radius: 5px; /* Optional: rounded corners */ |
|
padding: 10px; /* Optional: padding inside the text area */ |
|
} |
|
</style> |
|
""", |
|
unsafe_allow_html=True |
|
) |
|
|
|
|
|
@st.cache_resource |
|
def get_model(): |
|
try: |
|
tokenizer = AutoTokenizer.from_pretrained("kazalbrur/BanglaEnglishTokenizerBanglaT5", use_fast=True) |
|
model = AutoModelForSeq2SeqLM.from_pretrained("kazalbrur/BanglaEnglishTranslationBanglaT5") |
|
return tokenizer, model |
|
except Exception as e: |
|
st.error(f"Error loading model: {str(e)}") |
|
return None, None |
|
|
|
|
|
tokenizer, model = get_model() |
|
|
|
if tokenizer and model: |
|
|
|
st.markdown("<h1 style='color:black;'>Translate</h1>", unsafe_allow_html=True) |
|
|
|
|
|
st.subheader("Source Text") |
|
|
|
|
|
user_input = st.text_area("", "", height=200, max_chars=200) |
|
|
|
|
|
submit_button = st.button("Translate") |
|
|
|
|
|
st.markdown("<hr style='border:1px solid #ccc;'>", unsafe_allow_html=True) |
|
|
|
|
|
if user_input and submit_button: |
|
try: |
|
normalized_input = normalize(user_input) |
|
input_ids = tokenizer(normalized_input, padding=True, truncation=True, max_length=128, return_tensors="pt").input_ids |
|
generated_tokens = model.generate(input_ids, max_new_tokens=128) |
|
decoded_tokens = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0] |
|
|
|
|
|
st.subheader("Translation") |
|
st.markdown(f"<div style='background-color: #E8F4FE; padding: 10px; border-radius: 5px;'>{decoded_tokens}</div>", unsafe_allow_html=True) |
|
except torch.cuda.OutOfMemoryError: |
|
st.error("Out of memory error! Please try smaller input or refresh the page.") |
|
except Exception as e: |
|
st.error(f"An error occurred during translation: {str(e)}") |
|
else: |
|
st.error("Model could not be loaded. Please check the model path and try again.") |
|
|