question-generator / server /inference.py
lucas-wa
Refactoring server
6558cd8
raw
history blame
939 Bytes
from langchain.schema.runnable import RunnablePassthrough
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_core.runnables import RunnableLambda
from llm.gemini import questions_template, format_questions_instructions, questions_parser
from data.load_data import retriever
def get_questions(_dict):
question = _dict["question"]
context = _dict["context"]
messages = questions_template.format_messages(
context=context,
question=question,
format_questions_instructions=format_questions_instructions,
)
chat = ChatGoogleGenerativeAI(model="gemini-pro")
response = chat.invoke(messages)
return questions_parser.parse(response.content)
def format_docs(docs):
return "\n\n".join(doc.page_content for doc in docs)
rag_chain = {
"context": retriever | RunnableLambda(format_docs),
"question": RunnablePassthrough(),
} | RunnableLambda(get_questions)