Mitrakara-TriwiraData / RAGModule.py
MarcoAland's picture
update
04223c8
# Embedding model builder
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.core import Settings, SimpleDirectoryReader, VectorStoreIndex
from llama_index.core.retrievers import VectorIndexRetriever
from llama_index.core.query_engine import RetrieverQueryEngine
from llama_index.core.postprocessor import SimilarityPostprocessor
def set_embed_model(model_name: str,
chunk_size: int = 256,
chunk_overlap: int = 25) -> None:
Settings.llm = None
Settings.embed_model = HuggingFaceEmbedding(model_name=model_name)
Settings.chunk_size = chunk_size
Settings.chunk_overlap = chunk_overlap
class RAGModule:
def __init__(self,
llm_model: str = "MarcoAland/llama3.1-rag-indo",
embedding_model: str = "MarcoAland/Indo-bge-m3",
docs_path: str = "data",
top_k: int = 3,
similarity_cutoff: float = 0.3):
# Define embedding model
set_embed_model(model_name=embedding_model)
# Set vector DB
documents = SimpleDirectoryReader(docs_path).load_data()
index = VectorStoreIndex.from_documents(documents)
retriever = VectorIndexRetriever(
index=index,
similarity_top_k=top_k,
)
self.top_k = top_k
self.query_engine = RetrieverQueryEngine(
retriever=retriever,
node_postprocessors=[SimilarityPostprocessor(similarity_cutoff=similarity_cutoff)]
)
def format_context(self, response):
context = "Context:\n"
for i in range(self.top_k):
context += response.source_nodes[i].text + "\n\n"
return context
def query(self, query: str):
try:
response = self.query_engine.query(query)
context = self.format_context(response)
return context
except:
return ""
def prompt(self, context: str, instruction: str):
return f"{context}\n ### Instruksi:\n {instruction}"
def main(self, instruction: str):
context = self.query(query=instruction)
prompt = self.prompt(context=context, instruction=instruction)
# print(prompt)
return prompt