# https://python.langchain.com/docs/modules/chains/how_to/custom_chain
# Including reformulation of the question in the chain
import json

from langchain import PromptTemplate, LLMChain
from langchain.chains import QAWithSourcesChain
from langchain.chains import TransformChain, SequentialChain
from langchain.chains.qa_with_sources import load_qa_with_sources_chain

from anyqa.prompts import answer_prompt, reformulation_prompt
from anyqa.custom_retrieval_chain import CustomRetrievalQAWithSourcesChain


def load_qa_chain_with_docs(llm):
    """Load a QA chain with documents.
    Useful when you already have retrieved docs

    To be called with this input

    ```
    output = chain({
        "question":query,
        "audience":"experts scientists",
        "docs":docs,
        "language":"English",
    })
    ```
    """

    qa_chain = load_combine_documents_chain(llm)
    chain = QAWithSourcesChain(
        input_docs_key="docs",
        combine_documents_chain=qa_chain,
        return_source_documents=True,
    )
    return chain


def load_combine_documents_chain(llm):
    prompt = PromptTemplate(
        template=answer_prompt,
        input_variables=["summaries", "question", "audience", "language"],
    )
    qa_chain = load_qa_with_sources_chain(llm, chain_type="stuff", prompt=prompt)
    return qa_chain


def load_qa_chain_with_text(llm):
    prompt = PromptTemplate(
        template=answer_prompt,
        input_variables=["question", "audience", "language", "summaries"],
    )
    qa_chain = LLMChain(llm=llm, prompt=prompt)
    return qa_chain


def load_qa_chain(retriever, llm_reformulation, llm_answer):
    reformulation_chain = load_reformulation_chain(llm_reformulation)
    answer_chain = load_qa_chain_with_retriever(retriever, llm_answer)

    qa_chain = SequentialChain(
        chains=[reformulation_chain, answer_chain],
        input_variables=["query", "audience"],
        output_variables=["answer", "question", "language", "source_documents"],
        return_all=True,
        verbose=True,
    )
    return qa_chain


def load_reformulation_chain(llm):
    prompt = PromptTemplate(
        template=reformulation_prompt,
        input_variables=["query"],
    )
    reformulation_chain = LLMChain(llm=llm, prompt=prompt, output_key="json")

    # Parse the output
    def parse_output(output):
        query = output["query"]
        print("output", output)
        json_output = json.loads(output["json"])
        question = json_output.get("question", query)
        language = json_output.get("language", "English")
        return {
            "question": question,
            "language": language,
        }

    transform_chain = TransformChain(
        input_variables=["json"],
        output_variables=["question", "language"],
        transform=parse_output,
    )

    reformulation_chain = SequentialChain(
        chains=[reformulation_chain, transform_chain],
        input_variables=["query"],
        output_variables=["question", "language"],
    )
    return reformulation_chain


def load_qa_chain_with_retriever(retriever, llm):
    qa_chain = load_combine_documents_chain(llm)

    # This could be improved by providing a document prompt to avoid modifying page_content in the docs
    # See here https://github.com/langchain-ai/langchain/issues/3523

    answer_chain = CustomRetrievalQAWithSourcesChain(
        combine_documents_chain=qa_chain,
        retriever=retriever,
        return_source_documents=True,
        verbose=True,
        fallback_answer="**⚠️ No relevant passages found in the sources, you may want to ask a more specific question.**",
    )
    return answer_chain