File size: 2,873 Bytes
85d4c3b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import argparse
import os
from dataclasses import dataclass

import chromadb
import yaml
from langchain.chains.llm import LLMChain
from langchain.vectorstores.chroma import Chroma
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.llms.huggingface_endpoint import HuggingFaceEndpoint
from langchain_core.prompts import PromptTemplate

CONFIG_PATH = os.path.join('config', 'default_config.yaml')
CHROMA_PATH = "chroma"
MODEL_CACHE = "model_cache"
PROMPT_TEMPLATE = """

Answer the question based only on the following context:



{context}



---



Answer the question based on the above context: {question}

"""


def main():
    # Create CLI.
    parser = argparse.ArgumentParser()
    parser.add_argument("query_text", type=str, help="The query text.")
    args = parser.parse_args()
    query_text = args.query_text

    # Prepare the DB.
    hf_embed_func = HuggingFaceEmbeddings(
        model_name="all-MiniLM-L6-v2",
        model_kwargs={'device': 'cpu'},
        encode_kwargs={'normalize_embeddings': False},
        cache_folder=MODEL_CACHE
    )
    db = Chroma(persist_directory=CHROMA_PATH, embedding_function=hf_embed_func, collection_name="jscholar_rag")

    client = chromadb.PersistentClient(path=CHROMA_PATH)
    collection = client.get_collection(name="jscholar_rag")
    print(f"Total Embeddings: {collection.count()}")
    print(collection.peek())

    # Search the DB.
    results = db.similarity_search_with_relevance_scores(query_text, k=5)
    # results = db.similarity_search(query_text)
    if len(results) == 0 or results[0][1] < 0.1:
        print(f"Unable to find matching results.")
        return

    context_text = "\n\n---\n\n".join([doc.page_content for doc, _score in results])
    # prompt_template = ChatPromptTemplate.from_template(PROMPT_TEMPLATE)
    prompt = PromptTemplate(
        input_variables=[context_text, query_text], template=PROMPT_TEMPLATE
    )
    #prompt = prompt_template.format(context=context_text, question=query_text)

    llm = HuggingFaceEndpoint(
        repo_id="HuggingFaceH4/zephyr-7b-beta",
        task="text-generation",
        top_k=30,
        temperature=0.1,
        repetition_penalty=1.03,
        max_new_tokens=512,
    )
    chat_model = LLMChain(prompt=prompt, llm=llm)

    response_text = chat_model.invoke({'question': query_text, 'context': context_text})

    sources = [doc.metadata.get("source", None) for doc, _score in results]
    formatted_response = f"{response_text.get('text')}"
    formatted_sources = f"Citations: {sources}"
    print(formatted_response)
    print(formatted_sources)


def load_config():
    with open(CONFIG_PATH, 'r') as file:
        loaded_data = yaml.safe_load(file)

    return loaded_data


if __name__ == "__main__":
    main()