kazalbrur commited on
Commit
a09a909
1 Parent(s): a790d32
Files changed (1) hide show
  1. app.py +43 -22
app.py CHANGED
@@ -2,11 +2,12 @@ import streamlit as st
2
  import numpy as np
3
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
4
  from normalizer import normalize
 
5
 
6
  # Set the page configuration
7
  st.set_page_config(
8
- page_title="Bengali to English Translator App",
9
- page_icon=":shield:",
10
  initial_sidebar_state="auto"
11
  )
12
 
@@ -14,30 +15,50 @@ st.set_page_config(
14
  with open("assets/style.css") as f:
15
  st.markdown("<style>{}</style>".format(f.read()), unsafe_allow_html=True)
16
 
17
- # Function to load the pre-trained model
 
18
  def get_model():
19
- tokenizer = AutoTokenizer.from_pretrained("kazalbrur/BanglaEnglishTokenizerBanglaT5", use_fast=True)
20
- model = AutoModelForSeq2SeqLM.from_pretrained("kazalbrur/BanglaEnglishTranslationBanglaT5")
21
- return tokenizer, model
 
 
 
 
22
 
23
  # Load the tokenizer and model
24
  tokenizer, model = get_model()
25
 
26
- # Add a header to the Streamlit app with custom CSS for black font color
27
- st.markdown("<h1 style='color:black;'>Bengali to English Translator</h1>", unsafe_allow_html=True)
 
28
 
29
- # Add placeholder text with custom CSS styling
30
- st.markdown("<span style='color:black'>Enter your Bengali text here</span>", unsafe_allow_html=True)
 
 
 
 
 
 
 
 
 
31
 
32
- # Text area for user input with label and height set to 250
33
- user_input = st.text_area("Enter your Bengali text here", "", height=250, label_visibility="collapsed")
34
-
35
- # Button for submitting the input
36
- submit_button = st.button("Translate")
37
-
38
- # Perform prediction when user input is provided and the submit button is clicked
39
- if user_input and submit_button:
40
- input_ids = tokenizer(normalize(user_input), padding=True, truncation=True, max_length=128, return_tensors="pt").input_ids
41
- generated_tokens = model.generate(input_ids, max_new_tokens=128)
42
- decoded_tokens = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
43
- st.write(f"<span style='color:black'>English Translation: {decoded_tokens}</span>", unsafe_allow_html=True)
 
 
 
 
 
 
2
  import numpy as np
3
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
4
  from normalizer import normalize
5
+ import torch
6
 
7
  # Set the page configuration
8
  st.set_page_config(
9
+ page_title="NMT", # Updated title as seen in the image
10
+ page_icon=":robot_face:", # Use an emoji as the icon, similar to the robot face in the image
11
  initial_sidebar_state="auto"
12
  )
13
 
 
15
  with open("assets/style.css") as f:
16
  st.markdown("<style>{}</style>".format(f.read()), unsafe_allow_html=True)
17
 
18
+ # Function to load the pre-trained model with caching
19
+ @st.cache_resource
20
  def get_model():
21
+ try:
22
+ tokenizer = AutoTokenizer.from_pretrained("kazalbrur/BanglaEnglishTokenizerBanglaT5", use_fast=True)
23
+ model = AutoModelForSeq2SeqLM.from_pretrained("kazalbrur/BanglaEnglishTranslationBanglaT5")
24
+ return tokenizer, model
25
+ except Exception as e:
26
+ st.error(f"Error loading model: {str(e)}")
27
+ return None, None
28
 
29
  # Load the tokenizer and model
30
  tokenizer, model = get_model()
31
 
32
+ if tokenizer and model:
33
+ # Add a header with custom CSS for black font color
34
+ st.markdown("<h1 style='color:black;'>Translate</h1>", unsafe_allow_html=True)
35
 
36
+ # Add a subheader for "Source Text"
37
+ st.subheader("Source Text")
38
+
39
+ # Text area for user input with height set to 200
40
+ user_input = st.text_area("", "", height=200, max_chars=200) # no label text, to match the image
41
+
42
+ # Button for submitting the input
43
+ submit_button = st.button("Translate")
44
+
45
+ # Divider between the input and output sections
46
+ st.markdown("<hr style='border:1px solid #ccc;'>", unsafe_allow_html=True)
47
 
48
+ # Perform prediction when user input is provided and the submit button is clicked
49
+ if user_input and submit_button:
50
+ try:
51
+ normalized_input = normalize(user_input)
52
+ input_ids = tokenizer(normalized_input, padding=True, truncation=True, max_length=128, return_tensors="pt").input_ids
53
+ generated_tokens = model.generate(input_ids, max_new_tokens=128)
54
+ decoded_tokens = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
55
+
56
+ # Show the output in a similar box style
57
+ st.subheader("Translation")
58
+ st.markdown(f"<div style='background-color: #E8F4FE; padding: 10px; border-radius: 5px;'>{decoded_tokens}</div>", unsafe_allow_html=True)
59
+ except torch.cuda.OutOfMemoryError:
60
+ st.error("Out of memory error! Please try smaller input or refresh the page.")
61
+ except Exception as e:
62
+ st.error(f"An error occurred during translation: {str(e)}")
63
+ else:
64
+ st.error("Model could not be loaded. Please check the model path and try again.")