good_acc_v3 / app.py
nebiyu29's picture
came back to the before the converiosn of data frame
609abfc verified
raw
history blame
4.6 kB
import streamlit as st
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
import re
import pandas as pd
# Load the model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("nebiyu29/fintunned-v2-roberta_GA")
model = AutoModelForSequenceClassification.from_pretrained("nebiyu29/fintunned-v2-roberta_GA")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
# Define a function to split a text into segments of 512 tokens
def split_text(text):
text=re.sub(r'[^a-zA-Z\s]','',text)
text=str(text)
# Tokenize the text
tokens = tokenizer.tokenize(text)
# Initialize an empty list for segments
segments = []
# Initialize an empty list for current segment
current_segment = []
# Initialize a counter for tokens
token_count = 0
# Loop through the tokens
for token in tokens:
# Add the token to the current segment
current_segment.append(token)
# Increment the token count
token_count += 1
# If the token count reaches 512 or the end of the text, add the current segment to the segments list
if token_count == 512 or token == tokens[-1]:
# Convert the current segment to a string and add it to the segments list
segments.append(tokenizer.convert_tokens_to_string(current_segment))
# Reset the current segment and the token count
current_segment = []
token_count = 0
# Return the segments list
return segments
def classify(text):
# Define the labels
labels = ["depression", "anxiety", "bipolar disorder", "schizophrenia", "PTSD", "OCD", "ADHD", "autism", "eating disorder", "personality disorder", "phobia"]
#labels=list(model.config.id2label)
# Encode the labels
label_encodings = tokenizer(labels, padding=True, return_tensors="pt")
# Split the text into segments
segments = split_text(text)
# Initialize an empty list for logits
logits_list = []
# Loop through the segments
for segment in segments:
# Encode the segment and the labels
inputs = tokenizer([segment] + labels, padding=True, return_tensors="pt")
# Get the input ids and attention mask
input_ids = inputs["input_ids"]
attention_mask = inputs["attention_mask"]
# Move the input ids and attention mask to the device
input_ids = input_ids.to(device)
attention_mask = attention_mask.to(device)
# Get the model outputs for each segment
with torch.no_grad():
outputs = model(
input_ids,
attention_mask=attention_mask,
)
# Get the logits for each segment and append them to the logits list
logits = outputs.logits
logits_list.append(logits)
# Average the logits across the segments
avg_logits = torch.mean(torch.stack(logits_list), dim=0)
# Apply softmax to convert logits to probabilities
probabilities = torch.softmax(avg_logits, dim=1)
# Get the probabilities for each label
label_probabilities = probabilities[:, :len(labels)].tolist()
# Get the top 3 most likely labels and their probabilities
# Get the top 3 most likely labels and their probabilities
top_labels = []
top_probabilities = []
label_probabilities = label_probabilities[0] # Extract the list of probabilities for the first (and only) example
for _ in range(3):
max_prob_index = label_probabilities.index(max(label_probabilities))
top_labels.append(labels[max_prob_index])
top_probabilities.append(max(label_probabilities))
label_probabilities[max_prob_index] = 0 # Set the max probability to 0 to get the next highest probability
# Create a dictionary to store the results
results = {
"sequence": text,
"top_labels": top_labels,
"top_probabilities": top_probabilities
}
return results
# Streamlit app
st.title("Text Classification.")
st.write("Enter some text, and the model will classify it.")
text_input = st.text_input("Text Input")
#if st.button("Classify"):
predictions = classify(text_input)
labels_str=",".join(predictions["top_labels"])
probs_ints=",".join(map(str,predictions["top_probabilities"]))
#df=pd.DataFrame({'probabilities: ',probs_ints})
#formated_df=df.styled.format("{:.2f}").to_dict('list')
#for prediction in predictions:
# st.write(f"Segment Text: {prediction['segment_text']}")
st.write(f"Label: {labels_str}")
st.write(f"Probability: {probs_ints}")