import argparse import os from dataclasses import dataclass import chromadb import yaml from langchain.chains.llm import LLMChain from langchain.vectorstores.chroma import Chroma from langchain_community.embeddings import HuggingFaceEmbeddings from langchain_community.llms.huggingface_endpoint import HuggingFaceEndpoint from langchain_core.prompts import PromptTemplate CONFIG_PATH = os.path.join('config', 'default_config.yaml') CHROMA_PATH = "chroma" MODEL_CACHE = "model_cache" PROMPT_TEMPLATE = """ Answer the question based only on the following context: {context} --- Answer the question based on the above context: {question} """ def main(): # Create CLI. parser = argparse.ArgumentParser() parser.add_argument("query_text", type=str, help="The query text.") args = parser.parse_args() query_text = args.query_text # Prepare the DB. hf_embed_func = HuggingFaceEmbeddings( model_name="all-MiniLM-L6-v2", model_kwargs={'device': 'cpu'}, encode_kwargs={'normalize_embeddings': False}, cache_folder=MODEL_CACHE ) db = Chroma(persist_directory=CHROMA_PATH, embedding_function=hf_embed_func, collection_name="jscholar_rag") client = chromadb.PersistentClient(path=CHROMA_PATH) collection = client.get_collection(name="jscholar_rag") print(f"Total Embeddings: {collection.count()}") print(collection.peek()) # Search the DB. results = db.similarity_search_with_relevance_scores(query_text, k=5) # results = db.similarity_search(query_text) if len(results) == 0 or results[0][1] < 0.1: print(f"Unable to find matching results.") return context_text = "\n\n---\n\n".join([doc.page_content for doc, _score in results]) # prompt_template = ChatPromptTemplate.from_template(PROMPT_TEMPLATE) prompt = PromptTemplate( input_variables=[context_text, query_text], template=PROMPT_TEMPLATE ) #prompt = prompt_template.format(context=context_text, question=query_text) llm = HuggingFaceEndpoint( repo_id="HuggingFaceH4/zephyr-7b-beta", task="text-generation", top_k=30, temperature=0.1, repetition_penalty=1.03, max_new_tokens=512, ) chat_model = LLMChain(prompt=prompt, llm=llm) response_text = chat_model.invoke({'question': query_text, 'context': context_text}) sources = [doc.metadata.get("source", None) for doc, _score in results] formatted_response = f"{response_text.get('text')}" formatted_sources = f"Citations: {sources}" print(formatted_response) print(formatted_sources) def load_config(): with open(CONFIG_PATH, 'r') as file: loaded_data = yaml.safe_load(file) return loaded_data if __name__ == "__main__": main()