shopify_1 / rag_system.py
nileshhanotia's picture
Create rag_system.py
3b0f177 verified
raw
history blame
2.43 kB
import os
import pandas as pd
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
from langchain.text_splitter import CharacterTextSplitter
from langchain.docstore.document import Document
from transformers import pipeline
from langchain.prompts import PromptTemplate
class RAGSystem:
def __init__(self, csv_path="apparel.csv"):
self.setup_system(csv_path)
self.qa_pipeline = pipeline("question-answering", model="distilbert-base-cased-distilled-squad")
def setup_system(self, csv_path):
if not os.path.exists(csv_path):
raise FileNotFoundError(f"CSV file not found at {csv_path}")
# Read the CSV file
documents = pd.read_csv(csv_path)
# Create proper Document objects
docs = [
Document(
page_content=str(row['Title']), # Convert to string to ensure compatibility
metadata={'index': idx}
)
for idx, row in documents.iterrows()
]
# Split documents
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
split_docs = text_splitter.split_documents(docs)
# Create embeddings and vector store
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
self.vector_store = FAISS.from_documents(split_docs, embeddings)
self.retriever = self.vector_store.as_retriever()
def process_query(self, query):
# Retrieve documents based on the query
retrieved_docs = self.retriever.get_relevant_documents(query) # Changed from invoke to get_relevant_documents
# Properly access page_content from Document objects
retrieved_text = "\n".join([doc.page_content for doc in retrieved_docs])[:1000]
# Process with QA pipeline
qa_input = {
"question": query,
"context": retrieved_text
}
response = self.qa_pipeline(qa_input)
return response['answer']
def get_similar_documents(self, query, k=5):
"""
Retrieve similar documents without processing through QA pipeline
"""
docs = self.retriever.get_relevant_documents(query)
return [{'content': doc.page_content, 'metadata': doc.metadata} for doc in docs[:k]]