good_acc / main.py
nebiyu29's picture
initial commit
aea9d52 verified
raw
history blame
1.09 kB
from transformers import AutoTokenizer, AutoModelForSequenceClassification
# Load model directly
from transformers import AutoTokenizer, AutoModelForSequenceClassification
tokenizer = AutoTokenizer.from_pretrained("nebiyu29/fintunned-v2-roberta_GA")
model = AutoModelForSequenceClassification.from_pretrained("nebiyu29/fintunned-v2-roberta_GA")
def classify_text(text):
"""
This function preprocesses, feeds text to the model, and outputs the predicted class.
"""
inputs = tokenizer(text, padding=True, truncation=True, return_tensors="pt")
outputs = model(**inputs)
logits = outputs.logits # Access logits instead of pipeline output
predictions = torch.argmax(logits, dim=-1) # Apply argmax for prediction
return model.config.id2label[predictions.item()] # Map index to class label
interface = gr.Interface(
fn=classify_text,
inputs="text",
outputs="text",
title="Text Classification Demo",
description="Enter some text, and the model will classify it.",
choices=["positive", "negative", "neutral"] # Adjust class names
)
interface.launch()