nebiyu29 commited on
Commit
bd1016c
1 Parent(s): 0516563

initial commit

Browse files
Files changed (1) hide show
  1. app.py +110 -0
app.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
3
+ import torch
4
+ import re
5
+ # Load the model and tokenizer
6
+ tokenizer = AutoTokenizer.from_pretrained("nebiyu29/fintunned-v2-roberta_GA")
7
+ model = AutoModelForSequenceClassification.from_pretrained("nebiyu29/fintunned-v2-roberta_GA")
8
+
9
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
+ model = model.to(device)
11
+
12
+
13
+
14
+ # Define a function to split a text into segments of 512 tokens
15
+ def split_text(text):
16
+ text=re.sub(r'[^a-zA-Z\s]','',text)
17
+ text=str(text)
18
+ # Tokenize the text
19
+ tokens = tokenizer.tokenize(text)
20
+ # Initialize an empty list for segments
21
+ segments = []
22
+ # Initialize an empty list for current segment
23
+ current_segment = []
24
+ # Initialize a counter for tokens
25
+ token_count = 0
26
+ # Loop through the tokens
27
+ for token in tokens:
28
+ # Add the token to the current segment
29
+ current_segment.append(token)
30
+ # Increment the token count
31
+ token_count += 1
32
+ # If the token count reaches 512 or the end of the text, add the current segment to the segments list
33
+ if token_count == 512 or token == tokens[-1]:
34
+ # Convert the current segment to a string and add it to the segments list
35
+ segments.append(tokenizer.convert_tokens_to_string(current_segment))
36
+ # Reset the current segment and the token count
37
+ current_segment = []
38
+ token_count = 0
39
+ # Return the segments list
40
+ return segments
41
+
42
+ def classify(text):
43
+ # Define the labels
44
+ labels = ["depression", "anxiety", "bipolar disorder", "schizophrenia", "PTSD", "OCD", "ADHD", "autism", "eating disorder", "personality disorder", "phobia"]
45
+ #labels=list(model.config.id2label)
46
+ # Encode the labels
47
+ label_encodings = tokenizer(labels, padding=True, return_tensors="pt")
48
+ # Split the text into segments
49
+ segments = split_text(text)
50
+ # Initialize an empty list for logits
51
+ logits_list = []
52
+
53
+ # Loop through the segments
54
+ for segment in segments:
55
+ # Encode the segment and the labels
56
+ inputs = tokenizer([segment] + labels, padding=True, return_tensors="pt")
57
+ # Get the input ids and attention mask
58
+ input_ids = inputs["input_ids"]
59
+ attention_mask = inputs["attention_mask"]
60
+ # Move the input ids and attention mask to the device
61
+ input_ids = input_ids.to(device)
62
+ attention_mask = attention_mask.to(device)
63
+ # Get the model outputs for each segment
64
+
65
+ with torch.no_grad():
66
+ outputs = model(
67
+ input_ids,
68
+ attention_mask=attention_mask,
69
+ )
70
+ # Get the logits for each segment and append them to the logits list
71
+ logits = outputs.logits
72
+ logits_list.append(logits)
73
+ # Average the logits across the segments
74
+ avg_logits = torch.mean(torch.stack(logits_list), dim=0)
75
+ # Apply softmax to convert logits to probabilities
76
+ probabilities = torch.softmax(avg_logits, dim=1)
77
+ # Get the probabilities for each label
78
+ label_probabilities = probabilities[:, :len(labels)].tolist()
79
+
80
+ # Get the top 3 most likely labels and their probabilities
81
+ # Get the top 3 most likely labels and their probabilities
82
+ top_labels = []
83
+ top_probabilities = []
84
+ label_probabilities = label_probabilities[0] # Extract the list of probabilities for the first (and only) example
85
+ for _ in range(3):
86
+ max_prob_index = label_probabilities.index(max(label_probabilities))
87
+ top_labels.append(labels[max_prob_index])
88
+ top_probabilities.append(max(label_probabilities))
89
+ label_probabilities[max_prob_index] = 0 # Set the max probability to 0 to get the next highest probability
90
+
91
+ # Create a dictionary to store the results
92
+ results = {
93
+ "sequence": text,
94
+ "top_labels": top_labels,
95
+ "top_probabilities": top_probabilities
96
+ }
97
+
98
+ return results
99
+
100
+ # Streamlit app
101
+ st.title("Text Classification Demo")
102
+ st.write("Enter some text, and the model will classify it.")
103
+
104
+ text_input = st.text_input("Text Input")
105
+ if st.button("Classify"):
106
+ predictions = classify(text_input)
107
+ for prediction in predictions:
108
+ # st.write(f"Segment Text: {prediction['segment_text']}")
109
+ st.write(f"Label: {prediction['top_labels']}")
110
+ st.write(f"Probability: {prediction['top_probabilities']}")