shopify_1 / intent_classifier.py
nileshhanotia's picture
Create intent_classifier.py
8cd9024 verified
raw
history blame
730 Bytes
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import torch
class IntentClassifier:
def __init__(self):
# Define the intents directly
self.intents = {"SQL Query": "database_query", "ask a question": "product_description"}
def classify(self, queryType):
# Map the dropdown selection to the appropriate intent
if queryType == "sql_query":
return self.intents["sql_query"], 1.0 # Full confidence for SQL queries
elif queryType == "ask_question":
return self.intents["ask_question"], 1.0 # Full confidence for product description
# Handle unexpected input (optional)
return None, 0.0 # No intent matched