from transformers import pipeline, AutoTokenizer, AutoModel from torchvision import models, transforms from PIL import Image import faiss class TextClassifier: def __init__(self, model_name='distilbert-base-uncased'): self.classifier = pipeline("text-classification", model=model_name) def classify(self, text): return self.classifier(text)[0]['label'] class SentimentAnalyzer: def __init__(self, model_name='nlptown/bert-base-multilingual-uncased-sentiment'): self.analyzer = pipeline("sentiment-analysis", model=model_name) def analyze(self, text): return self.analyzer(text)[0] class ImageRecognizer: def __init__(self, model_name='resnet50'): self.model = models.resnet50(pretrained=True) self.model.eval() self.transform = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) def recognize(self, image_path): image = Image.open(image_path) image = self.transform(image).unsqueeze(0) with torch.no_grad(): outputs = self.model(image) _, predicted = torch.max(outputs, 1) return predicted.item() class TextGenerator: def __init__(self, model_name='gpt2'): self.generator = pipeline("text-generation", model=model_name) def generate(self, prompt): response = self.generator(prompt, max_length=100, num_return_sequences=1) return response[0]['generated_text'] class FAQRetriever: def __init__(self, model_name='sentence-transformers/all-MiniLM-L6-v2'): self.tokenizer = AutoTokenizer.from_pretrained(model_name) self.model = AutoModel.from_pretrained(model_name) self.index = faiss.IndexFlatL2(384) # Dimension of MiniLM embeddings def embed(self, text): inputs = self.tokenizer(text, return_tensors='pt', truncation=True, padding=True) with torch.no_grad(): embeddings = self.model(**inputs).last_hidden_state.mean(dim=1) return embeddings.cpu().numpy() def add_faqs(self, faqs): self.faq_embeddings = np.concatenate([self.embed(faq) for faq in faqs]) faiss.normalize_L2(self.faq_embeddings) self.index.add(self.faq_embeddings) def retrieve(self, query): query_embedding = self.embed(query) faiss.normalize_L2(query_embedding) D, I = self.index.search(query_embedding, 5) return I[0] # Return top 5 FAQ indices class CustomerSupportAssistant: def __init__(self): self.text_classifier = TextClassifier() self.sentiment_analyzer = SentimentAnalyzer() self.image_recognizer = ImageRecognizer() self.text_generator = TextGenerator() self.faq_retriever = FAQRetriever() self.faqs = [ "How to reset my password?", "What is the return policy?", "How to track my order?", "How to contact customer support?", "What payment methods are accepted?" ] self.faq_retriever.add_faqs(self.faqs) def process_query(self, text, image_path=None): topic = self.text_classifier.classify(text) sentiment = self.sentiment_analyzer.analyze(text) if image_path: image_info = self.image_recognizer.recognize(image_path) else: image_info = "No image provided." faqs = self.faq_retriever.retrieve(text) faq_responses = [self.faqs[i] for i in faqs] response_prompt = f"Topic: {topic}, Sentiment: {sentiment['label']} with confidence {sentiment['score']}. FAQs: {faq_responses}. Image info: {image_info}. Generate a response." response = self.text_generator.generate(response_prompt) return response # Example usage: assistant = CustomerSupportAssistant() input_text = "I'm having trouble with my recent order." output = assistant.process_query(input_text) print(output)