Spaces:
Sleeping
Sleeping
tony-42069
commited on
Commit
·
d16e9aa
1
Parent(s):
836ede6
Add source code and test files
Browse files- api/__init__.py +8 -0
- api/function_app.py +71 -0
- api/requirements.txt +7 -0
- src/__init__.py +1 -0
- src/pdf_processor.py +112 -0
- src/rag_engine.py +131 -0
- tests/__init__.py +1 -0
- tests/test_pdf_processor.py +73 -0
- tests/test_rag_engine.py +112 -0
api/__init__.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import azure.functions as func
|
2 |
+
import streamlit as st
|
3 |
+
|
4 |
+
def main(req: func.HttpRequest) -> func.HttpResponse:
|
5 |
+
return func.HttpResponse(
|
6 |
+
"This is the API endpoint for the CRE Knowledge Assistant",
|
7 |
+
status_code=200
|
8 |
+
)
|
api/function_app.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import azure.functions as func
|
2 |
+
import logging
|
3 |
+
import json
|
4 |
+
from io import BytesIO
|
5 |
+
|
6 |
+
# Add the project root to Python path
|
7 |
+
import sys
|
8 |
+
import os
|
9 |
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
10 |
+
|
11 |
+
from app.config import validate_config
|
12 |
+
from app.logging import setup_logging
|
13 |
+
from src.pdf_processor import PDFProcessor
|
14 |
+
from src.rag_engine import RAGEngine
|
15 |
+
|
16 |
+
# Initialize components
|
17 |
+
setup_logging()
|
18 |
+
logger = logging.getLogger('app')
|
19 |
+
pdf_processor = PDFProcessor()
|
20 |
+
rag_engine = RAGEngine()
|
21 |
+
|
22 |
+
def process_pdf(req: func.HttpRequest) -> func.HttpResponse:
|
23 |
+
try:
|
24 |
+
# Get the PDF file from the request
|
25 |
+
pdf_file = req.files['file']
|
26 |
+
pdf_bytes = pdf_file.read()
|
27 |
+
|
28 |
+
# Process the PDF
|
29 |
+
pdf_processor.process(BytesIO(pdf_bytes))
|
30 |
+
|
31 |
+
return func.HttpResponse(
|
32 |
+
json.dumps({"message": "PDF processed successfully"}),
|
33 |
+
mimetype="application/json",
|
34 |
+
status_code=200
|
35 |
+
)
|
36 |
+
except Exception as e:
|
37 |
+
logger.error(f"Error processing PDF: {str(e)}")
|
38 |
+
return func.HttpResponse(
|
39 |
+
json.dumps({"error": str(e)}),
|
40 |
+
mimetype="application/json",
|
41 |
+
status_code=500
|
42 |
+
)
|
43 |
+
|
44 |
+
def query(req: func.HttpRequest) -> func.HttpResponse:
|
45 |
+
try:
|
46 |
+
# Get the query from request body
|
47 |
+
req_body = req.get_json()
|
48 |
+
user_query = req_body.get('query')
|
49 |
+
|
50 |
+
if not user_query:
|
51 |
+
return func.HttpResponse(
|
52 |
+
json.dumps({"error": "No query provided"}),
|
53 |
+
mimetype="application/json",
|
54 |
+
status_code=400
|
55 |
+
)
|
56 |
+
|
57 |
+
# Process query through RAG engine
|
58 |
+
answer = rag_engine.process_query(user_query)
|
59 |
+
|
60 |
+
return func.HttpResponse(
|
61 |
+
json.dumps({"answer": answer}),
|
62 |
+
mimetype="application/json",
|
63 |
+
status_code=200
|
64 |
+
)
|
65 |
+
except Exception as e:
|
66 |
+
logger.error(f"Error processing query: {str(e)}")
|
67 |
+
return func.HttpResponse(
|
68 |
+
json.dumps({"error": str(e)}),
|
69 |
+
mimetype="application/json",
|
70 |
+
status_code=500
|
71 |
+
)
|
api/requirements.txt
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
azure-functions==1.15.0
|
2 |
+
openai==1.6.1
|
3 |
+
python-dotenv==1.0.0
|
4 |
+
azure-cognitiveservices-language-textanalytics==0.2.0
|
5 |
+
PyPDF2==3.0.1
|
6 |
+
langchain==0.0.352
|
7 |
+
azure-storage-blob==12.19.0
|
src/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
|
src/pdf_processor.py
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
PDF processing module for extracting and chunking text from PDF documents.
|
3 |
+
"""
|
4 |
+
import logging
|
5 |
+
from typing import List, Tuple
|
6 |
+
import PyPDF2
|
7 |
+
from io import BytesIO
|
8 |
+
|
9 |
+
from app.config import MAX_CHUNK_SIZE, OVERLAP_SIZE
|
10 |
+
|
11 |
+
logger = logging.getLogger('pdf')
|
12 |
+
|
13 |
+
class PDFProcessor:
|
14 |
+
"""Handles PDF document processing and text chunking."""
|
15 |
+
|
16 |
+
@staticmethod
|
17 |
+
def extract_text(pdf_file: BytesIO) -> str:
|
18 |
+
"""Extract text content from a PDF file."""
|
19 |
+
try:
|
20 |
+
pdf_reader = PyPDF2.PdfReader(pdf_file)
|
21 |
+
text = ""
|
22 |
+
|
23 |
+
for page in pdf_reader.pages:
|
24 |
+
text += page.extract_text() + "\n"
|
25 |
+
|
26 |
+
logger.info(f"Successfully extracted text from PDF ({len(text)} characters)")
|
27 |
+
return text
|
28 |
+
|
29 |
+
except Exception as e:
|
30 |
+
logger.error(f"Error extracting text from PDF: {str(e)}")
|
31 |
+
raise
|
32 |
+
|
33 |
+
@staticmethod
|
34 |
+
def create_chunks(text: str, chunk_size: int = MAX_CHUNK_SIZE,
|
35 |
+
overlap: int = OVERLAP_SIZE) -> List[Tuple[str, dict]]:
|
36 |
+
"""Split text into overlapping chunks with metadata."""
|
37 |
+
try:
|
38 |
+
chunks = []
|
39 |
+
start = 0
|
40 |
+
|
41 |
+
while start < len(text):
|
42 |
+
# Find the end of the chunk
|
43 |
+
end = start + chunk_size
|
44 |
+
|
45 |
+
# If we're not at the end of the text, try to find a good break point
|
46 |
+
if end < len(text):
|
47 |
+
# Try to find the last period or newline in the chunk
|
48 |
+
last_period = text.rfind('.', start, end)
|
49 |
+
last_newline = text.rfind('\n', start, end)
|
50 |
+
break_point = max(last_period, last_newline)
|
51 |
+
|
52 |
+
if break_point > start:
|
53 |
+
end = break_point + 1
|
54 |
+
|
55 |
+
# Create chunk with metadata
|
56 |
+
chunk_text = text[start:end].strip()
|
57 |
+
if chunk_text: # Only add non-empty chunks
|
58 |
+
metadata = {
|
59 |
+
"start_char": start,
|
60 |
+
"end_char": end,
|
61 |
+
"chunk_size": len(chunk_text)
|
62 |
+
}
|
63 |
+
chunks.append((chunk_text, metadata))
|
64 |
+
|
65 |
+
# Move the start position, accounting for overlap
|
66 |
+
start = end - overlap if end < len(text) else len(text)
|
67 |
+
|
68 |
+
logger.info(f"Created {len(chunks)} chunks from text")
|
69 |
+
return chunks
|
70 |
+
|
71 |
+
except Exception as e:
|
72 |
+
logger.error(f"Error creating chunks: {str(e)}")
|
73 |
+
raise
|
74 |
+
|
75 |
+
@staticmethod
|
76 |
+
def clean_text(text: str) -> str:
|
77 |
+
"""Clean and normalize extracted text."""
|
78 |
+
try:
|
79 |
+
# Remove extra whitespace
|
80 |
+
text = ' '.join(text.split())
|
81 |
+
|
82 |
+
# Remove special characters that might cause issues
|
83 |
+
text = text.replace('\x00', '')
|
84 |
+
|
85 |
+
# Normalize newlines
|
86 |
+
text = text.replace('\r\n', '\n')
|
87 |
+
|
88 |
+
logger.info("Text cleaned successfully")
|
89 |
+
return text
|
90 |
+
|
91 |
+
except Exception as e:
|
92 |
+
logger.error(f"Error cleaning text: {str(e)}")
|
93 |
+
raise
|
94 |
+
|
95 |
+
def process_pdf(self, pdf_file: BytesIO) -> List[Tuple[str, dict]]:
|
96 |
+
"""Process PDF file and return chunks with metadata."""
|
97 |
+
try:
|
98 |
+
# Extract text from PDF
|
99 |
+
raw_text = self.extract_text(pdf_file)
|
100 |
+
|
101 |
+
# Clean the extracted text
|
102 |
+
cleaned_text = self.clean_text(raw_text)
|
103 |
+
|
104 |
+
# Create chunks
|
105 |
+
chunks = self.create_chunks(cleaned_text)
|
106 |
+
|
107 |
+
logger.info(f"PDF processed successfully: {len(chunks)} chunks created")
|
108 |
+
return chunks
|
109 |
+
|
110 |
+
except Exception as e:
|
111 |
+
logger.error(f"Error processing PDF: {str(e)}")
|
112 |
+
raise
|
src/rag_engine.py
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
RAG (Retrieval Augmented Generation) engine for the CRE Chatbot.
|
3 |
+
"""
|
4 |
+
import logging
|
5 |
+
import os
|
6 |
+
from typing import List, Dict, Any, Optional
|
7 |
+
|
8 |
+
import chromadb
|
9 |
+
from chromadb.config import Settings
|
10 |
+
from openai import AzureOpenAI
|
11 |
+
from app.config import (
|
12 |
+
AZURE_OPENAI_ENDPOINT,
|
13 |
+
AZURE_OPENAI_API_KEY, # Added this line
|
14 |
+
TEMPERATURE,
|
15 |
+
MAX_TOKENS,
|
16 |
+
AZURE_OPENAI_EMBEDDING_DEPLOYMENT_NAME
|
17 |
+
)
|
18 |
+
|
19 |
+
logger = logging.getLogger('rag')
|
20 |
+
|
21 |
+
class RAGEngine:
|
22 |
+
"""Handles document retrieval and question answering using Azure OpenAI."""
|
23 |
+
|
24 |
+
def __init__(self, deployment_name: str):
|
25 |
+
"""Initialize the RAG engine with Azure OpenAI client."""
|
26 |
+
self.client = AzureOpenAI(
|
27 |
+
api_key=AZURE_OPENAI_API_KEY,
|
28 |
+
api_version="2023-12-01-preview",
|
29 |
+
azure_endpoint=AZURE_OPENAI_ENDPOINT
|
30 |
+
)
|
31 |
+
self.deployment_name = deployment_name
|
32 |
+
self.embedding_deployment_name = AZURE_OPENAI_EMBEDDING_DEPLOYMENT_NAME
|
33 |
+
|
34 |
+
# Initialize ChromaDB with simple in-memory settings
|
35 |
+
self.chroma_client = chromadb.Client(Settings(anonymized_telemetry=False))
|
36 |
+
self.collection = None
|
37 |
+
self.initialize_vector_store("cre_docs")
|
38 |
+
logger.info("RAG Engine initialized with Azure OpenAI")
|
39 |
+
|
40 |
+
def create_embeddings(self, texts: List[str]) -> List[List[float]]:
|
41 |
+
"""Create embeddings for the given texts using Azure OpenAI."""
|
42 |
+
try:
|
43 |
+
response = self.client.embeddings.create(
|
44 |
+
input=texts,
|
45 |
+
model=self.embedding_deployment_name
|
46 |
+
)
|
47 |
+
return [item.embedding for item in response.data]
|
48 |
+
except Exception as e:
|
49 |
+
logger.error(f"Error creating embeddings: {str(e)}")
|
50 |
+
raise
|
51 |
+
|
52 |
+
def initialize_vector_store(self, collection_name: str):
|
53 |
+
"""Initialize or get the vector store collection."""
|
54 |
+
try:
|
55 |
+
self.collection = self.chroma_client.get_or_create_collection(
|
56 |
+
name=collection_name,
|
57 |
+
metadata={"hnsw:space": "cosine"}
|
58 |
+
)
|
59 |
+
logger.info(f"Vector store initialized with collection: {collection_name}")
|
60 |
+
except Exception as e:
|
61 |
+
logger.error(f"Error initializing vector store: {str(e)}")
|
62 |
+
raise
|
63 |
+
|
64 |
+
def add_documents(self, texts: List[str], metadata: Optional[List[Dict[str, Any]]] = None):
|
65 |
+
"""Add documents to the vector store."""
|
66 |
+
try:
|
67 |
+
if not self.collection:
|
68 |
+
raise ValueError("Vector store collection not initialized")
|
69 |
+
|
70 |
+
embeddings = self.create_embeddings(texts)
|
71 |
+
# Use timestamp + index as ID to ensure uniqueness
|
72 |
+
import time
|
73 |
+
timestamp = int(time.time())
|
74 |
+
ids = [f"{timestamp}_{i}" for i in range(len(texts))]
|
75 |
+
|
76 |
+
self.collection.add(
|
77 |
+
embeddings=embeddings,
|
78 |
+
documents=texts,
|
79 |
+
ids=ids,
|
80 |
+
metadatas=metadata if metadata else [{}] * len(texts)
|
81 |
+
)
|
82 |
+
logger.info(f"Added {len(texts)} documents to vector store")
|
83 |
+
except Exception as e:
|
84 |
+
logger.error(f"Error adding documents: {str(e)}")
|
85 |
+
raise
|
86 |
+
|
87 |
+
def query(self, question: str, k: int = 3) -> Dict[str, Any]:
|
88 |
+
"""Query the vector store and generate an answer."""
|
89 |
+
try:
|
90 |
+
# Create embedding for the question
|
91 |
+
question_embedding = self.create_embeddings([question])[0]
|
92 |
+
|
93 |
+
# Query vector store
|
94 |
+
results = self.collection.query(
|
95 |
+
query_embeddings=[question_embedding],
|
96 |
+
n_results=k
|
97 |
+
)
|
98 |
+
|
99 |
+
# Prepare context from retrieved documents
|
100 |
+
context = "\n".join(results['documents'][0])
|
101 |
+
|
102 |
+
# Generate answer using Azure OpenAI
|
103 |
+
messages = [
|
104 |
+
{"role": "system", "content": "You are a helpful assistant that answers questions about commercial real estate concepts. Use the provided context to answer questions accurately and concisely."},
|
105 |
+
{"role": "user", "content": f"Context: {context}\n\nQuestion: {question}"}
|
106 |
+
]
|
107 |
+
|
108 |
+
response = self.client.chat.completions.create(
|
109 |
+
model=self.deployment_name,
|
110 |
+
messages=messages,
|
111 |
+
temperature=TEMPERATURE,
|
112 |
+
max_tokens=MAX_TOKENS
|
113 |
+
)
|
114 |
+
|
115 |
+
answer = response.choices[0].message.content
|
116 |
+
|
117 |
+
return {
|
118 |
+
"answer": answer,
|
119 |
+
"context": context,
|
120 |
+
"source_documents": results['documents'][0]
|
121 |
+
}
|
122 |
+
|
123 |
+
except Exception as e:
|
124 |
+
logger.error(f"Error querying RAG engine: {str(e)}")
|
125 |
+
raise
|
126 |
+
|
127 |
+
def clear(self):
|
128 |
+
"""Clear the vector store collection."""
|
129 |
+
if self.collection:
|
130 |
+
self.collection.delete()
|
131 |
+
logger.info("Vector store collection cleared")
|
tests/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
|
tests/test_pdf_processor.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Tests for the PDF processor module.
|
3 |
+
"""
|
4 |
+
import pytest
|
5 |
+
from io import BytesIO
|
6 |
+
from src.pdf_processor import PDFProcessor
|
7 |
+
|
8 |
+
def test_clean_text():
|
9 |
+
"""Test text cleaning functionality."""
|
10 |
+
processor = PDFProcessor()
|
11 |
+
|
12 |
+
# Test removing extra whitespace
|
13 |
+
text = "This has extra spaces"
|
14 |
+
assert processor.clean_text(text) == "This has extra spaces"
|
15 |
+
|
16 |
+
# Test normalizing newlines
|
17 |
+
text = "Line1\r\nLine2\r\nLine3"
|
18 |
+
assert processor.clean_text(text) == "Line1 Line2 Line3"
|
19 |
+
|
20 |
+
# Test removing null characters
|
21 |
+
text = "Text with\x00null\x00chars"
|
22 |
+
assert processor.clean_text(text) == "Text with null chars"
|
23 |
+
|
24 |
+
def test_create_chunks():
|
25 |
+
"""Test text chunking functionality."""
|
26 |
+
processor = PDFProcessor()
|
27 |
+
|
28 |
+
# Test basic chunking
|
29 |
+
text = "This is a test. This is another test. And a final test."
|
30 |
+
chunks = processor.create_chunks(text, chunk_size=20, overlap=5)
|
31 |
+
|
32 |
+
assert len(chunks) > 0
|
33 |
+
assert all(isinstance(chunk, tuple) for chunk in chunks)
|
34 |
+
assert all(len(chunk) == 2 for chunk in chunks) # (text, metadata)
|
35 |
+
assert all(isinstance(chunk[1], dict) for chunk in chunks) # metadata is dict
|
36 |
+
|
37 |
+
def test_chunk_metadata():
|
38 |
+
"""Test chunk metadata creation."""
|
39 |
+
processor = PDFProcessor()
|
40 |
+
|
41 |
+
text = "Short test text."
|
42 |
+
chunks = processor.create_chunks(text, chunk_size=20, overlap=5)
|
43 |
+
|
44 |
+
assert len(chunks) == 1
|
45 |
+
chunk_text, metadata = chunks[0]
|
46 |
+
|
47 |
+
assert "start_char" in metadata
|
48 |
+
assert "end_char" in metadata
|
49 |
+
assert "chunk_size" in metadata
|
50 |
+
assert metadata["chunk_size"] == len(chunk_text)
|
51 |
+
|
52 |
+
def test_empty_text():
|
53 |
+
"""Test handling of empty text."""
|
54 |
+
processor = PDFProcessor()
|
55 |
+
|
56 |
+
chunks = processor.create_chunks("")
|
57 |
+
assert len(chunks) == 0
|
58 |
+
|
59 |
+
def test_chunk_overlap():
|
60 |
+
"""Test chunk overlap functionality."""
|
61 |
+
processor = PDFProcessor()
|
62 |
+
|
63 |
+
text = "This is a long text that should be split into multiple chunks with overlap."
|
64 |
+
chunks = processor.create_chunks(text, chunk_size=20, overlap=5)
|
65 |
+
|
66 |
+
# Check that chunks overlap
|
67 |
+
if len(chunks) > 1:
|
68 |
+
for i in range(len(chunks) - 1):
|
69 |
+
current_chunk = chunks[i][0]
|
70 |
+
next_chunk = chunks[i + 1][0]
|
71 |
+
|
72 |
+
# There should be some overlap between consecutive chunks
|
73 |
+
assert any(word in next_chunk for word in current_chunk.split()[-3:])
|
tests/test_rag_engine.py
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Tests for the RAG engine module.
|
3 |
+
"""
|
4 |
+
import pytest
|
5 |
+
from unittest.mock import Mock, patch
|
6 |
+
from src.rag_engine import RAGEngine
|
7 |
+
|
8 |
+
@pytest.fixture
|
9 |
+
def mock_azure_client():
|
10 |
+
"""Create a mock Azure OpenAI client."""
|
11 |
+
with patch('openai.AzureOpenAI') as mock_client:
|
12 |
+
yield mock_client
|
13 |
+
|
14 |
+
@pytest.fixture
|
15 |
+
def mock_chroma_client():
|
16 |
+
"""Create a mock Chroma client."""
|
17 |
+
with patch('chromadb.Client') as mock_client:
|
18 |
+
yield mock_client
|
19 |
+
|
20 |
+
@pytest.fixture
|
21 |
+
def rag_engine(mock_azure_client, mock_chroma_client):
|
22 |
+
"""Create a RAG engine instance with mocked dependencies."""
|
23 |
+
return RAGEngine("test-deployment")
|
24 |
+
|
25 |
+
def test_create_embeddings(rag_engine, mock_azure_client):
|
26 |
+
"""Test embedding creation."""
|
27 |
+
# Setup mock response
|
28 |
+
mock_response = Mock()
|
29 |
+
mock_response.data = [
|
30 |
+
Mock(embedding=[0.1, 0.2, 0.3]),
|
31 |
+
Mock(embedding=[0.4, 0.5, 0.6])
|
32 |
+
]
|
33 |
+
rag_engine.client.embeddings.create.return_value = mock_response
|
34 |
+
|
35 |
+
# Test
|
36 |
+
texts = ["Text 1", "Text 2"]
|
37 |
+
embeddings = rag_engine.create_embeddings(texts)
|
38 |
+
|
39 |
+
# Verify
|
40 |
+
assert len(embeddings) == 2
|
41 |
+
assert all(isinstance(emb, list) for emb in embeddings)
|
42 |
+
assert len(embeddings[0]) == 3 # Embedding dimension
|
43 |
+
|
44 |
+
def test_initialize_vector_store(rag_engine):
|
45 |
+
"""Test vector store initialization."""
|
46 |
+
rag_engine.initialize_vector_store("test_collection")
|
47 |
+
|
48 |
+
# Verify the collection was created
|
49 |
+
assert rag_engine.collection is not None
|
50 |
+
|
51 |
+
def test_add_documents(rag_engine):
|
52 |
+
"""Test adding documents to vector store."""
|
53 |
+
# Setup
|
54 |
+
rag_engine.initialize_vector_store("test_collection")
|
55 |
+
texts = ["Document 1", "Document 2"]
|
56 |
+
metadata = [{"source": "test1"}, {"source": "test2"}]
|
57 |
+
|
58 |
+
# Create mock embeddings
|
59 |
+
with patch.object(rag_engine, 'create_embeddings') as mock_create_embeddings:
|
60 |
+
mock_create_embeddings.return_value = [[0.1, 0.2], [0.3, 0.4]]
|
61 |
+
|
62 |
+
# Test
|
63 |
+
rag_engine.add_documents(texts, metadata)
|
64 |
+
|
65 |
+
# Verify
|
66 |
+
mock_create_embeddings.assert_called_once_with(texts)
|
67 |
+
assert rag_engine.collection.add.called
|
68 |
+
|
69 |
+
def test_query(rag_engine):
|
70 |
+
"""Test querying the RAG engine."""
|
71 |
+
# Setup
|
72 |
+
rag_engine.initialize_vector_store("test_collection")
|
73 |
+
|
74 |
+
# Mock embeddings creation
|
75 |
+
with patch.object(rag_engine, 'create_embeddings') as mock_create_embeddings:
|
76 |
+
mock_create_embeddings.return_value = [[0.1, 0.2]]
|
77 |
+
|
78 |
+
# Mock vector store query
|
79 |
+
mock_results = {
|
80 |
+
'documents': [["Relevant document 1", "Relevant document 2"]],
|
81 |
+
'distances': [[0.1, 0.2]]
|
82 |
+
}
|
83 |
+
rag_engine.collection.query.return_value = mock_results
|
84 |
+
|
85 |
+
# Mock chat completion
|
86 |
+
mock_response = Mock()
|
87 |
+
mock_response.choices = [Mock(message=Mock(content="Test answer"))]
|
88 |
+
rag_engine.client.chat.completions.create.return_value = mock_response
|
89 |
+
|
90 |
+
# Test
|
91 |
+
result = rag_engine.query("Test question")
|
92 |
+
|
93 |
+
# Verify
|
94 |
+
assert isinstance(result, dict)
|
95 |
+
assert "answer" in result
|
96 |
+
assert "context" in result
|
97 |
+
assert "source_documents" in result
|
98 |
+
assert result["answer"] == "Test answer"
|
99 |
+
|
100 |
+
def test_error_handling(rag_engine):
|
101 |
+
"""Test error handling in RAG engine."""
|
102 |
+
# Test error in embeddings creation
|
103 |
+
rag_engine.client.embeddings.create.side_effect = Exception("API Error")
|
104 |
+
|
105 |
+
with pytest.raises(Exception):
|
106 |
+
rag_engine.create_embeddings(["Test"])
|
107 |
+
|
108 |
+
# Test error in vector store initialization
|
109 |
+
rag_engine.chroma_client.get_or_create_collection.side_effect = Exception("DB Error")
|
110 |
+
|
111 |
+
with pytest.raises(Exception):
|
112 |
+
rag_engine.initialize_vector_store("test")
|