File size: 4,134 Bytes
fbfbbd7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import os
from typing import List, Dict
from dotenv import load_dotenv
import chromadb
from langchain.embeddings import AzureOpenAIEmbeddings
from langchain.vectorstores import Chroma
from langchain.chat_models import AzureChatOpenAI
from langchain.chains import RetrievalQA
import time

# Load environment variables
load_dotenv()

class RAGEngine:
    def __init__(self):
        # Verify Azure OpenAI settings are set
        required_vars = [
            'AZURE_OPENAI_ENDPOINT',
            'AZURE_OPENAI_KEY',
            'AZURE_OPENAI_DEPLOYMENT_NAME',
            'AZURE_OPENAI_EMBEDDING_DEPLOYMENT_NAME'
        ]
        
        missing_vars = [var for var in required_vars if not os.getenv(var)]
        if missing_vars:
            raise ValueError(f"Missing required Azure OpenAI settings: {', '.join(missing_vars)}")
        
        # Initialize with retry mechanism
        max_retries = 3
        for attempt in range(max_retries):
            try:
                self.embeddings = AzureOpenAIEmbeddings(
                    azure_endpoint=os.getenv('AZURE_OPENAI_ENDPOINT'),
                    azure_deployment=os.getenv('AZURE_OPENAI_EMBEDDING_DEPLOYMENT_NAME'),
                    api_key=os.getenv('AZURE_OPENAI_KEY')
                )
                self.vector_store = None
                self.qa_chain = None
                # Test connection
                self.embeddings.embed_query("test")
                break
            except Exception as e:
                if attempt == max_retries - 1:
                    raise ConnectionError(f"Failed to connect to Azure OpenAI API after {max_retries} attempts. Error: {str(e)}")
                time.sleep(2)  # Wait before retrying
        
    def initialize_vector_store(self, chunks: List[Dict]):
        """
        Initialize the vector store with document chunks.
        
        Args:
            chunks (List[Dict]): List of dictionaries containing text and metadata
        """
        texts = [chunk['text'] for chunk in chunks]
        metadatas = [chunk['metadata'] for chunk in chunks]
        
        # Create vector store
        self.vector_store = Chroma.from_texts(
            texts=texts,
            embedding=self.embeddings,
            metadatas=metadatas
        )
        
        # Initialize QA chain
        llm = AzureChatOpenAI(temperature=0, model_name="gpt-3.5-turbo", azure_deployment_name=os.getenv('AZURE_OPENAI_DEPLOYMENT_NAME'), azure_endpoint=os.getenv('AZURE_OPENAI_ENDPOINT'), api_key=os.getenv('AZURE_OPENAI_KEY'))
        self.qa_chain = RetrievalQA.from_chain_type(
            llm=llm,
            chain_type="stuff",
            retriever=self.vector_store.as_retriever(
                search_kwargs={"k": 3}
            )
        )
    
    def answer_question(self, question: str) -> Dict:
        """
        Answer a question using the RAG system.
        
        Args:
            question (str): User's question
            
        Returns:
            Dict: Answer and source information
        """
        if not self.qa_chain:
            raise ValueError("Vector store not initialized. Please process documents first.")
        
        # Create a prompt that emphasizes definition extraction
        prompt = f"""
        Question: {question}
        Please provide a clear and concise answer based on the provided context. 
        If the question asks for a definition or explanation of a concept, 
        make sure to provide that specifically. Include relevant examples or 
        additional context only if they help clarify the concept.
        """
        
        # Get answer from QA chain
        result = self.qa_chain({"query": prompt})
        
        # Get source documents
        source_docs = self.vector_store.similarity_search(question, k=2)
        sources = [
            {
                'page': doc.metadata['page'],
                'text': doc.page_content[:200] + "..."  # Preview of source text
            }
            for doc in source_docs
        ]
        
        return {
            'answer': result['result'],
            'sources': sources
        }