hello-sajal / src /graph.py
sajal2692's picture
add hfspace files
274be20
"""Implements the graph to handle workflows for the Sajal assistant"""
from typing import Dict, TypedDict
from chains.intent_detection import IntentDetection
from chains.smalltalk import Smalltalk
from chains.document_grader import DocumentGrader
from chains.rephrase_question import RephraseQuestion
from chains.qa_all_data import QAAllData
from chains.rag import RAG
from retriever import Retriever
from langgraph.graph import END, StateGraph
class GraphState(TypedDict):
"""
Represents the state of our graph.
Attributes:
keys: A dictionary where each key is a string.
"""
keys: Dict[str, any]
class AssistantGraph:
"""Implements the graph to handle workflows for the Sajal assistant"""
def __init__(self, llm, vector_db_path, source_data_path):
self.intent_detector = IntentDetection(llm)
self.smalltalk = Smalltalk(llm)
self.document_grader = DocumentGrader()
self.rephrase_question_chain = RephraseQuestion(llm)
self.retriever = Retriever(vector_db_path=vector_db_path)
self.qa_all_data = QAAllData(llm=llm, source_data_path=source_data_path)
self.rag = RAG(llm)
self.app = self.compile_graph()
def run(self, inputs):
return self.app.invoke(inputs)
# define graph nodes and edges and compile graph
def compile_graph(self):
workflow = StateGraph(GraphState)
### define the nodes
workflow.add_node("detect_intent", self.detect_intent)
workflow.add_node("chat", self.chat)
workflow.add_node("rephrase_question", self.rephrase_question)
workflow.add_node("retrieve", self.retrieve)
workflow.add_node("grade_documents", self.grade_documents)
workflow.add_node("generate_answer_with_retrieved_documents", self.generate_answer_with_retrieved_documents)
workflow.add_node("generate_answer_using_all_data", self.generate_answer_using_all_data)
### build the graph
workflow.set_entry_point("detect_intent")
workflow.add_conditional_edges(
"detect_intent",
self.decide_to_rag,
{
"rag": "rephrase_question",
"chat": "chat",
}
)
workflow.add_edge("rephrase_question", "retrieve")
workflow.add_edge("retrieve", "grade_documents")
workflow.add_conditional_edges(
"grade_documents",
self.decide_to_use_all_data,
{
"rag": "generate_answer_with_retrieved_documents",
"generate_answer_using_all_data": "generate_answer_using_all_data",
}
)
workflow.add_edge("generate_answer_with_retrieved_documents", END)
workflow.add_edge("generate_answer_using_all_data", END)
workflow.add_edge("chat", END)
### compile the graph
app = workflow.compile()
return app
# define the nodes
def detect_intent(self, state):
"""
Detects the intent of a user's message
Args:
state (dict): The current graph state
Returns:
state (dict): New key added to state, intent, that contains the detected intent
"""
state = state["keys"]
message = state["message"]
history = state["history"]
intent = self.intent_detector.run(message=message, history=history)
return {"keys": {"message": message, "intent": intent, "history": history}}
def chat(self, state):
"""
Chat with the user
Args:
state (dict): The current graph state
Returns:
str: Updated graph state after adding response
"""
state = state["keys"]
input = state["message"]
history = state["history"]
response = self.smalltalk.run(message=input, history=history)
return {"keys": {"message": input, "history": history, "response": response}}
def grade_documents(self, state):
"""
Determines whether the retrieved documents are relevant to the question.
Args:
state (dict): The current graph state
Returns:
state (dict): Updates documents key with relevant documents
"""
print("---CHECK RELEVANCE---")
state = state["keys"]
question = state["standalone_question"]
documents = state["documents"]
# Score
filtered_docs = []
all_data = False # Default do not opt to use all data for generation
for d in documents:
score = self.document_grader.run(question=question, context=d.page_content)
grade = score[0].binary_score
if grade == "yes":
print("---GRADE: FOUND RELEVANT DOCUMENT---")
filtered_docs.append(d)
if not filtered_docs:
all_data = True # Opt to use all data for generation
return {
"keys": {
"documents": filtered_docs,
"standalone_question": question,
"run_with_all_data": all_data,
}
}
def rephrase_question(self, state):
"""
Rephrase the question to be a standalone question
Args:
state (dict): The current graph state
Returns:
str: Updated graph state after adding standalone question
"""
state = state["keys"]
question = state["message"]
chat_history = state["history"]
result = self.rephrase_question_chain.run(message=question, history=chat_history)
return {"keys": {"message": question, "history": chat_history, "standalone_question": result}}
def retrieve(self, state):
"""
Retrieve documents
Args:
state (dict): The current graph state
Returns:
state (dict): New key added to state, documents, that contains retrieved documents
"""
state = state["keys"]
question = state["standalone_question"]
chat_history = state["history"]
documents = self.retriever.run(query=question)
return {"keys": {"message": state["message"], "history": chat_history, "standalone_question": question, "documents": documents}}
def generate_answer_using_all_data(self, state):
"""
Generate an answer using all documents
Args:
state (dict): The current graph state
Returns:
str: Updated graph state after adding response
"""
state = state["keys"]
question = state["standalone_question"]
response = self.qa_all_data.run(question=question)
return {"keys": {"message": question, "response": response}}
def generate_answer_with_retrieved_documents(self, state):
"""
Generate an answer using the retrieved documents
Args:
state (dict): The current graph state
Returns:
str: Updated graph state after adding response
"""
state = state["keys"]
question = state["standalone_question"]
documents = state["documents"]
response = self.rag.run(question=question, documents=documents)
return {"keys": {"message": question, "response": response}}
# define the edges
def decide_to_rag(self, state):
"""
Decides whether to use RAG or not
Args:
state (dict): The current graph state
Returns:
str: Next node to call
"""
state = state["keys"]
intent = state["intent"]
if intent == "sajal_question":
return "rag"
return "chat"
def decide_to_use_all_data(self, state):
"""
Determines whether to use all data for generation or not.
Args:
state (dict): The current state of the agent, including all keys.
Returns:
str: Next node to call
"""
state = state["keys"]
run_with_all_data = state["run_with_all_data"]
if run_with_all_data:
return "generate_answer_using_all_data"
else:
return "rag"