Spaces:
Paused
Paused
##################################################### | |
### DOCUMENT PROCESSOR [ENGINE] | |
##################################################### | |
# Jonathan Wang | |
# ABOUT: | |
# This project creates an app to chat with PDFs. | |
# This is the ENGINE | |
# which defines how LLMs handle processing. | |
##################################################### | |
## TODO Board: | |
##################################################### | |
## IMPORTS | |
from __future__ import annotations | |
import gc | |
from typing import TYPE_CHECKING, Callable, List, Optional, cast | |
from llama_index.core.query_engine import CustomQueryEngine | |
from llama_index.core.schema import NodeWithScore, QueryBundle | |
from llama_index.core.settings import ( | |
Settings, | |
) | |
from torch.cuda import empty_cache | |
if TYPE_CHECKING: | |
from llama_index.core.base.response.schema import Response | |
from llama_index.core.callbacks import CallbackManager | |
from llama_index.core.postprocessor.types import BaseNodePostprocessor | |
from llama_index.core.response_synthesizers import ( | |
BaseSynthesizer, | |
) | |
from llama_index.core.retrievers import BaseRetriever | |
# Own Modules | |
##################################################### | |
## CODE | |
class RAGQueryEngine(CustomQueryEngine): | |
"""Custom RAG Query Engine.""" | |
retriever: BaseRetriever | |
response_synthesizer: BaseSynthesizer | |
node_postprocessors: Optional[List[BaseNodePostprocessor]] = [] | |
# def __init__( | |
# self, | |
# retriever: BaseRetriever, | |
# response_synthesizer: Optional[BaseSynthesizer] = None, | |
# node_postprocessors: Optional[List[BaseNodePostprocessor]] = None, | |
# callback_manager: Optional[CallbackManager] = None, | |
# ) -> None: | |
# self._retriever = retriever | |
# # callback_manager = ( | |
# # callback_manager | |
# # Settings.callback_manager | |
# # ) | |
# # llm = llm or Settings.llm | |
# self._response_synthesizer = response_synthesizer or get_response_synthesizer( | |
# # llm=llm, | |
# # service_context=service_context, | |
# # callback_manager=callback_manager, | |
# ) | |
# self._node_postprocessors = node_postprocessors or [] | |
# self._metadata_mode = metadata_mode | |
# for node_postprocessor in self._node_postprocessors: | |
# node_postprocessor.callback_manager = callback_manager | |
# super().__init__(callback_manager=callback_manager) | |
def class_name(cls) -> str: | |
"""Class name.""" | |
return "RAGQueryEngine" | |
# taken from Llamaindex CustomEngine: | |
# https://github.com/run-llama/llama_index/blob/main/llama-index-core/llama_index/core/query_engine/retriever_query_engine.py#L134 | |
def _apply_node_postprocessors( | |
self, nodes: list[NodeWithScore], query_bundle: QueryBundle | |
) -> list[NodeWithScore]: | |
if self.node_postprocessors is None: | |
return nodes | |
for node_postprocessor in self.node_postprocessors: | |
nodes = node_postprocessor.postprocess_nodes( | |
nodes, query_bundle=query_bundle | |
) | |
return nodes | |
def retrieve(self, query_bundle: QueryBundle) -> list[NodeWithScore]: | |
nodes = self.retriever.retrieve(query_bundle) | |
return self._apply_node_postprocessors(nodes, query_bundle=query_bundle) | |
async def aretrieve(self, query_bundle: QueryBundle) -> list[NodeWithScore]: | |
nodes = await self.retriever.aretrieve(query_bundle) | |
return self._apply_node_postprocessors(nodes, query_bundle=query_bundle) | |
def custom_query(self, query_str: str) -> Response: | |
# Convert query string into query bundle | |
query_bundle = QueryBundle(query_str=query_str) | |
nodes = self.retrieve(query_bundle) # also does the postprocessing. | |
response_obj = self.response_synthesizer.synthesize(query_bundle, nodes) | |
empty_cache() | |
gc.collect() | |
return cast(Response, response_obj) # type: ignore | |
# @st.cache_resource # none of these can be hashable or cached :( | |
def get_engine( | |
retriever: BaseRetriever, | |
response_synthesizer: BaseSynthesizer, | |
node_postprocessors: list[BaseNodePostprocessor] | None = None, | |
callback_manager: CallbackManager | None = None, | |
) -> RAGQueryEngine: | |
return RAGQueryEngine( | |
retriever=retriever, | |
response_synthesizer=response_synthesizer, | |
node_postprocessors=node_postprocessors, | |
callback_manager=callback_manager or Settings.callback_manager, | |
) | |