Spaces:
Sleeping
Sleeping
from transformers import AutoModelForSequenceClassification, AutoTokenizer | |
import torch | |
class IntentClassifier: | |
def __init__(self): | |
self.model_name = "distilbert-base-uncased-finetuned-sst-2-english" | |
self.model = AutoModelForSequenceClassification.from_pretrained(self.model_name, num_labels=2) | |
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) | |
self.intents = {0: "database_query", 1: "product_description"} | |
def classify(self, query): | |
inputs = self.tokenizer(query, return_tensors="pt", truncation=True, padding=True) | |
outputs = self.model(**inputs) | |
probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1) | |
predicted_class = torch.argmax(probabilities).item() | |
return self.intents[predicted_class], probabilities[0][predicted_class].item() | |