dmedhi commited on
Commit
6917098
1 Parent(s): f0497c6

chat appication

Browse files
Files changed (4) hide show
  1. .gitignore +165 -0
  2. app.py +157 -0
  3. datastore.py +107 -0
  4. embeddings.py +12 -0
.gitignore ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/#use-with-ide
110
+ .pdm.toml
111
+
112
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113
+ __pypackages__/
114
+
115
+ # Celery stuff
116
+ celerybeat-schedule
117
+ celerybeat.pid
118
+
119
+ # SageMath parsed files
120
+ *.sage.py
121
+
122
+ # Environments
123
+ .env
124
+ .venv
125
+ env/
126
+ venv/
127
+ ENV/
128
+ env.bak/
129
+ venv.bak/
130
+
131
+ # Spyder project settings
132
+ .spyderproject
133
+ .spyproject
134
+
135
+ # Rope project settings
136
+ .ropeproject
137
+
138
+ # mkdocs documentation
139
+ /site
140
+
141
+ # mypy
142
+ .mypy_cache/
143
+ .dmypy.json
144
+ dmypy.json
145
+
146
+ # Pyre type checker
147
+ .pyre/
148
+
149
+ # pytype static type analyzer
150
+ .pytype/
151
+
152
+ # Cython debug symbols
153
+ cython_debug/
154
+
155
+ # PyCharm
156
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
159
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160
+ #.idea/
161
+
162
+
163
+ models
164
+ chroma
165
+ *.ipynb
app.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import os
3
+ import uuid
4
+
5
+ import fitz
6
+ import streamlit as st
7
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
8
+ from llama_cpp import Llama
9
+
10
+ from datastore import ChromaStore
11
+ from embeddings import Embedding
12
+
13
+ #### state
14
+ if "chat_history" not in st.session_state:
15
+ st.session_state.chat_history = []
16
+ if "document_submitted" not in st.session_state:
17
+ st.session_state.document_submitted = False
18
+
19
+
20
+ def phi3(input: str, relevant_chunks: list):
21
+ llm = Llama(
22
+ model_path=os.path.join(
23
+ os.getcwd(),
24
+ "models",
25
+ "Phi-3.1-mini-4k-instruct-Q4_K_M.gguf",
26
+ ),
27
+ n_ctx=2000,
28
+ n_threads=1, # The number of CPU threads to use,
29
+ n_gpu_layers=0, # The number of layers to offload to GPU,
30
+ )
31
+
32
+ prompt = f"""CONTENT: {relevant_chunks}\n\nQUESTION: {input}\n\nFrom the given CONTENT, Please answer the QUESTION."""
33
+
34
+ output = llm(
35
+ f"<|user|>\n{prompt}<|end|>\n<|assistant|>",
36
+ max_tokens=2000,
37
+ stop=["<|end|>"],
38
+ echo=True,
39
+ )
40
+
41
+ cleaned_output = output["choices"][0]["text"].split("<|assistant|>", 1)[-1].strip()
42
+ return cleaned_output
43
+
44
+
45
+ def generate_unique_id():
46
+ unique_id = uuid.uuid4()
47
+ current_time = datetime.datetime.now().strftime("%Y%m%d%H%M%S%f")
48
+ combined_id = f"{unique_id}-{current_time}"
49
+ return combined_id
50
+
51
+
52
+ def add_to_vectorstore(content: str, chunk_size: int = 500, chunk_overlap: int = 20):
53
+ chromastore = ChromaStore(collection_name="pdf_store")
54
+
55
+ # delete if already exist
56
+ if "pdf_store" in chromastore.list_collections():
57
+ chromastore.delete("pdf_store")
58
+ st.toast("Old database cleaned!")
59
+ collection = chromastore.create()
60
+ # chunkify content
61
+ text_splitter = RecursiveCharacterTextSplitter(
62
+ chunk_size=chunk_size,
63
+ chunk_overlap=chunk_overlap,
64
+ length_function=len,
65
+ is_separator_regex=False,
66
+ )
67
+ chunks = text_splitter.split_text(content)
68
+
69
+ # generate embeddings and ids
70
+ embeddings, ids = [], []
71
+ for i, chunk in enumerate(chunks):
72
+ embeddings.append(Embedding.encode_text(chunk).tolist())
73
+ ids.append(generate_unique_id())
74
+
75
+ # add to vectorstore
76
+ chromastore.add(
77
+ collection=collection,
78
+ embeddings=embeddings,
79
+ documents=chunks,
80
+ ids=ids,
81
+ )
82
+
83
+
84
+ def similarity_search(query: str):
85
+ chromastore = ChromaStore(collection_name="pdf_store")
86
+ collection = chromastore.create()
87
+ query_embedding = Embedding.encode_text(query).tolist()
88
+ return chromastore.query(collection=collection, query_embedding=query_embedding)
89
+
90
+
91
+ def main():
92
+ st.set_page_config(page_icon="🤖", page_title="Phi 3 RAG", layout="wide")
93
+ st.markdown(
94
+ """<h1 style="text-align:center;">Phi 3 RAG</h1>""", unsafe_allow_html=True
95
+ )
96
+ st.markdown(
97
+ """<h3 style="text-align:center;">Conversational RAG application that utilizes local stack, <a href="https://huggingface.co/bartowski/Phi-3-medium-4k-instruct-GGUF">Phi-3 mini 4k instruct GGUF</a> and <a href="https://docs.trychroma.com/getting-started">ChromaDB</h3>""",
98
+ unsafe_allow_html=True,
99
+ )
100
+ layout = st.columns(2)
101
+
102
+ with layout[0]:
103
+ with st.container(border=True, height=550):
104
+ uploaded_file = st.file_uploader(
105
+ label="Upload document to search",
106
+ type="PDF",
107
+ accept_multiple_files=False,
108
+ )
109
+ submit = st.button("submit")
110
+
111
+ chunk_size = st.slider(
112
+ label="Chunk_size", min_value=100, max_value=2000, step=100
113
+ )
114
+ chunk_overlap = st.slider(
115
+ label="Chunk overlap", min_value=10, max_value=500, step=10
116
+ )
117
+ if uploaded_file is not None and submit is not False:
118
+ # load in vectorstore
119
+ doc = fitz.open(stream=uploaded_file.read(), filetype="pdf")
120
+ text = ""
121
+ for page in doc:
122
+ text += page.get_text()
123
+ doc.close()
124
+
125
+ # add to vectorstore
126
+ add_to_vectorstore(text, chunk_size, chunk_overlap)
127
+ st.session_state.document_submitted = True
128
+ st.toast("Document added successfully added to vectorstore", icon="✅")
129
+
130
+ # chats
131
+ with layout[1]:
132
+ with st.container(border=True, height=550):
133
+ if st.session_state.document_submitted:
134
+ user_input = st.chat_input("Ask me!")
135
+ if user_input is not None:
136
+ st.session_state.chat_history.append(
137
+ {"role": "user", "content": str(user_input)}
138
+ )
139
+
140
+ with st.spinner("Thinking..."):
141
+ # find on vector store
142
+ relevant_chunks = similarity_search(user_input)
143
+ response = phi3(
144
+ input=user_input, relevant_chunks=relevant_chunks
145
+ )
146
+ st.session_state.chat_history.append(
147
+ {"role": "assistant", "content": str(response)}
148
+ )
149
+
150
+ # display messages
151
+ for message in reversed(st.session_state.chat_history):
152
+ with st.chat_message(message["role"]):
153
+ st.markdown(message["content"])
154
+
155
+
156
+ if __name__ == "__main__":
157
+ main()
datastore.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __import__("pysqlite3")
2
+ import sys
3
+
4
+ sys.modules["sqlite3"] = sys.modules.pop("pysqlite3")
5
+
6
+
7
+ import uuid
8
+ from collections import defaultdict
9
+ from typing import Any, List
10
+
11
+ import chromadb
12
+ import numpy as np
13
+ from chromadb import Collection
14
+
15
+ from embeddings import Embedding
16
+
17
+
18
+ class ChromaStore:
19
+ def __init__(
20
+ self,
21
+ collection_name: str,
22
+ storage_path: str = "./chroma",
23
+ database: str = "database",
24
+ metadata: dict = {"hnsw:space": "cosine"},
25
+ ) -> None:
26
+ """Initiate Chromadb
27
+ - collection_name(str): name of the collection
28
+ - metadata(dict): available options for 'hnsw:space' are 'l2', 'ip' or 'cosine'.
29
+ """
30
+
31
+ self.collection_name = collection_name
32
+ self.metadata = metadata
33
+ self.storage_path = storage_path
34
+ self.database = database
35
+
36
+ self.client = chromadb.PersistentClient(path=self.storage_path)
37
+
38
+ def _health_check(self) -> bool:
39
+ return isinstance(self.client.heartbeat(), int)
40
+
41
+ def create(self):
42
+ collection = self.client.get_or_create_collection(
43
+ name=self.collection_name,
44
+ )
45
+ return collection
46
+
47
+ def add(
48
+ self,
49
+ collection: Collection,
50
+ embeddings: List[float],
51
+ documents: List[str],
52
+ ids: List[str],
53
+ ):
54
+ """Add embeddings, documents to index or collection.
55
+
56
+ Args:
57
+ - collection: created collection.
58
+ - embeddings: list of embeddings
59
+ - documents: text documents
60
+ - ids: list of ids"""
61
+ try:
62
+ collection.add(
63
+ embeddings=embeddings,
64
+ ids=ids,
65
+ documents=documents,
66
+ )
67
+ except Exception as e:
68
+ raise Exception(f"Failed to add documents to Chroma store. {e}")
69
+
70
+ def query(
71
+ self,
72
+ collection: Collection,
73
+ query_embedding: List[float],
74
+ top_k: int = 3,
75
+ ) -> list:
76
+ """Retrieve relevant images from chroma database.
77
+
78
+ Args:
79
+ - collection: created collection.
80
+ - query_embedding: query image embedding.
81
+ - top_k (int): top k images to retrieve.
82
+
83
+ Returns:
84
+ - list of images along with their score.
85
+ """
86
+ result = collection.query(query_embeddings=query_embedding, n_results=top_k)
87
+ relevant_chunks = [chunk for chunk in result["documents"][0]]
88
+ return relevant_chunks
89
+ # scores = [round(score, 3) for score in result["distances"][0]]
90
+ # return list(zip(relevant_chunks, scores))
91
+
92
+ def delete(self, collection_name: str):
93
+ try:
94
+ self.client.delete_collection(collection_name)
95
+ return True
96
+ except Exception as e:
97
+ raise Exception("Failed to delete collection", e)
98
+
99
+ def list_collections(self):
100
+ return self.client.list_collections()
101
+
102
+ @staticmethod
103
+ def collection_info(collection: Collection):
104
+ info = defaultdict(str)
105
+ info["count"] = collection.count()
106
+ info["top_10_items"] = collection.peek()
107
+ return info
embeddings.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ import torch
4
+ from sentence_transformers import SentenceTransformer
5
+
6
+
7
+ class Embedding:
8
+ model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
9
+
10
+ @classmethod
11
+ def encode_text(cls, text: str):
12
+ return cls.model.encode(text)