updated
Browse files
app.py
CHANGED
@@ -2,11 +2,12 @@ import streamlit as st
|
|
2 |
import numpy as np
|
3 |
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
4 |
from normalizer import normalize
|
|
|
5 |
|
6 |
# Set the page configuration
|
7 |
st.set_page_config(
|
8 |
-
page_title="
|
9 |
-
page_icon=":
|
10 |
initial_sidebar_state="auto"
|
11 |
)
|
12 |
|
@@ -14,30 +15,50 @@ st.set_page_config(
|
|
14 |
with open("assets/style.css") as f:
|
15 |
st.markdown("<style>{}</style>".format(f.read()), unsafe_allow_html=True)
|
16 |
|
17 |
-
# Function to load the pre-trained model
|
|
|
18 |
def get_model():
|
19 |
-
|
20 |
-
|
21 |
-
|
|
|
|
|
|
|
|
|
22 |
|
23 |
# Load the tokenizer and model
|
24 |
tokenizer, model = get_model()
|
25 |
|
26 |
-
|
27 |
-
|
|
|
28 |
|
29 |
-
# Add
|
30 |
-
st.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
|
32 |
-
#
|
33 |
-
user_input
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
|
|
|
|
|
|
|
|
|
|
|
2 |
import numpy as np
|
3 |
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
4 |
from normalizer import normalize
|
5 |
+
import torch
|
6 |
|
7 |
# Set the page configuration
|
8 |
st.set_page_config(
|
9 |
+
page_title="NMT", # Updated title as seen in the image
|
10 |
+
page_icon=":robot_face:", # Use an emoji as the icon, similar to the robot face in the image
|
11 |
initial_sidebar_state="auto"
|
12 |
)
|
13 |
|
|
|
15 |
with open("assets/style.css") as f:
|
16 |
st.markdown("<style>{}</style>".format(f.read()), unsafe_allow_html=True)
|
17 |
|
18 |
+
# Function to load the pre-trained model with caching
|
19 |
+
@st.cache_resource
|
20 |
def get_model():
|
21 |
+
try:
|
22 |
+
tokenizer = AutoTokenizer.from_pretrained("kazalbrur/BanglaEnglishTokenizerBanglaT5", use_fast=True)
|
23 |
+
model = AutoModelForSeq2SeqLM.from_pretrained("kazalbrur/BanglaEnglishTranslationBanglaT5")
|
24 |
+
return tokenizer, model
|
25 |
+
except Exception as e:
|
26 |
+
st.error(f"Error loading model: {str(e)}")
|
27 |
+
return None, None
|
28 |
|
29 |
# Load the tokenizer and model
|
30 |
tokenizer, model = get_model()
|
31 |
|
32 |
+
if tokenizer and model:
|
33 |
+
# Add a header with custom CSS for black font color
|
34 |
+
st.markdown("<h1 style='color:black;'>Translate</h1>", unsafe_allow_html=True)
|
35 |
|
36 |
+
# Add a subheader for "Source Text"
|
37 |
+
st.subheader("Source Text")
|
38 |
+
|
39 |
+
# Text area for user input with height set to 200
|
40 |
+
user_input = st.text_area("", "", height=200, max_chars=200) # no label text, to match the image
|
41 |
+
|
42 |
+
# Button for submitting the input
|
43 |
+
submit_button = st.button("Translate")
|
44 |
+
|
45 |
+
# Divider between the input and output sections
|
46 |
+
st.markdown("<hr style='border:1px solid #ccc;'>", unsafe_allow_html=True)
|
47 |
|
48 |
+
# Perform prediction when user input is provided and the submit button is clicked
|
49 |
+
if user_input and submit_button:
|
50 |
+
try:
|
51 |
+
normalized_input = normalize(user_input)
|
52 |
+
input_ids = tokenizer(normalized_input, padding=True, truncation=True, max_length=128, return_tensors="pt").input_ids
|
53 |
+
generated_tokens = model.generate(input_ids, max_new_tokens=128)
|
54 |
+
decoded_tokens = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
|
55 |
+
|
56 |
+
# Show the output in a similar box style
|
57 |
+
st.subheader("Translation")
|
58 |
+
st.markdown(f"<div style='background-color: #E8F4FE; padding: 10px; border-radius: 5px;'>{decoded_tokens}</div>", unsafe_allow_html=True)
|
59 |
+
except torch.cuda.OutOfMemoryError:
|
60 |
+
st.error("Out of memory error! Please try smaller input or refresh the page.")
|
61 |
+
except Exception as e:
|
62 |
+
st.error(f"An error occurred during translation: {str(e)}")
|
63 |
+
else:
|
64 |
+
st.error("Model could not be loaded. Please check the model path and try again.")
|