File size: 5,094 Bytes
a8fcfee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24981d7
a8fcfee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4109e8e
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
from llama_index.core.base.embeddings.base import similarity
from llama_index.llms.ollama import Ollama
from llama_index.core import VectorStoreIndex, SimpleDirectoryReader, Settings
from llama_index.core import StorageContext
from llama_index.vector_stores.chroma import ChromaVectorStore
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.core import load_index_from_storage
import os
from dotenv import load_dotenv
from llama_index.core.callbacks import CallbackManager, LlamaDebugHandler, CBEventType
from llama_index.core.node_parser import SentenceSplitter
from llama_index.core.postprocessor import SimilarityPostprocessor
from llama_index.llms.openai import OpenAI
from llama_index.llms.groq import Groq
from llama_parse import LlamaParse
from llama_index.core.indices.query.query_transform.base import HyDEQueryTransform
from llama_index.core.query_engine import TransformQueryEngine
from llama_index.core.extractors import (
    SummaryExtractor,
    QuestionsAnsweredExtractor,
)
from llama_index.core.schema import MetadataMode
from llama_index.core.ingestion import IngestionPipeline

load_dotenv()
# OPENAI_API_KEY = os.getenv('OPENAI_API_KEY')
GROQ_API_KEY = os.getenv('GROQ_API_KEY')
LLAMAINDEX_API_KEY = os.getenv('LLAMAINDEX_API_KEY')

llm = Groq(model="llama3-70b-8192")#"llama3-8b-8192")
Settings.llm = llm

# set up callback manager
llama_debug = LlamaDebugHandler(print_trace_on_end=True)
callback_manager = CallbackManager([llama_debug])
Settings.callback_manager = callback_manager

# converting documents into embeddings and indexing
embed_model = HuggingFaceEmbedding(model_name="BAAI/bge-small-en-v1.5")
Settings.embed_model = embed_model

# create splitter
splitter = SentenceSplitter(chunk_size=1024, chunk_overlap=20)
Settings.transformations = [splitter]

if os.path.exists("./vectordb"):
    storage_context = StorageContext.from_defaults(persist_dir="./vectordb")
    index = load_index_from_storage(storage_context)
else:
    parser = LlamaParse(
        api_key=LLAMAINDEX_API_KEY, 
        result_type="markdown",  # "markdown" and "text" are available
        verbose=True,
    )
    filename_fn = lambda filename: {"file_name": filename}
    required_exts = [".pdf",".docx"]
    file_extractor = {".pdf": parser}
    reader = SimpleDirectoryReader(
        "./data", 
        file_extractor=file_extractor,
        required_exts=required_exts,
        recursive=True,
        file_metadata=filename_fn
    )
    documents = reader.load_data()
    for doc in documents:
        doc.text = str(doc.metadata) +' '+ doc.text 
    print("index creating with `%d` documents", len(documents))
    # index = VectorStoreIndex.from_documents(documents, embed_model=embed_model, text_splitter=splitter)
    extractor_llm = Groq(model="llama3-70b-8192", temperature=0.1, max_tokens=512) #OpenAI(temperature=0.1, model="gpt-3.5-turbo", max_tokens=512)

    node_parser = SentenceSplitter(chunk_size=512, chunk_overlap=20)
    extractors = [
        SummaryExtractor(summaries=["prev", "self", "next"], llm=extractor_llm),
        QuestionsAnsweredExtractor(
            questions=3, llm=extractor_llm, metadata_mode=MetadataMode.EMBED
        ),
    ]
    nodes = node_parser.get_nodes_from_documents(documents)
    nodes_extract_ls = []
    print('extracting from:', len(nodes), ' nodes.')
    import time
    batch_size=5
    for i in range(0, len(nodes), batch_size):
        print(i)
        nodes_batch_raw = nodes[i:i+batch_size]
        try:
            pipeline = IngestionPipeline(transformations=[node_parser, *extractors])
            nodes_batch = pipeline.run(nodes=nodes_batch_raw, in_place=False, show_progress=True)
            nodes_extract_ls.append(nodes_batch)
        except:
            time.sleep(30) # api call limit reach, sleep 30 seconds before trying
    nodes_extract = [
        x
        for xs in nodes_extract_ls
        for x in xs
    ]
    index = VectorStoreIndex(nodes_extract)

    index.storage_context.persist(persist_dir="./vectordb")


query_engine = index.as_query_engine(
    similarity_top_k=5,
    #node_postprocessors=[SimilarityPostprocessor(similarity_cutoff=0.7)],
    verbose=True,
)

# query_engine = index.as_query_engine(
#     similarity_top_k=10,
#     #node_postprocessors=[SimilarityPostprocessor(similarity_cutoff=0.7)],
#     verbose=True,
# )

# # hyde: transform query into a temporary doc, and use doc to doc similarity match
# hyde = HyDEQueryTransform(include_original=True)
# hyde_query_engine = TransformQueryEngine(query_engine, query_transform=hyde)

import gradio as gr

def retreive(question):
    qns_w_source = "Answer the following question: " + question + " Followed by providing the page and file name of the source document as well, thank you!"
    streaming_response = query_engine.query(qns_w_source)
    # sources = streaming_response.get_formatted_sources(length=5000)
    return str(streaming_response) #+ "\n" + str(sources)

demo = gr.Interface(fn=retreive, inputs="textbox", outputs="textbox")

if __name__ == "__main__":
    demo.launch(share=True)