Spaces:
Sleeping
Sleeping
"""Implements the graph to handle workflows for the Sajal assistant""" | |
from typing import Dict, TypedDict | |
from chains.intent_detection import IntentDetection | |
from chains.smalltalk import Smalltalk | |
from chains.document_grader import DocumentGrader | |
from chains.rephrase_question import RephraseQuestion | |
from chains.qa_all_data import QAAllData | |
from chains.rag import RAG | |
from retriever import Retriever | |
from langgraph.graph import END, StateGraph | |
class GraphState(TypedDict): | |
""" | |
Represents the state of our graph. | |
Attributes: | |
keys: A dictionary where each key is a string. | |
""" | |
keys: Dict[str, any] | |
class AssistantGraph: | |
"""Implements the graph to handle workflows for the Sajal assistant""" | |
def __init__(self, llm, vector_db_path, source_data_path): | |
self.intent_detector = IntentDetection(llm) | |
self.smalltalk = Smalltalk(llm) | |
self.document_grader = DocumentGrader() | |
self.rephrase_question_chain = RephraseQuestion(llm) | |
self.retriever = Retriever(vector_db_path=vector_db_path) | |
self.qa_all_data = QAAllData(llm=llm, source_data_path=source_data_path) | |
self.rag = RAG(llm) | |
self.app = self.compile_graph() | |
def run(self, inputs): | |
return self.app.invoke(inputs) | |
# define graph nodes and edges and compile graph | |
def compile_graph(self): | |
workflow = StateGraph(GraphState) | |
### define the nodes | |
workflow.add_node("detect_intent", self.detect_intent) | |
workflow.add_node("chat", self.chat) | |
workflow.add_node("rephrase_question", self.rephrase_question) | |
workflow.add_node("retrieve", self.retrieve) | |
workflow.add_node("grade_documents", self.grade_documents) | |
workflow.add_node("generate_answer_with_retrieved_documents", self.generate_answer_with_retrieved_documents) | |
workflow.add_node("generate_answer_using_all_data", self.generate_answer_using_all_data) | |
### build the graph | |
workflow.set_entry_point("detect_intent") | |
workflow.add_conditional_edges( | |
"detect_intent", | |
self.decide_to_rag, | |
{ | |
"rag": "rephrase_question", | |
"chat": "chat", | |
} | |
) | |
workflow.add_edge("rephrase_question", "retrieve") | |
workflow.add_edge("retrieve", "grade_documents") | |
workflow.add_conditional_edges( | |
"grade_documents", | |
self.decide_to_use_all_data, | |
{ | |
"rag": "generate_answer_with_retrieved_documents", | |
"generate_answer_using_all_data": "generate_answer_using_all_data", | |
} | |
) | |
workflow.add_edge("generate_answer_with_retrieved_documents", END) | |
workflow.add_edge("generate_answer_using_all_data", END) | |
workflow.add_edge("chat", END) | |
### compile the graph | |
app = workflow.compile() | |
return app | |
# define the nodes | |
def detect_intent(self, state): | |
""" | |
Detects the intent of a user's message | |
Args: | |
state (dict): The current graph state | |
Returns: | |
state (dict): New key added to state, intent, that contains the detected intent | |
""" | |
state = state["keys"] | |
message = state["message"] | |
history = state["history"] | |
intent = self.intent_detector.run(message=message, history=history) | |
return {"keys": {"message": message, "intent": intent, "history": history}} | |
def chat(self, state): | |
""" | |
Chat with the user | |
Args: | |
state (dict): The current graph state | |
Returns: | |
str: Updated graph state after adding response | |
""" | |
state = state["keys"] | |
input = state["message"] | |
history = state["history"] | |
response = self.smalltalk.run(message=input, history=history) | |
return {"keys": {"message": input, "history": history, "response": response}} | |
def grade_documents(self, state): | |
""" | |
Determines whether the retrieved documents are relevant to the question. | |
Args: | |
state (dict): The current graph state | |
Returns: | |
state (dict): Updates documents key with relevant documents | |
""" | |
print("---CHECK RELEVANCE---") | |
state = state["keys"] | |
question = state["standalone_question"] | |
documents = state["documents"] | |
# Score | |
filtered_docs = [] | |
all_data = False # Default do not opt to use all data for generation | |
for d in documents: | |
score = self.document_grader.run(question=question, context=d.page_content) | |
grade = score[0].binary_score | |
if grade == "yes": | |
print("---GRADE: FOUND RELEVANT DOCUMENT---") | |
filtered_docs.append(d) | |
if not filtered_docs: | |
all_data = True # Opt to use all data for generation | |
return { | |
"keys": { | |
"documents": filtered_docs, | |
"standalone_question": question, | |
"run_with_all_data": all_data, | |
} | |
} | |
def rephrase_question(self, state): | |
""" | |
Rephrase the question to be a standalone question | |
Args: | |
state (dict): The current graph state | |
Returns: | |
str: Updated graph state after adding standalone question | |
""" | |
state = state["keys"] | |
question = state["message"] | |
chat_history = state["history"] | |
result = self.rephrase_question_chain.run(message=question, history=chat_history) | |
return {"keys": {"message": question, "history": chat_history, "standalone_question": result}} | |
def retrieve(self, state): | |
""" | |
Retrieve documents | |
Args: | |
state (dict): The current graph state | |
Returns: | |
state (dict): New key added to state, documents, that contains retrieved documents | |
""" | |
state = state["keys"] | |
question = state["standalone_question"] | |
chat_history = state["history"] | |
documents = self.retriever.run(query=question) | |
return {"keys": {"message": state["message"], "history": chat_history, "standalone_question": question, "documents": documents}} | |
def generate_answer_using_all_data(self, state): | |
""" | |
Generate an answer using all documents | |
Args: | |
state (dict): The current graph state | |
Returns: | |
str: Updated graph state after adding response | |
""" | |
state = state["keys"] | |
question = state["standalone_question"] | |
response = self.qa_all_data.run(question=question) | |
return {"keys": {"message": question, "response": response}} | |
def generate_answer_with_retrieved_documents(self, state): | |
""" | |
Generate an answer using the retrieved documents | |
Args: | |
state (dict): The current graph state | |
Returns: | |
str: Updated graph state after adding response | |
""" | |
state = state["keys"] | |
question = state["standalone_question"] | |
documents = state["documents"] | |
response = self.rag.run(question=question, documents=documents) | |
return {"keys": {"message": question, "response": response}} | |
# define the edges | |
def decide_to_rag(self, state): | |
""" | |
Decides whether to use RAG or not | |
Args: | |
state (dict): The current graph state | |
Returns: | |
str: Next node to call | |
""" | |
state = state["keys"] | |
intent = state["intent"] | |
if intent == "sajal_question": | |
return "rag" | |
return "chat" | |
def decide_to_use_all_data(self, state): | |
""" | |
Determines whether to use all data for generation or not. | |
Args: | |
state (dict): The current state of the agent, including all keys. | |
Returns: | |
str: Next node to call | |
""" | |
state = state["keys"] | |
run_with_all_data = state["run_with_all_data"] | |
if run_with_all_data: | |
return "generate_answer_using_all_data" | |
else: | |
return "rag" | |