240422_n_11 / app.py
oceankim's picture
Update app.py
e431b7a verified
raw
history blame
1.09 kB
import gradio as gr
import torch
from transformers import BertTokenizer, BertForSequenceClassification
# Load pre-trained model tokenizer (vocabulary)
tokenizer = BertTokenizer.from_pretrained('ProsusAI/finbert')
# Load pre-trained model
model = BertForSequenceClassification.from_pretrained('ProsusAI/finbert')
def analyze_sentiment(sec_text):
# Encode the text
tokens = tokenizer.encode_plus(sec_text, add_special_tokens=True, return_tensors="pt")
# Make prediction
with torch.no_grad():
outputs = model(**tokens)
predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
# Convert predictions to sentiment labels
labels = ['Positive', 'Neutral', 'Negative']
sentiment = labels[torch.argmax(predictions)]
# Return the sentiment analysis result
return f"{sentiment} Sentiment"
# Define the Gradio interface
gr_interface = gr.Interface(
fn=analyze_sentiment,
inputs=gr.Textbox(lines=1, placeholder="..."),
outputs="text",
title="Sentiment Analysis"
)
# Launch the interface
gr_interface.launch()