Spaces:
Runtime error
Runtime error
import os | |
import pandas as pd | |
from langchain_huggingface import HuggingFaceEmbeddings | |
from langchain_community.vectorstores import FAISS | |
from langchain.text_splitter import CharacterTextSplitter | |
from langchain.docstore.document import Document | |
from transformers import pipeline | |
from langchain.prompts import PromptTemplate | |
class RAGSystem: | |
def __init__(self, csv_path="apparel.csv"): | |
self.setup_system(csv_path) | |
self.qa_pipeline = pipeline("question-answering", model="distilbert-base-cased-distilled-squad") | |
def setup_system(self, csv_path): | |
if not os.path.exists(csv_path): | |
raise FileNotFoundError(f"CSV file not found at {csv_path}") | |
# Read the CSV file | |
documents = pd.read_csv(csv_path) | |
# Create proper Document objects | |
docs = [ | |
Document( | |
page_content=str(row['Title']), # Convert to string to ensure compatibility | |
metadata={'index': idx} | |
) | |
for idx, row in documents.iterrows() | |
] | |
# Split documents | |
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=100) | |
split_docs = text_splitter.split_documents(docs) | |
# Create embeddings and vector store | |
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") | |
self.vector_store = FAISS.from_documents(split_docs, embeddings) | |
self.retriever = self.vector_store.as_retriever() | |
def process_query(self, query): | |
# Retrieve documents based on the query | |
retrieved_docs = self.retriever.get_relevant_documents(query) # Changed from invoke to get_relevant_documents | |
# Properly access page_content from Document objects | |
retrieved_text = "\n".join([doc.page_content for doc in retrieved_docs])[:1000] | |
# Process with QA pipeline | |
qa_input = { | |
"question": query, | |
"context": retrieved_text | |
} | |
response = self.qa_pipeline(qa_input) | |
return response['answer'] | |
def get_similar_documents(self, query, k=5): | |
""" | |
Retrieve similar documents without processing through QA pipeline | |
""" | |
docs = self.retriever.get_relevant_documents(query) | |
return [{'content': doc.page_content, 'metadata': doc.metadata} for doc in docs[:k]] |