Spaces:
Runtime error
Runtime error
import torch | |
import gradio as gr | |
from sentence_transformers import SentenceTransformer | |
from safetensors.torch import load_file | |
import torch.nn as nn | |
# Define the model class (same as in the training script) | |
class Magical1Sun(nn.Module): | |
def __init__(self, num_classes, dropout_rate=0.1): | |
super(Magical1Sun, self).__init__() | |
self.sentence_transformer = SentenceTransformer('all-MiniLM-L12-v2') | |
self.dropout = nn.Dropout(dropout_rate) | |
self.classifier = nn.Sequential( | |
nn.Linear(384, 256), | |
nn.ReLU(), | |
nn.Dropout(dropout_rate), | |
nn.Linear(256, num_classes) | |
) | |
def forward(self, text): | |
embeddings = self.sentence_transformer.encode(text, convert_to_tensor=True) | |
embeddings = self.dropout(embeddings) | |
return self.classifier(embeddings) | |
# Load the trained model | |
def load_model(model_path): | |
model = Magical1Sun(num_classes=2) | |
state_dict = load_file(model_path) | |
model.load_state_dict(state_dict) | |
model.eval() | |
return model | |
# Prediction function | |
def predict(text): | |
with torch.no_grad(): | |
output = model(text) | |
probabilities = torch.softmax(output, dim=0) | |
positive_prob = probabilities[1].item() | |
negative_prob = probabilities[0].item() | |
prediction = "Positive" if positive_prob > negative_prob else "Negative" | |
confidence = max(positive_prob, negative_prob) | |
return { | |
"Prediction": prediction, | |
"Confidence": f"{confidence:.2%}", | |
"Positive Probability": f"{positive_prob:.2%}", | |
"Negative Probability": f"{negative_prob:.2%}" | |
} | |
# Load the model (make sure to replace 'path_to_your_model.safetensors' with the actual path) | |
model = load_model('magical_1_sun.safetensors') | |
# Create the Gradio interface | |
iface = gr.Interface( | |
fn=predict, | |
inputs=gr.Textbox(lines=3, placeholder="Enter text to classify..."), | |
outputs=[ | |
gr.Label(num_top_classes=1, label="Prediction"), | |
gr.Label(label="Confidence"), | |
gr.Label(label="Positive Probability"), | |
gr.Label(label="Negative Probability") | |
], | |
title="Magical-1 Sun Text Classification", | |
description="Enter a text to classify it as positive or negative.", | |
examples=[ | |
["I love this product! It's amazing!"], | |
["This is terrible. Worst purchase ever."], | |
["Great experience overall. Would buy again."], | |
["Never buying again. Complete waste of money."], | |
["Highly recommended! You won't regret it."] | |
] | |
) | |
# Launch the app | |
if __name__ == "__main__": | |
iface.launch() |