Spaces:
Runtime error
Runtime error
Commit
·
310bd34
1
Parent(s):
228c8c1
final
Browse files- app.py +66 -59
- requirements.txt +205 -4
app.py
CHANGED
@@ -1,38 +1,85 @@
|
|
1 |
-
#
|
2 |
# AI MAKERSPACE MIDTERM PROJECT: META RAG CHATBOT
|
3 |
# Date: 2024-5-2
|
4 |
# Authors: MikeC
|
5 |
|
6 |
# Basic Imports & Setup
|
7 |
import os
|
8 |
-
|
9 |
from openai import AsyncOpenAI
|
10 |
|
11 |
# Using Chainlit for our UI
|
12 |
-
import chainlit as cl
|
13 |
-
from chainlit.prompt import Prompt, PromptMessage
|
14 |
-
from chainlit.playground.providers import ChatOpenAI
|
15 |
|
16 |
# Getting the API key from the .env file
|
17 |
from dotenv import load_dotenv
|
18 |
load_dotenv()
|
19 |
|
20 |
-
# RAG
|
|
|
|
|
21 |
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
|
|
|
23 |
|
|
|
|
|
|
|
|
|
|
|
24 |
|
|
|
25 |
|
26 |
-
|
27 |
-
|
28 |
-
""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
|
30 |
-
|
31 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
"""
|
33 |
|
34 |
-
|
35 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
async def start_chat():
|
37 |
settings = {
|
38 |
"model": "gpt-3.5-turbo",
|
@@ -42,54 +89,14 @@ async def start_chat():
|
|
42 |
"frequency_penalty": 0,
|
43 |
"presence_penalty": 0,
|
44 |
}
|
45 |
-
|
46 |
cl.user_session.set("settings", settings)
|
47 |
|
48 |
-
|
49 |
-
@cl.on_message # marks a function that should be run each time the chatbot receives a message from a user
|
50 |
async def main(message: cl.Message):
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
print(message.content)
|
56 |
-
|
57 |
-
prompt = Prompt(
|
58 |
-
provider=ChatOpenAI.id,
|
59 |
-
messages=[
|
60 |
-
PromptMessage(
|
61 |
-
role="system",
|
62 |
-
template=system_template,
|
63 |
-
formatted=system_template,
|
64 |
-
),
|
65 |
-
PromptMessage(
|
66 |
-
role="user",
|
67 |
-
template=user_template,
|
68 |
-
formatted=user_template.format(input=message.content),
|
69 |
-
),
|
70 |
-
],
|
71 |
-
inputs={"input": message.content},
|
72 |
-
settings=settings,
|
73 |
-
)
|
74 |
-
|
75 |
-
print([m.to_openai() for m in prompt.messages])
|
76 |
-
|
77 |
-
msg = cl.Message(content="")
|
78 |
-
|
79 |
-
|
80 |
-
# Question and Answer Chatbot
|
81 |
-
# Call OpenAI
|
82 |
-
async for stream_resp in await client.chat.completions.create(
|
83 |
-
messages=[m.to_openai() for m in prompt.messages], stream=True, **settings
|
84 |
-
):
|
85 |
-
token = stream_resp.choices[0].delta.content
|
86 |
-
if not token:
|
87 |
-
token = ""
|
88 |
-
await msg.stream_token(token)
|
89 |
-
|
90 |
-
# Update the prompt object with the completion
|
91 |
-
prompt.completion = msg.content
|
92 |
-
msg.prompt = prompt
|
93 |
|
94 |
-
|
95 |
await msg.send()
|
|
|
|
|
1 |
# AI MAKERSPACE MIDTERM PROJECT: META RAG CHATBOT
|
2 |
# Date: 2024-5-2
|
3 |
# Authors: MikeC
|
4 |
|
5 |
# Basic Imports & Setup
|
6 |
import os
|
|
|
7 |
from openai import AsyncOpenAI
|
8 |
|
9 |
# Using Chainlit for our UI
|
10 |
+
import chainlit as cl
|
11 |
+
from chainlit.prompt import Prompt, PromptMessage
|
12 |
+
from chainlit.playground.providers import ChatOpenAI
|
13 |
|
14 |
# Getting the API key from the .env file
|
15 |
from dotenv import load_dotenv
|
16 |
load_dotenv()
|
17 |
|
18 |
+
# RAG pipeline imports and setup code
|
19 |
+
from langchain.document_loaders import PyMuPDFLoader
|
20 |
+
docs = PyMuPDFLoader("https://d18rn0p25nwr6d.cloudfront.net/CIK-0001326801/c7318154-f6ae-4866-89fa-f0c589f2ee3d.pdf").load()
|
21 |
|
22 |
+
import tiktoken
|
23 |
+
def tiktoken_len(text):
|
24 |
+
tokens = tiktoken.encoding_for_model("gpt-3.5-turbo").encode(
|
25 |
+
text,
|
26 |
+
)
|
27 |
+
return len(tokens)
|
28 |
|
29 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
30 |
|
31 |
+
text_splitter = RecursiveCharacterTextSplitter(
|
32 |
+
chunk_size = 200,
|
33 |
+
chunk_overlap = 0,
|
34 |
+
length_function = tiktoken_len,
|
35 |
+
)
|
36 |
|
37 |
+
split_chunks = text_splitter.split_documents(docs)
|
38 |
|
39 |
+
from langchain_openai.embeddings import OpenAIEmbeddings
|
40 |
+
|
41 |
+
embedding_model = OpenAIEmbeddings(model="text-embedding-3-small")
|
42 |
+
|
43 |
+
from langchain_community.vectorstores import Qdrant
|
44 |
+
|
45 |
+
qdrant_vectorstore = Qdrant.from_documents(
|
46 |
+
split_chunks,
|
47 |
+
embedding_model,
|
48 |
+
location=":memory:",
|
49 |
+
collection_name="MetaFin",
|
50 |
+
)
|
51 |
|
52 |
+
qdrant_retriever = qdrant_vectorstore.as_retriever()
|
53 |
+
|
54 |
+
from langchain_openai import ChatOpenAI
|
55 |
+
openai_chat_model = ChatOpenAI(model="gpt-3.5-turbo")
|
56 |
+
|
57 |
+
from langchain_core.prompts import ChatPromptTemplate
|
58 |
+
RAG_PROMPT = """
|
59 |
+
CONTEXT:
|
60 |
+
{context}
|
61 |
+
|
62 |
+
QUERY:
|
63 |
+
{question}
|
64 |
+
|
65 |
+
Use the provided context to answer the user's query. You are a professional financial expert. You always review the provided financial information. You provide correct, substantiated answers. You may not answer the user's query unless there is a specific context in the following text. If asked about the Board of Directors, then add Mark Zuckerberg as the "Board Chair".
|
66 |
+
If you do not know the answer, or cannot answer, please respond with "Insufficient data for further analysis, please try again". >>
|
67 |
"""
|
68 |
|
69 |
+
rag_prompt = ChatPromptTemplate.from_template(RAG_PROMPT)
|
70 |
+
|
71 |
+
from operator import itemgetter
|
72 |
+
from langchain.schema.output_parser import StrOutputParser
|
73 |
+
from langchain.schema.runnable import RunnablePassthrough
|
74 |
+
|
75 |
+
retrieval_augmented_qa_chain = (
|
76 |
+
{"context": itemgetter("question") | qdrant_retriever, "question": itemgetter("question")}
|
77 |
+
| RunnablePassthrough.assign(context=itemgetter("context"))
|
78 |
+
| {"response": rag_prompt | openai_chat_model, "context": itemgetter("context")}
|
79 |
+
)
|
80 |
+
|
81 |
+
# Chainlit App
|
82 |
+
@cl.on_chat_start
|
83 |
async def start_chat():
|
84 |
settings = {
|
85 |
"model": "gpt-3.5-turbo",
|
|
|
89 |
"frequency_penalty": 0,
|
90 |
"presence_penalty": 0,
|
91 |
}
|
|
|
92 |
cl.user_session.set("settings", settings)
|
93 |
|
94 |
+
@cl.on_message
|
|
|
95 |
async def main(message: cl.Message):
|
96 |
+
chainlit_question = message.content
|
97 |
+
#chainlit_question = "What was the total value of 'Cash and cash equivalents' as of December 31, 2023?"
|
98 |
+
response = retrieval_augmented_qa_chain.invoke({"question": chainlit_question})
|
99 |
+
chainlit_answer = response["response"].content
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
100 |
|
101 |
+
msg = cl.Message(content=chainlit_answer)
|
102 |
await msg.send()
|
requirements.txt
CHANGED
@@ -1,5 +1,206 @@
|
|
1 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
cohere==4.37
|
3 |
-
|
4 |
-
|
5 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
aiofiles==23.2.1
|
2 |
+
aiohttp==3.9.5
|
3 |
+
aiosignal==1.3.1
|
4 |
+
altair==5.3.0
|
5 |
+
annotated-types==0.6.0
|
6 |
+
anyio==3.7.1
|
7 |
+
appdirs==1.4.4
|
8 |
+
argon2-cffi==23.1.0
|
9 |
+
argon2-cffi-bindings==21.2.0
|
10 |
+
arrow==1.3.0
|
11 |
+
async-lru==2.0.4
|
12 |
+
asyncer==0.0.2
|
13 |
+
attrs==23.2.0
|
14 |
+
Babel==2.14.0
|
15 |
+
backoff==2.2.1
|
16 |
+
beautifulsoup4==4.12.3
|
17 |
+
bidict==0.23.1
|
18 |
+
bleach==6.1.0
|
19 |
+
blinker==1.8.1
|
20 |
+
cachetools==5.3.3
|
21 |
+
certifi==2024.2.2
|
22 |
+
cffi==1.16.0
|
23 |
+
chainlit
|
24 |
+
charset-normalizer==3.3.2
|
25 |
+
click==8.1.7
|
26 |
cohere==4.37
|
27 |
+
contourpy==1.2.1
|
28 |
+
curl_cffi==0.6.2
|
29 |
+
cycler==0.12.1
|
30 |
+
dataclasses-json==0.5.14
|
31 |
+
datasets==2.19.0
|
32 |
+
defusedxml==0.7.1
|
33 |
+
Deprecated==1.2.14
|
34 |
+
dill==0.3.8
|
35 |
+
dirtyjson==1.0.8
|
36 |
+
distro==1.9.0
|
37 |
+
docker==7.0.0
|
38 |
+
docker-pycreds==0.4.0
|
39 |
+
duckduckgo_search==5.3.0
|
40 |
+
fastapi==0.100.1
|
41 |
+
fastapi-socketio==0.0.10
|
42 |
+
fastavro==1.9.4
|
43 |
+
fastjsonschema==2.19.1
|
44 |
+
filelock==3.13.4
|
45 |
+
filetype==1.2.0
|
46 |
+
fonttools==4.51.0
|
47 |
+
fqdn==1.5.1
|
48 |
+
frozenlist==1.4.1
|
49 |
+
fsspec==2024.3.1
|
50 |
+
gitdb==4.0.11
|
51 |
+
GitPython==3.1.43
|
52 |
+
googleapis-common-protos==1.63.0
|
53 |
+
grandalf==0.8
|
54 |
+
greenlet==3.0.3
|
55 |
+
grpcio==1.62.2
|
56 |
+
grpcio-tools==1.62.2
|
57 |
+
h11==0.14.0
|
58 |
+
h2==4.1.0
|
59 |
+
hpack==4.0.0
|
60 |
+
httpcore==0.17.3
|
61 |
+
httpx
|
62 |
+
huggingface-hub==0.22.2
|
63 |
+
hyperframe==6.0.1
|
64 |
+
idna==3.6
|
65 |
+
importlib-metadata==6.11.0
|
66 |
+
install==1.3.5
|
67 |
+
ipywidgets==8.1.2
|
68 |
+
isoduration==20.11.0
|
69 |
+
Jinja2==3.1.3
|
70 |
+
joblib==1.4.0
|
71 |
+
json5==0.9.25
|
72 |
+
jsonpatch==1.33
|
73 |
+
jsonpointer==2.4
|
74 |
+
jsonschema==4.21.1
|
75 |
+
jsonschema-specifications==2023.12.1
|
76 |
+
jupyter==1.0.0
|
77 |
+
jupyter-console==6.6.3
|
78 |
+
jupyter-events==0.10.0
|
79 |
+
jupyter-lsp==2.2.5
|
80 |
+
jupyter_server==2.14.0
|
81 |
+
jupyter_server_terminals==0.5.3
|
82 |
+
jupyterlab
|
83 |
+
jupyterlab_pygments==0.3.0
|
84 |
+
jupyterlab_server==2.26.0
|
85 |
+
jupyterlab_widgets==3.0.10
|
86 |
+
kiwisolver==1.4.5
|
87 |
+
langchain==0.1.17
|
88 |
+
langchain-community==0.0.36
|
89 |
+
langchain-core==0.1.50
|
90 |
+
langchain-openai==0.1.6
|
91 |
+
langchain-text-splitters==0.0.1
|
92 |
+
langchainhub==0.1.15
|
93 |
+
langsmith==0.1.48
|
94 |
+
Lazify==0.4.0
|
95 |
+
markdown-it-py==3.0.0
|
96 |
+
MarkupSafe==2.1.5
|
97 |
+
marshmallow==3.21.1
|
98 |
+
matplotlib==3.8.4
|
99 |
+
mdurl==0.1.2
|
100 |
+
mistune==3.0.2
|
101 |
+
multidict==6.0.5
|
102 |
+
multiprocess==0.70.16
|
103 |
+
mypy-extensions==1.0.0
|
104 |
+
nbclient==0.10.0
|
105 |
+
nbconvert==7.16.3
|
106 |
+
nbformat==5.10.4
|
107 |
+
networkx
|
108 |
+
nltk==3.8.1
|
109 |
+
notebook==7.1.2
|
110 |
+
notebook_shim==0.2.4
|
111 |
+
numpy==1.26.4
|
112 |
+
openai==1.25.1
|
113 |
+
opentelemetry-api==1.24.0
|
114 |
+
opentelemetry-exporter-otlp==1.24.0
|
115 |
+
opentelemetry-exporter-otlp-proto-common==1.24.0
|
116 |
+
opentelemetry-exporter-otlp-proto-grpc==1.24.0
|
117 |
+
opentelemetry-exporter-otlp-proto-http==1.24.0
|
118 |
+
opentelemetry-instrumentation==0.45b0
|
119 |
+
opentelemetry-proto==1.24.0
|
120 |
+
opentelemetry-sdk==1.24.0
|
121 |
+
opentelemetry-semantic-conventions==0.45b0
|
122 |
+
orjson==3.10.1
|
123 |
+
overrides==7.7.0
|
124 |
+
packaging==23.2
|
125 |
+
pandas==2.2.2
|
126 |
+
pandocfilters==1.5.1
|
127 |
+
pillow==10.3.0
|
128 |
+
plotly==5.22.0
|
129 |
+
portalocker==2.8.2
|
130 |
+
prometheus_client==0.20.0
|
131 |
+
protobuf==4.25.3
|
132 |
+
pyarrow==16.0.0
|
133 |
+
pyarrow-hotfix==0.6
|
134 |
+
pycparser==2.22
|
135 |
+
pydantic==2.6.4
|
136 |
+
pydantic_core==2.16.3
|
137 |
+
pydeck==0.9.0
|
138 |
+
PyJWT==2.8.0
|
139 |
+
PyMuPDF==1.24.2
|
140 |
+
PyMuPDFb==1.24.1
|
141 |
+
pyparsing==3.1.2
|
142 |
+
pypdf==4.2.0
|
143 |
+
pysbd==0.3.4
|
144 |
+
python-dotenv==1.0.0
|
145 |
+
python-engineio==4.9.0
|
146 |
+
python-graphql-client==0.4.3
|
147 |
+
python-json-logger==2.0.7
|
148 |
+
python-magic==0.4.27
|
149 |
+
python-multipart==0.0.6
|
150 |
+
python-socketio==5.11.2
|
151 |
+
pytz==2024.1
|
152 |
+
PyYAML==6.0.1
|
153 |
+
qdrant-client==1.9.1
|
154 |
+
qtconsole==5.5.1
|
155 |
+
QtPy==2.4.1
|
156 |
+
ragas==0.1.7
|
157 |
+
referencing==0.34.0
|
158 |
+
regex==2024.4.16
|
159 |
+
requests==2.31.0
|
160 |
+
rfc3339-validator==0.1.4
|
161 |
+
rfc3986-validator==0.1.1
|
162 |
+
rich==13.7.1
|
163 |
+
rpds-py==0.18.0
|
164 |
+
scikit-learn==1.4.2
|
165 |
+
scipy==1.13.0
|
166 |
+
Send2Trash==1.8.3
|
167 |
+
sentry-sdk==1.45.0
|
168 |
+
setproctitle==1.3.3
|
169 |
+
simple-websocket==1.0.0
|
170 |
+
smmap==5.0.1
|
171 |
+
sniffio==1.3.1
|
172 |
+
soupsieve==2.5
|
173 |
+
SQLAlchemy==2.0.29
|
174 |
+
starlette==0.27.0
|
175 |
+
streamlit==1.33.0
|
176 |
+
striprtf==0.0.26
|
177 |
+
syncer==2.0.3
|
178 |
+
tenacity==8.2.3
|
179 |
+
terminado==0.18.1
|
180 |
+
threadpoolctl==3.4.0
|
181 |
+
tiktoken==0.6.0
|
182 |
+
tinycss2==1.2.1
|
183 |
+
toml==0.10.2
|
184 |
+
tomli==2.0.1
|
185 |
+
toolz==0.12.1
|
186 |
+
tqdm==4.66.2
|
187 |
+
types-python-dateutil==2.9.0.20240316
|
188 |
+
types-requests==2.31.0.20240406
|
189 |
+
typing-inspect==0.9.0
|
190 |
+
tzdata==2024.1
|
191 |
+
uptrace==1.24.0
|
192 |
+
uri-template==1.3.0
|
193 |
+
urllib3==2.2.1
|
194 |
+
uvicorn==0.23.2
|
195 |
+
wandb==0.16.6
|
196 |
+
watchfiles==0.20.0
|
197 |
+
webcolors==1.13
|
198 |
+
webencodings==0.5.1
|
199 |
+
websocket-client==1.7.0
|
200 |
+
websockets==12.0
|
201 |
+
widgetsnbextension==4.0.10
|
202 |
+
wikipedia==1.4.0
|
203 |
+
wrapt==1.16.0
|
204 |
+
wsproto==1.2.0
|
205 |
+
xxhash==3.4.1
|
206 |
+
yarl==1.9.4
|