import torch from transformers import AutoModelForSequenceClassification, AutoTokenizer # Load pre-trained model and tokenizer from Hugging Face MODEL_NAME = "distilbert-base-uncased-finetuned-sst-2-english" # Example, change to other open-source models if necessary model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=2) tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) # Define the intents intents = {0: "database_query", 1: "product_description"} # Function to classify query intent def classify_intent(query): inputs = tokenizer(query, return_tensors="pt", truncation=True, padding=True) outputs = model(**inputs) probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1) predicted_class = torch.argmax(probabilities).item() return intents[predicted_class], probabilities[0][predicted_class].item() # Example usage query_1 = "Fetch all products with the keyword 'T-shirt' from the database." query_2 = "Can you tell me about the description of this Shopify store?" intent_1, confidence_1 = classify_intent(query_1) intent_2, confidence_2 = classify_intent(query_2) print(f"Query 1: '{query_1}'\nIntent: {intent_1} with confidence {confidence_1}\n") print(f"Query 2: '{query_2}'\nIntent: {intent_2} with confidence {confidence_2}\n") # Further routing based on classified intent def handle_query(query): intent, confidence = classify_intent(query) if intent == "database_query": # Call the natural language to SQL engine return execute_database_query(query) elif intent == "product_description": # Call the RAG engine for product descriptionß return execute_rag_query(query) else: return "Intent not recognized." # Placeholder functions for database and RAG query handling def execute_database_query(query): # Integrate with SQL-based natural language query generator return "Executing database query..." def execute_rag_query(query): # Integrate with RAG pipeline to retrieve product descriptions return "Executing RAG product description query..." # Test the function with different queries print(handle_query(query_1)) print(handle_query(query_2))