Spaces:
Runtime error
Runtime error
import argparse | |
import os | |
from dataclasses import dataclass | |
import chromadb | |
import yaml | |
from langchain.chains.llm import LLMChain | |
from langchain.vectorstores.chroma import Chroma | |
from langchain_community.embeddings import HuggingFaceEmbeddings | |
from langchain_community.llms.huggingface_endpoint import HuggingFaceEndpoint | |
from langchain_core.prompts import PromptTemplate | |
CONFIG_PATH = os.path.join('config', 'default_config.yaml') | |
CHROMA_PATH = "chroma" | |
MODEL_CACHE = "model_cache" | |
PROMPT_TEMPLATE = """ | |
Answer the question based only on the following context: | |
{context} | |
--- | |
Answer the question based on the above context: {question} | |
""" | |
def main(): | |
# Create CLI. | |
parser = argparse.ArgumentParser() | |
parser.add_argument("query_text", type=str, help="The query text.") | |
args = parser.parse_args() | |
query_text = args.query_text | |
# Prepare the DB. | |
hf_embed_func = HuggingFaceEmbeddings( | |
model_name="all-MiniLM-L6-v2", | |
model_kwargs={'device': 'cpu'}, | |
encode_kwargs={'normalize_embeddings': False}, | |
cache_folder=MODEL_CACHE | |
) | |
db = Chroma(persist_directory=CHROMA_PATH, embedding_function=hf_embed_func, collection_name="jscholar_rag") | |
client = chromadb.PersistentClient(path=CHROMA_PATH) | |
collection = client.get_collection(name="jscholar_rag") | |
print(f"Total Embeddings: {collection.count()}") | |
print(collection.peek()) | |
# Search the DB. | |
results = db.similarity_search_with_relevance_scores(query_text, k=5) | |
# results = db.similarity_search(query_text) | |
if len(results) == 0 or results[0][1] < 0.1: | |
print(f"Unable to find matching results.") | |
return | |
context_text = "\n\n---\n\n".join([doc.page_content for doc, _score in results]) | |
# prompt_template = ChatPromptTemplate.from_template(PROMPT_TEMPLATE) | |
prompt = PromptTemplate( | |
input_variables=[context_text, query_text], template=PROMPT_TEMPLATE | |
) | |
#prompt = prompt_template.format(context=context_text, question=query_text) | |
llm = HuggingFaceEndpoint( | |
repo_id="HuggingFaceH4/zephyr-7b-beta", | |
task="text-generation", | |
top_k=30, | |
temperature=0.1, | |
repetition_penalty=1.03, | |
max_new_tokens=512, | |
) | |
chat_model = LLMChain(prompt=prompt, llm=llm) | |
response_text = chat_model.invoke({'question': query_text, 'context': context_text}) | |
sources = [doc.metadata.get("source", None) for doc, _score in results] | |
formatted_response = f"{response_text.get('text')}" | |
formatted_sources = f"Citations: {sources}" | |
print(formatted_response) | |
print(formatted_sources) | |
def load_config(): | |
with open(CONFIG_PATH, 'r') as file: | |
loaded_data = yaml.safe_load(file) | |
return loaded_data | |
if __name__ == "__main__": | |
main() |