|
import chromadb |
|
from llama_index.core.base.embeddings.base import similarity |
|
|
|
from llama_index.llms.groq import Groq |
|
from llama_index.core import VectorStoreIndex, SimpleDirectoryReader, Settings, DocumentSummaryIndex |
|
from llama_index.core import StorageContext, get_response_synthesizer |
|
from llama_index.core.retrievers import VectorIndexRetriever |
|
from llama_index.core.query_engine import RetrieverQueryEngine |
|
from llama_index.vector_stores.chroma import ChromaVectorStore |
|
from llama_index.embeddings.huggingface import HuggingFaceEmbedding |
|
from llama_index.core import load_index_from_storage |
|
import os |
|
from dotenv import load_dotenv |
|
from llama_index.core.callbacks import CallbackManager, LlamaDebugHandler, CBEventType |
|
from llama_index.core.node_parser import SentenceSplitter |
|
from llama_index.core.postprocessor import SimilarityPostprocessor |
|
import time |
|
import gradio as gr |
|
from llama_index.core.memory import ChatMemoryBuffer |
|
from llama_parse import LlamaParse |
|
from llama_index.core import PromptTemplate |
|
from llama_index.core.llms import ChatMessage, MessageRole |
|
from llama_index.core.chat_engine import CondenseQuestionChatEngine |
|
|
|
|
|
|
|
load_dotenv() |
|
GROQ_API_KEY = os.getenv('GROQ_API_KEY') |
|
LLAMAINDEX_API_KEY = os.getenv('LLAMAINDEX_API_KEY') |
|
|
|
|
|
llama_debug = LlamaDebugHandler(print_trace_on_end=True) |
|
callback_manager = CallbackManager([llama_debug]) |
|
Settings.callback_manager = callback_manager |
|
|
|
|
|
llm = Groq(model="llama3-8b-8192") |
|
Settings.llm = llm |
|
|
|
|
|
embed_model = HuggingFaceEmbedding(model_name="BAAI/bge-small-en-v1.5") |
|
Settings.embed_model = embed_model |
|
|
|
|
|
splitter = SentenceSplitter(chunk_size=2048, chunk_overlap=50) |
|
Settings.transformations = [splitter] |
|
|
|
|
|
parser = LlamaParse( |
|
api_key=LLAMAINDEX_API_KEY, |
|
result_type="markdown", |
|
verbose=True, |
|
) |
|
|
|
|
|
if os.path.exists("./vectordb"): |
|
print("Index Exists!") |
|
storage_context = StorageContext.from_defaults(persist_dir="./vectordb") |
|
index = load_index_from_storage(storage_context) |
|
else: |
|
filename_fn = lambda filename: {"file_name": filename} |
|
required_exts = [".pdf",".docx"] |
|
file_extractor = {".pdf": parser} |
|
reader = SimpleDirectoryReader( |
|
input_dir="./data", |
|
file_extractor=file_extractor, |
|
required_exts=required_exts, |
|
recursive=True, |
|
file_metadata=filename_fn |
|
) |
|
documents = reader.load_data() |
|
print("index creating with `%d` documents", len(documents)) |
|
index = VectorStoreIndex.from_documents(documents, embed_model=embed_model, transformations=[splitter]) |
|
index.storage_context.persist(persist_dir="./vectordb") |
|
|
|
""" |
|
#create document summary index |
|
if os.path.exists("./docsummarydb"): |
|
print("Index Exists!") |
|
storage_context = StorageContext.from_defaults(persist_dir="./docsummarydb") |
|
doc_index = load_index_from_storage(storage_context) |
|
else: |
|
filename_fn = lambda filename: {"file_name": filename} |
|
required_exts = [".pdf",".docx"] |
|
reader = SimpleDirectoryReader( |
|
input_dir="./data", |
|
required_exts=required_exts, |
|
recursive=True, |
|
file_metadata=filename_fn |
|
) |
|
documents = reader.load_data() |
|
print("index creating with `%d` documents", len(documents)) |
|
|
|
response_synthesizer = get_response_synthesizer( |
|
response_mode="tree_summarize", use_async=True |
|
) |
|
doc_index = DocumentSummaryIndex.from_documents( |
|
documents, |
|
llm = llm, |
|
transformations = [splitter], |
|
response_synthesizer = response_synthesizer, |
|
show_progress = True |
|
) |
|
doc_index.storage_context.persist(persist_dir="./docsummarydb") |
|
""" |
|
""" |
|
retriever = DocumentSummaryIndexEmbeddingRetriever( |
|
doc_index, |
|
similarity_top_k=5, |
|
) |
|
""" |
|
|
|
|
|
retriever = VectorIndexRetriever( |
|
index = index, |
|
similarity_top_k = 10, |
|
|
|
|
|
) |
|
|
|
|
|
response_synthesizer = get_response_synthesizer() |
|
|
|
|
|
""" |
|
# set up prompt template |
|
qa_prompt_tmpl = ( |
|
"Context information from multiple sources is below.\n" |
|
"---------------------\n" |
|
"{context_str}\n" |
|
"---------------------\n" |
|
"Given the information from multiple sources and not prior knowledge, " |
|
"answer the query.\n" |
|
"Query: {query_str}\n" |
|
"Answer: " |
|
) |
|
qa_prompt = PromptTemplate(qa_prompt_tmpl) |
|
""" |
|
|
|
query_engine = RetrieverQueryEngine( |
|
retriever = retriever, |
|
node_postprocessors=[SimilarityPostprocessor(similarity_cutoff=0.53)], |
|
response_synthesizer=get_response_synthesizer(response_mode="tree_summarize",verbose=True) |
|
) |
|
print(query_engine.get_prompts()) |
|
|
|
|
|
|
|
|
|
|
|
memory = ChatMemoryBuffer.from_defaults(token_limit=10000) |
|
|
|
custom_prompt = PromptTemplate( |
|
"""\ |
|
Given a conversation (between Human and Assistant) and a follow up message from Human, \ |
|
rewrite the message to be a standalone question that captures all relevant context \ |
|
from the conversation. If you are unsure, ask for more information. |
|
<Chat History> |
|
{chat_history} |
|
<Follow Up Message> |
|
{question} |
|
<Standalone question> |
|
""" |
|
) |
|
|
|
|
|
custom_chat_history = [ |
|
ChatMessage( |
|
role=MessageRole.USER, |
|
content="Hello assistant.", |
|
), |
|
ChatMessage(role=MessageRole.ASSISTANT, content="Hello user."), |
|
] |
|
|
|
chat_engine = CondenseQuestionChatEngine.from_defaults( |
|
query_engine=query_engine, |
|
condense_question_prompt=custom_prompt, |
|
chat_history=custom_chat_history, |
|
verbose=True, |
|
memory=memory |
|
) |
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
chat_engine = chat_engine |
|
chatbot = gr.Chatbot() |
|
msg = gr.Textbox(label="⏎ for sending", |
|
placeholder="Ask me something",) |
|
clear = gr.Button("Delete") |
|
|
|
def user(user_message, history): |
|
return "", history + [[user_message, None]] |
|
|
|
def bot(history): |
|
user_message = history[-1][0] |
|
|
|
bot_message = query_engine.query(user_message + "Let's think step by step to get the correct answer. If you cannot provide an answer, say you don't know.") |
|
history[-1][1] = "" |
|
for character in bot_message.response: |
|
history[-1][1] += character |
|
time.sleep(0.01) |
|
yield history |
|
|
|
msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then( |
|
bot, chatbot, chatbot |
|
) |
|
clear.click(lambda: None, None, chatbot, queue=False) |
|
|
|
demo.launch(share=False) |