Spaces:
Sleeping
Sleeping
import chromadb | |
from llama_index.core.base.embeddings.base import similarity | |
#from llama_index.llms.ollama import Ollama | |
from llama_index.llms.groq import Groq | |
from llama_index.core import VectorStoreIndex, SimpleDirectoryReader, Settings, DocumentSummaryIndex | |
from llama_index.core import StorageContext, get_response_synthesizer | |
from llama_index.core.retrievers import VectorIndexRetriever | |
from llama_index.core.query_engine import RetrieverQueryEngine | |
from llama_index.vector_stores.chroma import ChromaVectorStore | |
from llama_index.embeddings.huggingface import HuggingFaceEmbedding | |
from llama_index.core import load_index_from_storage | |
import os | |
from dotenv import load_dotenv | |
from llama_index.core.callbacks import CallbackManager, LlamaDebugHandler, CBEventType | |
from llama_index.core.node_parser import SentenceSplitter | |
from llama_index.core.postprocessor import SimilarityPostprocessor | |
import time | |
import gradio as gr | |
from llama_index.core.memory import ChatMemoryBuffer | |
from llama_parse import LlamaParse | |
from llama_index.core import PromptTemplate | |
from llama_index.core.llms import ChatMessage, MessageRole | |
from llama_index.core.chat_engine import CondenseQuestionChatEngine | |
load_dotenv() | |
GROQ_API_KEY = os.getenv('GROQ_API_KEY') | |
LLAMAINDEX_API_KEY = os.getenv('LLAMAINDEX_API_KEY') | |
# set up callback manager | |
llama_debug = LlamaDebugHandler(print_trace_on_end=True) | |
callback_manager = CallbackManager([llama_debug]) | |
Settings.callback_manager = callback_manager | |
# set up LLM | |
llm = Groq(model="llama3-70b-8192")#"llama3-8b-8192") | |
Settings.llm = llm | |
# set up embedding model | |
embed_model = HuggingFaceEmbedding(model_name="BAAI/bge-small-en-v1.5") | |
Settings.embed_model = embed_model | |
# create splitter | |
splitter = SentenceSplitter(chunk_size=2048, chunk_overlap=50) | |
Settings.transformations = [splitter] | |
# create parser | |
parser = LlamaParse( | |
api_key=LLAMAINDEX_API_KEY, | |
result_type="markdown", # "markdown" and "text" are available | |
verbose=True, | |
) | |
#create index | |
if os.path.exists("./vectordb"): | |
print("Index Exists!") | |
storage_context = StorageContext.from_defaults(persist_dir="./vectordb") | |
index = load_index_from_storage(storage_context) | |
else: | |
filename_fn = lambda filename: {"file_name": filename} | |
required_exts = [".pdf",".docx"] | |
file_extractor = {".pdf": parser} | |
reader = SimpleDirectoryReader( | |
input_dir="./data", | |
file_extractor=file_extractor, | |
required_exts=required_exts, | |
recursive=True, | |
file_metadata=filename_fn | |
) | |
documents = reader.load_data() | |
print("index creating with `%d` documents", len(documents)) | |
index = VectorStoreIndex.from_documents(documents, embed_model=embed_model, transformations=[splitter]) | |
index.storage_context.persist(persist_dir="./vectordb") | |
""" | |
#create document summary index | |
if os.path.exists("./docsummarydb"): | |
print("Index Exists!") | |
storage_context = StorageContext.from_defaults(persist_dir="./docsummarydb") | |
doc_index = load_index_from_storage(storage_context) | |
else: | |
filename_fn = lambda filename: {"file_name": filename} | |
required_exts = [".pdf",".docx"] | |
reader = SimpleDirectoryReader( | |
input_dir="./data", | |
required_exts=required_exts, | |
recursive=True, | |
file_metadata=filename_fn | |
) | |
documents = reader.load_data() | |
print("index creating with `%d` documents", len(documents)) | |
response_synthesizer = get_response_synthesizer( | |
response_mode="tree_summarize", use_async=True | |
) | |
doc_index = DocumentSummaryIndex.from_documents( | |
documents, | |
llm = llm, | |
transformations = [splitter], | |
response_synthesizer = response_synthesizer, | |
show_progress = True | |
) | |
doc_index.storage_context.persist(persist_dir="./docsummarydb") | |
""" | |
""" | |
retriever = DocumentSummaryIndexEmbeddingRetriever( | |
doc_index, | |
similarity_top_k=5, | |
) | |
""" | |
# set up retriever | |
retriever = VectorIndexRetriever( | |
index = index, | |
similarity_top_k = 10, | |
#vector_store_query_mode="mmr", | |
#vector_store_kwargs={"mmr_threshold": 0.4} | |
) | |
# set up response synthesizer | |
response_synthesizer = get_response_synthesizer() | |
### customising prompts worsened the result### | |
""" | |
# set up prompt template | |
qa_prompt_tmpl = ( | |
"Context information from multiple sources is below.\n" | |
"---------------------\n" | |
"{context_str}\n" | |
"---------------------\n" | |
"Given the information from multiple sources and not prior knowledge, " | |
"answer the query.\n" | |
"Query: {query_str}\n" | |
"Answer: " | |
) | |
qa_prompt = PromptTemplate(qa_prompt_tmpl) | |
""" | |
# setting up query engine | |
query_engine = RetrieverQueryEngine( | |
retriever = retriever, | |
node_postprocessors=[SimilarityPostprocessor(similarity_cutoff=0.53)], | |
response_synthesizer=get_response_synthesizer(response_mode="tree_summarize",verbose=True) | |
) | |
print(query_engine.get_prompts()) | |
#response = query_engine.query("What happens if the distributor wants its own warehouse for pizzahood?") | |
#print(response) | |
memory = ChatMemoryBuffer.from_defaults(token_limit=10000) | |
custom_prompt = PromptTemplate( | |
"""\ | |
Given a conversation (between Human and Assistant) and a follow up message from Human, \ | |
rewrite the message to be a standalone question that captures all relevant context \ | |
from the conversation. If you are unsure, ask for more information. | |
<Chat History> | |
{chat_history} | |
<Follow Up Message> | |
{question} | |
<Standalone question> | |
""" | |
) | |
# list of `ChatMessage` objects | |
custom_chat_history = [ | |
ChatMessage( | |
role=MessageRole.USER, | |
content="Hello assistant.", | |
), | |
ChatMessage(role=MessageRole.ASSISTANT, content="Hello user."), | |
] | |
chat_engine = CondenseQuestionChatEngine.from_defaults( | |
query_engine=query_engine, | |
condense_question_prompt=custom_prompt, | |
chat_history=custom_chat_history, | |
verbose=True, | |
memory=memory | |
) | |
# gradio with streaming support | |
with gr.Blocks() as demo: | |
chat_engine = chat_engine | |
chatbot = gr.Chatbot() | |
msg = gr.Textbox(label="⏎ for sending", | |
placeholder="Ask me something",) | |
clear = gr.Button("Delete") | |
def user(user_message, history): | |
return "", history + [[user_message, None]] | |
def bot(history): | |
user_message = history[-1][0] | |
#bot_message = chat_engine.chat(user_message) | |
bot_message = query_engine.query(user_message + "Let's think step by step to get the correct answer. If you cannot provide an answer, say you don't know.") | |
history[-1][1] = "" | |
for character in bot_message.response: | |
history[-1][1] += character | |
time.sleep(0.01) | |
yield history | |
msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then( | |
bot, chatbot, chatbot | |
) | |
clear.click(lambda: None, None, chatbot, queue=False) | |
# demo.queue() | |
demo.launch(share=False) |