SujonPro24 commited on
Commit
1da409e
·
verified ·
1 Parent(s): 01ed240

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -0
app.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #import streamlit as st
2
+ #from transformers import DistilBertForSequenceClassification, DistilBertTokenizerFast
3
+ #import torch
4
+
5
+ ## Load the fine-tuned model and tokenizer
6
+ #model_name = "fine-tuned-model"
7
+ #model = DistilBertForSequenceClassification.from_pretrained(model_name)
8
+ #tokenizer = DistilBertTokenizerFast.from_pretrained(model_name)
9
+
10
+ ## Function to classify text
11
+ #def classify_text(text):
12
+ # inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
13
+ # with torch.no_grad():
14
+ # outputs = model(**inputs)
15
+ # logits = outputs.logits
16
+ # predicted_class_id = torch.argmax(logits, dim=1).item()
17
+ # return "spam" if predicted_class_id == 1 else "ham"
18
+
19
+ ## Streamlit app
20
+ #st.title("Text Message Classification")
21
+ #st.write("Enter a text message and see if it's classified as spam or ham.")
22
+
23
+ #user_input = st.text_area("Text Message", "")
24
+ #if st.button("Classify"):
25
+ # if user_input:
26
+ # prediction = classify_text(user_input)
27
+ # st.write(f"The message is classified as: \n **{prediction}**")
28
+ # else:
29
+ # st.write("Please enter a text message.")
30
+ import streamlit as st
31
+ from transformers import DistilBertForSequenceClassification, DistilBertTokenizerFast
32
+ import torch
33
+
34
+ # Load the fine-tuned model and tokenizer
35
+ model_name = "fine-tuned-model"
36
+ model = DistilBertForSequenceClassification.from_pretrained(model_name)
37
+ tokenizer = DistilBertTokenizerFast.from_pretrained(model_name)
38
+
39
+ # Function to classify text
40
+ def classify_text(text):
41
+ inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
42
+ with torch.no_grad():
43
+ outputs = model(**inputs)
44
+ logits = outputs.logits
45
+ predicted_class_id = torch.argmax(logits, dim=1).item()
46
+ return "spam" if predicted_class_id == 1 else "ham"
47
+
48
+ # Streamlit app
49
+ st.set_page_config(page_title="Text Message Classification", page_icon="📧")
50
+
51
+ # Header
52
+ st.title("📧 Text Message Classification")
53
+
54
+ # Text input area
55
+ #st.subheader("Enter a Text Message:")
56
+ user_input = st.text_area("Type your message here...", height=50)
57
+
58
+ # Classify button and result display
59
+ if st.button("Classify"):
60
+ if user_input:
61
+ prediction = classify_text(user_input)
62
+ if prediction == "ham":
63
+ st.success(f"The message is classified as: **{prediction}**")
64
+ else:
65
+ st.error(f"The message is classified as: **{prediction}**")
66
+ else:
67
+ st.warning("Please enter a text message.")
68
+
69
+ # Footer
70
+ st.markdown("""
71
+ ---
72
+ Built with ❤️ using [Streamlit](https://streamlit.io/) and [Transformers](https://huggingface.co/transformers/).
73
+ """)