Merge pull request #27 from DL4DS/code_restructure
Browse filesThis view is limited to 50 files because it contains too many changes. Β
See raw diff
- .chainlit/translations/en-US.json +0 -231
- .chainlit/translations/pt-BR.json +0 -155
- .gitignore +9 -1
- Dockerfile.dev +6 -2
- README.md +71 -22
- {.chainlit β code/.chainlit}/config.toml +16 -22
- code/__init__.py +1 -0
- chainlit.md β code/chainlit.md +0 -0
- code/main.py +87 -39
- code/modules/chat/__init__.py +0 -0
- code/modules/{chat_model_loader.py β chat/chat_model_loader.py} +2 -3
- code/modules/chat/helpers.py +104 -0
- code/modules/{llm_tutor.py β chat/llm_tutor.py} +94 -60
- code/modules/chat_processor/__init__.py +0 -0
- code/modules/chat_processor/base.py +12 -0
- code/modules/chat_processor/chat_processor.py +29 -0
- code/modules/chat_processor/literal_ai.py +37 -0
- code/modules/config/__init__.py +0 -0
- code/{config.yml β modules/config/config.yml} +27 -8
- code/modules/{constants.py β config/constants.py} +2 -1
- code/modules/dataloader/__init__.py +0 -0
- code/modules/{data_loader.py β dataloader/data_loader.py} +184 -117
- code/modules/dataloader/helpers.py +108 -0
- code/modules/{helpers.py β dataloader/webpage_crawler.py} +3 -225
- code/modules/retriever/__init__.py +0 -0
- code/modules/retriever/base.py +12 -0
- code/modules/retriever/chroma_retriever.py +24 -0
- code/modules/retriever/colbert_retriever.py +10 -0
- code/modules/retriever/faiss_retriever.py +23 -0
- code/modules/retriever/helpers.py +39 -0
- code/modules/retriever/raptor_retriever.py +16 -0
- code/modules/retriever/retriever.py +26 -0
- code/modules/vector_db.py +0 -226
- code/modules/vectorstore/__init__.py +0 -0
- code/modules/vectorstore/base.py +33 -0
- code/modules/vectorstore/chroma.py +41 -0
- code/modules/vectorstore/colbert.py +39 -0
- code/modules/{embedding_model_loader.py β vectorstore/embedding_model_loader.py} +5 -8
- code/modules/vectorstore/faiss.py +45 -0
- code/modules/vectorstore/helpers.py +0 -0
- code/modules/vectorstore/raptor.py +438 -0
- code/modules/vectorstore/store_manager.py +163 -0
- code/modules/vectorstore/vectorstore.py +57 -0
- code/public/acastusphoton-svgrepo-com.svg +2 -0
- code/public/adv-screen-recorder-svgrepo-com.svg +2 -0
- code/public/alarmy-svgrepo-com.svg +2 -0
- public/logo_dark.png β code/public/avatars/ai-tutor.png +0 -0
- code/public/calendar-samsung-17-svgrepo-com.svg +36 -0
- public/logo_light.png β code/public/logo_dark.png +0 -0
- code/public/logo_light.png +0 -0
.chainlit/translations/en-US.json
DELETED
@@ -1,231 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"components": {
|
3 |
-
"atoms": {
|
4 |
-
"buttons": {
|
5 |
-
"userButton": {
|
6 |
-
"menu": {
|
7 |
-
"settings": "Settings",
|
8 |
-
"settingsKey": "S",
|
9 |
-
"APIKeys": "API Keys",
|
10 |
-
"logout": "Logout"
|
11 |
-
}
|
12 |
-
}
|
13 |
-
}
|
14 |
-
},
|
15 |
-
"molecules": {
|
16 |
-
"newChatButton": {
|
17 |
-
"newChat": "New Chat"
|
18 |
-
},
|
19 |
-
"tasklist": {
|
20 |
-
"TaskList": {
|
21 |
-
"title": "\ud83d\uddd2\ufe0f Task List",
|
22 |
-
"loading": "Loading...",
|
23 |
-
"error": "An error occured"
|
24 |
-
}
|
25 |
-
},
|
26 |
-
"attachments": {
|
27 |
-
"cancelUpload": "Cancel upload",
|
28 |
-
"removeAttachment": "Remove attachment"
|
29 |
-
},
|
30 |
-
"newChatDialog": {
|
31 |
-
"createNewChat": "Create new chat?",
|
32 |
-
"clearChat": "This will clear the current messages and start a new chat.",
|
33 |
-
"cancel": "Cancel",
|
34 |
-
"confirm": "Confirm"
|
35 |
-
},
|
36 |
-
"settingsModal": {
|
37 |
-
"settings": "Settings",
|
38 |
-
"expandMessages": "Expand Messages",
|
39 |
-
"hideChainOfThought": "Hide Chain of Thought",
|
40 |
-
"darkMode": "Dark Mode"
|
41 |
-
},
|
42 |
-
"detailsButton": {
|
43 |
-
"using": "Using",
|
44 |
-
"running": "Running",
|
45 |
-
"took_one": "Took {{count}} step",
|
46 |
-
"took_other": "Took {{count}} steps"
|
47 |
-
},
|
48 |
-
"auth": {
|
49 |
-
"authLogin": {
|
50 |
-
"title": "Login to access the app.",
|
51 |
-
"form": {
|
52 |
-
"email": "Email address",
|
53 |
-
"password": "Password",
|
54 |
-
"noAccount": "Don't have an account?",
|
55 |
-
"alreadyHaveAccount": "Already have an account?",
|
56 |
-
"signup": "Sign Up",
|
57 |
-
"signin": "Sign In",
|
58 |
-
"or": "OR",
|
59 |
-
"continue": "Continue",
|
60 |
-
"forgotPassword": "Forgot password?",
|
61 |
-
"passwordMustContain": "Your password must contain:",
|
62 |
-
"emailRequired": "email is a required field",
|
63 |
-
"passwordRequired": "password is a required field"
|
64 |
-
},
|
65 |
-
"error": {
|
66 |
-
"default": "Unable to sign in.",
|
67 |
-
"signin": "Try signing in with a different account.",
|
68 |
-
"oauthsignin": "Try signing in with a different account.",
|
69 |
-
"redirect_uri_mismatch": "The redirect URI is not matching the oauth app configuration.",
|
70 |
-
"oauthcallbackerror": "Try signing in with a different account.",
|
71 |
-
"oauthcreateaccount": "Try signing in with a different account.",
|
72 |
-
"emailcreateaccount": "Try signing in with a different account.",
|
73 |
-
"callback": "Try signing in with a different account.",
|
74 |
-
"oauthaccountnotlinked": "To confirm your identity, sign in with the same account you used originally.",
|
75 |
-
"emailsignin": "The e-mail could not be sent.",
|
76 |
-
"emailverify": "Please verify your email, a new email has been sent.",
|
77 |
-
"credentialssignin": "Sign in failed. Check the details you provided are correct.",
|
78 |
-
"sessionrequired": "Please sign in to access this page."
|
79 |
-
}
|
80 |
-
},
|
81 |
-
"authVerifyEmail": {
|
82 |
-
"almostThere": "You're almost there! We've sent an email to ",
|
83 |
-
"verifyEmailLink": "Please click on the link in that email to complete your signup.",
|
84 |
-
"didNotReceive": "Can't find the email?",
|
85 |
-
"resendEmail": "Resend email",
|
86 |
-
"goBack": "Go Back",
|
87 |
-
"emailSent": "Email sent successfully.",
|
88 |
-
"verifyEmail": "Verify your email address"
|
89 |
-
},
|
90 |
-
"providerButton": {
|
91 |
-
"continue": "Continue with {{provider}}",
|
92 |
-
"signup": "Sign up with {{provider}}"
|
93 |
-
},
|
94 |
-
"authResetPassword": {
|
95 |
-
"newPasswordRequired": "New password is a required field",
|
96 |
-
"passwordsMustMatch": "Passwords must match",
|
97 |
-
"confirmPasswordRequired": "Confirm password is a required field",
|
98 |
-
"newPassword": "New password",
|
99 |
-
"confirmPassword": "Confirm password",
|
100 |
-
"resetPassword": "Reset Password"
|
101 |
-
},
|
102 |
-
"authForgotPassword": {
|
103 |
-
"email": "Email address",
|
104 |
-
"emailRequired": "email is a required field",
|
105 |
-
"emailSent": "Please check the email address {{email}} for instructions to reset your password.",
|
106 |
-
"enterEmail": "Enter your email address and we will send you instructions to reset your password.",
|
107 |
-
"resendEmail": "Resend email",
|
108 |
-
"continue": "Continue",
|
109 |
-
"goBack": "Go Back"
|
110 |
-
}
|
111 |
-
}
|
112 |
-
},
|
113 |
-
"organisms": {
|
114 |
-
"chat": {
|
115 |
-
"history": {
|
116 |
-
"index": {
|
117 |
-
"showHistory": "Show history",
|
118 |
-
"lastInputs": "Last Inputs",
|
119 |
-
"noInputs": "Such empty...",
|
120 |
-
"loading": "Loading..."
|
121 |
-
}
|
122 |
-
},
|
123 |
-
"inputBox": {
|
124 |
-
"input": {
|
125 |
-
"placeholder": "Type your message here..."
|
126 |
-
},
|
127 |
-
"speechButton": {
|
128 |
-
"start": "Start recording",
|
129 |
-
"stop": "Stop recording"
|
130 |
-
},
|
131 |
-
"SubmitButton": {
|
132 |
-
"sendMessage": "Send message",
|
133 |
-
"stopTask": "Stop Task"
|
134 |
-
},
|
135 |
-
"UploadButton": {
|
136 |
-
"attachFiles": "Attach files"
|
137 |
-
},
|
138 |
-
"waterMark": {
|
139 |
-
"text": "Built with"
|
140 |
-
}
|
141 |
-
},
|
142 |
-
"Messages": {
|
143 |
-
"index": {
|
144 |
-
"running": "Running",
|
145 |
-
"executedSuccessfully": "executed successfully",
|
146 |
-
"failed": "failed",
|
147 |
-
"feedbackUpdated": "Feedback updated",
|
148 |
-
"updating": "Updating"
|
149 |
-
}
|
150 |
-
},
|
151 |
-
"dropScreen": {
|
152 |
-
"dropYourFilesHere": "Drop your files here"
|
153 |
-
},
|
154 |
-
"index": {
|
155 |
-
"failedToUpload": "Failed to upload",
|
156 |
-
"cancelledUploadOf": "Cancelled upload of",
|
157 |
-
"couldNotReachServer": "Could not reach the server",
|
158 |
-
"continuingChat": "Continuing previous chat"
|
159 |
-
},
|
160 |
-
"settings": {
|
161 |
-
"settingsPanel": "Settings panel",
|
162 |
-
"reset": "Reset",
|
163 |
-
"cancel": "Cancel",
|
164 |
-
"confirm": "Confirm"
|
165 |
-
}
|
166 |
-
},
|
167 |
-
"threadHistory": {
|
168 |
-
"sidebar": {
|
169 |
-
"filters": {
|
170 |
-
"FeedbackSelect": {
|
171 |
-
"feedbackAll": "Feedback: All",
|
172 |
-
"feedbackPositive": "Feedback: Positive",
|
173 |
-
"feedbackNegative": "Feedback: Negative"
|
174 |
-
},
|
175 |
-
"SearchBar": {
|
176 |
-
"search": "Search"
|
177 |
-
}
|
178 |
-
},
|
179 |
-
"DeleteThreadButton": {
|
180 |
-
"confirmMessage": "This will delete the thread as well as it's messages and elements.",
|
181 |
-
"cancel": "Cancel",
|
182 |
-
"confirm": "Confirm",
|
183 |
-
"deletingChat": "Deleting chat",
|
184 |
-
"chatDeleted": "Chat deleted"
|
185 |
-
},
|
186 |
-
"index": {
|
187 |
-
"pastChats": "Past Chats"
|
188 |
-
},
|
189 |
-
"ThreadList": {
|
190 |
-
"empty": "Empty...",
|
191 |
-
"today": "Today",
|
192 |
-
"yesterday": "Yesterday",
|
193 |
-
"previous7days": "Previous 7 days",
|
194 |
-
"previous30days": "Previous 30 days"
|
195 |
-
},
|
196 |
-
"TriggerButton": {
|
197 |
-
"closeSidebar": "Close sidebar",
|
198 |
-
"openSidebar": "Open sidebar"
|
199 |
-
}
|
200 |
-
},
|
201 |
-
"Thread": {
|
202 |
-
"backToChat": "Go back to chat",
|
203 |
-
"chatCreatedOn": "This chat was created on"
|
204 |
-
}
|
205 |
-
},
|
206 |
-
"header": {
|
207 |
-
"chat": "Chat",
|
208 |
-
"readme": "Readme"
|
209 |
-
}
|
210 |
-
}
|
211 |
-
},
|
212 |
-
"hooks": {
|
213 |
-
"useLLMProviders": {
|
214 |
-
"failedToFetchProviders": "Failed to fetch providers:"
|
215 |
-
}
|
216 |
-
},
|
217 |
-
"pages": {
|
218 |
-
"Design": {},
|
219 |
-
"Env": {
|
220 |
-
"savedSuccessfully": "Saved successfully",
|
221 |
-
"requiredApiKeys": "Required API Keys",
|
222 |
-
"requiredApiKeysInfo": "To use this app, the following API keys are required. The keys are stored on your device's local storage."
|
223 |
-
},
|
224 |
-
"Page": {
|
225 |
-
"notPartOfProject": "You are not part of this project."
|
226 |
-
},
|
227 |
-
"ResumeButton": {
|
228 |
-
"resumeChat": "Resume Chat"
|
229 |
-
}
|
230 |
-
}
|
231 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.chainlit/translations/pt-BR.json
DELETED
@@ -1,155 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"components": {
|
3 |
-
"atoms": {
|
4 |
-
"buttons": {
|
5 |
-
"userButton": {
|
6 |
-
"menu": {
|
7 |
-
"settings": "Configura\u00e7\u00f5es",
|
8 |
-
"settingsKey": "S",
|
9 |
-
"APIKeys": "Chaves de API",
|
10 |
-
"logout": "Sair"
|
11 |
-
}
|
12 |
-
}
|
13 |
-
}
|
14 |
-
},
|
15 |
-
"molecules": {
|
16 |
-
"newChatButton": {
|
17 |
-
"newChat": "Nova Conversa"
|
18 |
-
},
|
19 |
-
"tasklist": {
|
20 |
-
"TaskList": {
|
21 |
-
"title": "\ud83d\uddd2\ufe0f Lista de Tarefas",
|
22 |
-
"loading": "Carregando...",
|
23 |
-
"error": "Ocorreu um erro"
|
24 |
-
}
|
25 |
-
},
|
26 |
-
"attachments": {
|
27 |
-
"cancelUpload": "Cancelar envio",
|
28 |
-
"removeAttachment": "Remover anexo"
|
29 |
-
},
|
30 |
-
"newChatDialog": {
|
31 |
-
"createNewChat": "Criar novo chat?",
|
32 |
-
"clearChat": "Isso limpar\u00e1 as mensagens atuais e iniciar\u00e1 uma nova conversa.",
|
33 |
-
"cancel": "Cancelar",
|
34 |
-
"confirm": "Confirmar"
|
35 |
-
},
|
36 |
-
"settingsModal": {
|
37 |
-
"expandMessages": "Expandir Mensagens",
|
38 |
-
"hideChainOfThought": "Esconder Sequ\u00eancia de Pensamento",
|
39 |
-
"darkMode": "Modo Escuro"
|
40 |
-
}
|
41 |
-
},
|
42 |
-
"organisms": {
|
43 |
-
"chat": {
|
44 |
-
"history": {
|
45 |
-
"index": {
|
46 |
-
"lastInputs": "\u00daltimas Entradas",
|
47 |
-
"noInputs": "Vazio...",
|
48 |
-
"loading": "Carregando..."
|
49 |
-
}
|
50 |
-
},
|
51 |
-
"inputBox": {
|
52 |
-
"input": {
|
53 |
-
"placeholder": "Digite sua mensagem aqui..."
|
54 |
-
},
|
55 |
-
"speechButton": {
|
56 |
-
"start": "Iniciar grava\u00e7\u00e3o",
|
57 |
-
"stop": "Parar grava\u00e7\u00e3o"
|
58 |
-
},
|
59 |
-
"SubmitButton": {
|
60 |
-
"sendMessage": "Enviar mensagem",
|
61 |
-
"stopTask": "Parar Tarefa"
|
62 |
-
},
|
63 |
-
"UploadButton": {
|
64 |
-
"attachFiles": "Anexar arquivos"
|
65 |
-
},
|
66 |
-
"waterMark": {
|
67 |
-
"text": "Constru\u00eddo com"
|
68 |
-
}
|
69 |
-
},
|
70 |
-
"Messages": {
|
71 |
-
"index": {
|
72 |
-
"running": "Executando",
|
73 |
-
"executedSuccessfully": "executado com sucesso",
|
74 |
-
"failed": "falhou",
|
75 |
-
"feedbackUpdated": "Feedback atualizado",
|
76 |
-
"updating": "Atualizando"
|
77 |
-
}
|
78 |
-
},
|
79 |
-
"dropScreen": {
|
80 |
-
"dropYourFilesHere": "Solte seus arquivos aqui"
|
81 |
-
},
|
82 |
-
"index": {
|
83 |
-
"failedToUpload": "Falha ao enviar",
|
84 |
-
"cancelledUploadOf": "Envio cancelado de",
|
85 |
-
"couldNotReachServer": "N\u00e3o foi poss\u00edvel conectar ao servidor",
|
86 |
-
"continuingChat": "Continuando o chat anterior"
|
87 |
-
},
|
88 |
-
"settings": {
|
89 |
-
"settingsPanel": "Painel de Configura\u00e7\u00f5es",
|
90 |
-
"reset": "Redefinir",
|
91 |
-
"cancel": "Cancelar",
|
92 |
-
"confirm": "Confirmar"
|
93 |
-
}
|
94 |
-
},
|
95 |
-
"threadHistory": {
|
96 |
-
"sidebar": {
|
97 |
-
"filters": {
|
98 |
-
"FeedbackSelect": {
|
99 |
-
"feedbackAll": "Feedback: Todos",
|
100 |
-
"feedbackPositive": "Feedback: Positivo",
|
101 |
-
"feedbackNegative": "Feedback: Negativo"
|
102 |
-
},
|
103 |
-
"SearchBar": {
|
104 |
-
"search": "Buscar"
|
105 |
-
}
|
106 |
-
},
|
107 |
-
"DeleteThreadButton": {
|
108 |
-
"confirmMessage": "Isso deletar\u00e1 a conversa, assim como suas mensagens e elementos.",
|
109 |
-
"cancel": "Cancelar",
|
110 |
-
"confirm": "Confirmar",
|
111 |
-
"deletingChat": "Deletando conversa",
|
112 |
-
"chatDeleted": "Conversa deletada"
|
113 |
-
},
|
114 |
-
"index": {
|
115 |
-
"pastChats": "Conversas Anteriores"
|
116 |
-
},
|
117 |
-
"ThreadList": {
|
118 |
-
"empty": "Vazio..."
|
119 |
-
},
|
120 |
-
"TriggerButton": {
|
121 |
-
"closeSidebar": "Fechar barra lateral",
|
122 |
-
"openSidebar": "Abrir barra lateral"
|
123 |
-
}
|
124 |
-
},
|
125 |
-
"Thread": {
|
126 |
-
"backToChat": "Voltar para a conversa",
|
127 |
-
"chatCreatedOn": "Esta conversa foi criada em"
|
128 |
-
}
|
129 |
-
},
|
130 |
-
"header": {
|
131 |
-
"chat": "Conversa",
|
132 |
-
"readme": "Leia-me"
|
133 |
-
}
|
134 |
-
},
|
135 |
-
"hooks": {
|
136 |
-
"useLLMProviders": {
|
137 |
-
"failedToFetchProviders": "Falha ao buscar provedores:"
|
138 |
-
}
|
139 |
-
},
|
140 |
-
"pages": {
|
141 |
-
"Design": {},
|
142 |
-
"Env": {
|
143 |
-
"savedSuccessfully": "Salvo com sucesso",
|
144 |
-
"requiredApiKeys": "Chaves de API necess\u00e1rias",
|
145 |
-
"requiredApiKeysInfo": "Para usar este aplicativo, as seguintes chaves de API s\u00e3o necess\u00e1rias. As chaves s\u00e3o armazenadas localmente em seu dispositivo."
|
146 |
-
},
|
147 |
-
"Page": {
|
148 |
-
"notPartOfProject": "Voc\u00ea n\u00e3o faz parte deste projeto."
|
149 |
-
},
|
150 |
-
"ResumeButton": {
|
151 |
-
"resumeChat": "Continuar Conversa"
|
152 |
-
}
|
153 |
-
}
|
154 |
-
}
|
155 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.gitignore
CHANGED
@@ -160,4 +160,12 @@ cython_debug/
|
|
160 |
#.idea/
|
161 |
|
162 |
# log files
|
163 |
-
*.log
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
160 |
#.idea/
|
161 |
|
162 |
# log files
|
163 |
+
*.log
|
164 |
+
|
165 |
+
.ragatouille/*
|
166 |
+
*/__pycache__/*
|
167 |
+
*/.chainlit/translations/*
|
168 |
+
storage/logs/*
|
169 |
+
vectorstores/*
|
170 |
+
|
171 |
+
*/.files/*
|
Dockerfile.dev
CHANGED
@@ -10,7 +10,8 @@ RUN pip install --no-cache-dir -r /code/requirements.txt
|
|
10 |
|
11 |
COPY . /code
|
12 |
|
13 |
-
|
|
|
14 |
|
15 |
# Change permissions to allow writing to the directory
|
16 |
RUN chmod -R 777 /code
|
@@ -21,7 +22,10 @@ RUN mkdir /code/logs && chmod 777 /code/logs
|
|
21 |
# Create a cache directory within the application's working directory
|
22 |
RUN mkdir /.cache && chmod -R 777 /.cache
|
23 |
|
|
|
|
|
24 |
# Expose the port the app runs on
|
25 |
EXPOSE 8051
|
26 |
|
27 |
-
|
|
|
|
10 |
|
11 |
COPY . /code
|
12 |
|
13 |
+
# List the contents of the /code directory to verify files are copied correctly
|
14 |
+
RUN ls -R /code
|
15 |
|
16 |
# Change permissions to allow writing to the directory
|
17 |
RUN chmod -R 777 /code
|
|
|
22 |
# Create a cache directory within the application's working directory
|
23 |
RUN mkdir /.cache && chmod -R 777 /.cache
|
24 |
|
25 |
+
WORKDIR /code/code
|
26 |
+
|
27 |
# Expose the port the app runs on
|
28 |
EXPOSE 8051
|
29 |
|
30 |
+
# Default command to run the application
|
31 |
+
CMD ["sh", "-c", "python -m modules.vectorstore.store_manager && chainlit run main.py --host 0.0.0.0 --port 8051"]
|
README.md
CHANGED
@@ -1,35 +1,84 @@
|
|
1 |
-
|
2 |
-
title: Dl4ds Tutor
|
3 |
-
emoji: π
|
4 |
-
colorFrom: green
|
5 |
-
colorTo: red
|
6 |
-
sdk: docker
|
7 |
-
pinned: false
|
8 |
-
hf_oauth: true
|
9 |
-
---
|
10 |
|
11 |
-
|
12 |
-
===========
|
13 |
|
14 |
-
|
15 |
|
16 |
-
|
17 |
|
18 |
-
|
|
|
|
|
|
|
19 |
|
20 |
-
|
|
|
|
|
21 |
|
22 |
-
|
|
|
|
|
|
|
|
|
23 |
|
24 |
-
|
25 |
-
```
|
26 |
-
|
|
|
|
|
|
|
|
|
27 |
|
28 |
-
|
29 |
-
```
|
|
|
|
|
30 |
|
31 |
See the [docs](https://github.com/DL4DS/dl4ds_tutor/tree/main/docs) for more information.
|
32 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
## Contributing
|
34 |
|
35 |
-
Please create an issue if you have any suggestions or improvements, and start working on it by creating a branch and by making a pull request to the main branch.
|
|
|
1 |
+
# DL4DS Tutor π
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
+
Check out the configuration reference at [Hugging Face Spaces Config Reference](https://huggingface.co/docs/hub/spaces-config-reference).
|
|
|
4 |
|
5 |
+
You can find an implementation of the Tutor at [DL4DS Tutor on Hugging Face](https://dl4ds-dl4ds-tutor.hf.space/), which is hosted on Hugging Face [here](https://huggingface.co/spaces/dl4ds/dl4ds_tutor).
|
6 |
|
7 |
+
## Running Locally
|
8 |
|
9 |
+
1. **Clone the Repository**
|
10 |
+
```bash
|
11 |
+
git clone https://github.com/DL4DS/dl4ds_tutor
|
12 |
+
```
|
13 |
|
14 |
+
2. **Put your data under the `storage/data` directory**
|
15 |
+
- Add URLs in the `urls.txt` file.
|
16 |
+
- Add other PDF files in the `storage/data` directory.
|
17 |
|
18 |
+
3. **To test Data Loading (Optional)**
|
19 |
+
```bash
|
20 |
+
cd code
|
21 |
+
python -m modules.dataloader.data_loader
|
22 |
+
```
|
23 |
|
24 |
+
4. **Create the Vector Database**
|
25 |
+
```bash
|
26 |
+
cd code
|
27 |
+
python -m modules.vectorstore.store_manager
|
28 |
+
```
|
29 |
+
- Note: You need to run the above command when you add new data to the `storage/data` directory, or if the `storage/data/urls.txt` file is updated.
|
30 |
+
- Alternatively, you can set `["vectorstore"]["embedd_files"]` to `True` in the `code/modules/config/config.yaml` file, which will embed files from the storage directory every time you run the below chainlit command.
|
31 |
|
32 |
+
5. **Run the Chainlit App**
|
33 |
+
```bash
|
34 |
+
chainlit run main.py
|
35 |
+
```
|
36 |
|
37 |
See the [docs](https://github.com/DL4DS/dl4ds_tutor/tree/main/docs) for more information.
|
38 |
|
39 |
+
## File Structure
|
40 |
+
|
41 |
+
```plaintext
|
42 |
+
code/
|
43 |
+
βββ modules
|
44 |
+
β βββ chat # Contains the chatbot implementation
|
45 |
+
β βββ chat_processor # Contains the implementation to process and log the conversations
|
46 |
+
β βββ config # Contains the configuration files
|
47 |
+
β βββ dataloader # Contains the implementation to load the data from the storage directory
|
48 |
+
β βββ retriever # Contains the implementation to create the retriever
|
49 |
+
β βββ vectorstore # Contains the implementation to create the vector database
|
50 |
+
βββ public
|
51 |
+
β βββ logo_dark.png # Dark theme logo
|
52 |
+
β βββ logo_light.png # Light theme logo
|
53 |
+
β βββ test.css # Custom CSS file
|
54 |
+
βββ main.py
|
55 |
+
|
56 |
+
|
57 |
+
docs/ # Contains the documentation to the codebase and methods used
|
58 |
+
|
59 |
+
storage/
|
60 |
+
βββ data # Store files and URLs here
|
61 |
+
βββ logs # Logs directory, includes logs on vector DB creation, tutor logs, and chunks logged in JSON files
|
62 |
+
βββ models # Local LLMs are loaded from here
|
63 |
+
|
64 |
+
vectorstores/ # Stores the created vector databases
|
65 |
+
|
66 |
+
.env # This needs to be created, store the API keys here
|
67 |
+
```
|
68 |
+
- `code/modules/vectorstore/vectorstore.py`: Instantiates the `VectorStore` class to create the vector database.
|
69 |
+
- `code/modules/vectorstore/store_manager.py`: Instantiates the `VectorStoreManager:` class to manage the vector database, and all associated methods.
|
70 |
+
- `code/modules/retriever/retriever.py`: Instantiates the `Retriever` class to create the retriever.
|
71 |
+
|
72 |
+
|
73 |
+
## Docker
|
74 |
+
|
75 |
+
The HuggingFace Space is built using the `Dockerfile` in the repository. To run it locally, use the `Dockerfile.dev` file.
|
76 |
+
|
77 |
+
```bash
|
78 |
+
docker build --tag dev -f Dockerfile.dev .
|
79 |
+
docker run -it --rm -p 8051:8051 dev
|
80 |
+
```
|
81 |
+
|
82 |
## Contributing
|
83 |
|
84 |
+
Please create an issue if you have any suggestions or improvements, and start working on it by creating a branch and by making a pull request to the main branch.
|
{.chainlit β code/.chainlit}/config.toml
RENAMED
@@ -19,9 +19,6 @@ allow_origins = ["*"]
|
|
19 |
# follow_symlink = false
|
20 |
|
21 |
[features]
|
22 |
-
# Show the prompt playground
|
23 |
-
prompt_playground = true
|
24 |
-
|
25 |
# Process and display HTML in messages. This can be a security risk (see https://stackoverflow.com/questions/19603097/why-is-it-dangerous-to-render-user-generated-html-or-javascript)
|
26 |
unsafe_allow_html = false
|
27 |
|
@@ -53,26 +50,20 @@ auto_tag_thread = true
|
|
53 |
sample_rate = 44100
|
54 |
|
55 |
[UI]
|
56 |
-
# Name of the
|
57 |
name = "AI Tutor"
|
58 |
|
59 |
-
#
|
60 |
-
|
61 |
-
|
62 |
-
# Description of the app and chatbot. This is used for HTML tags.
|
63 |
-
# description = "AI Tutor - DS598"
|
64 |
|
65 |
# Large size content are by default collapsed for a cleaner ui
|
66 |
default_collapse_content = true
|
67 |
|
68 |
-
# The default value for the expand messages settings.
|
69 |
-
default_expand_messages = false
|
70 |
-
|
71 |
# Hide the chain of thought details from the user in the UI.
|
72 |
-
hide_cot =
|
73 |
|
74 |
# Link to your github repo. This will add a github button in the UI's header.
|
75 |
-
# github = ""
|
76 |
|
77 |
# Specify a CSS file that can be used to customize the user interface.
|
78 |
# The CSS file can be served from the public directory or via an external link.
|
@@ -86,7 +77,7 @@ custom_css = "/public/test.css"
|
|
86 |
# custom_font = "https://fonts.googleapis.com/css2?family=Inter:wght@400;500;700&display=swap"
|
87 |
|
88 |
# Specify a custom meta image url.
|
89 |
-
|
90 |
|
91 |
# Specify a custom build directory for the frontend.
|
92 |
# This can be used to customize the frontend code.
|
@@ -94,18 +85,21 @@ custom_css = "/public/test.css"
|
|
94 |
# custom_build = "./public/build"
|
95 |
|
96 |
[UI.theme]
|
|
|
97 |
#layout = "wide"
|
98 |
#font_family = "Inter, sans-serif"
|
99 |
# Override default MUI light theme. (Check theme.ts)
|
100 |
[UI.theme.light]
|
101 |
-
|
102 |
-
|
103 |
|
104 |
[UI.theme.light.primary]
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
|
|
|
|
109 |
# Override default MUI dark theme. (Check theme.ts)
|
110 |
[UI.theme.dark]
|
111 |
background = "#1C1C1C" # Slightly lighter dark background color
|
@@ -118,4 +112,4 @@ custom_css = "/public/test.css"
|
|
118 |
|
119 |
|
120 |
[meta]
|
121 |
-
generated_by = "1.1.
|
|
|
19 |
# follow_symlink = false
|
20 |
|
21 |
[features]
|
|
|
|
|
|
|
22 |
# Process and display HTML in messages. This can be a security risk (see https://stackoverflow.com/questions/19603097/why-is-it-dangerous-to-render-user-generated-html-or-javascript)
|
23 |
unsafe_allow_html = false
|
24 |
|
|
|
50 |
sample_rate = 44100
|
51 |
|
52 |
[UI]
|
53 |
+
# Name of the assistant.
|
54 |
name = "AI Tutor"
|
55 |
|
56 |
+
# Description of the assistant. This is used for HTML tags.
|
57 |
+
# description = ""
|
|
|
|
|
|
|
58 |
|
59 |
# Large size content are by default collapsed for a cleaner ui
|
60 |
default_collapse_content = true
|
61 |
|
|
|
|
|
|
|
62 |
# Hide the chain of thought details from the user in the UI.
|
63 |
+
hide_cot = true
|
64 |
|
65 |
# Link to your github repo. This will add a github button in the UI's header.
|
66 |
+
# github = "https://github.com/DL4DS/dl4ds_tutor"
|
67 |
|
68 |
# Specify a CSS file that can be used to customize the user interface.
|
69 |
# The CSS file can be served from the public directory or via an external link.
|
|
|
77 |
# custom_font = "https://fonts.googleapis.com/css2?family=Inter:wght@400;500;700&display=swap"
|
78 |
|
79 |
# Specify a custom meta image url.
|
80 |
+
custom_meta_image_url = "https://upload.wikimedia.org/wikipedia/commons/thumb/f/f5/Boston_University_seal.svg/1200px-Boston_University_seal.svg.png"
|
81 |
|
82 |
# Specify a custom build directory for the frontend.
|
83 |
# This can be used to customize the frontend code.
|
|
|
85 |
# custom_build = "./public/build"
|
86 |
|
87 |
[UI.theme]
|
88 |
+
default = "light"
|
89 |
#layout = "wide"
|
90 |
#font_family = "Inter, sans-serif"
|
91 |
# Override default MUI light theme. (Check theme.ts)
|
92 |
[UI.theme.light]
|
93 |
+
background = "#FAFAFA"
|
94 |
+
paper = "#FFFFFF"
|
95 |
|
96 |
[UI.theme.light.primary]
|
97 |
+
main = "#b22222" # Brighter shade of red
|
98 |
+
dark = "#8b0000" # Darker shade of the brighter red
|
99 |
+
light = "#ff6347" # Lighter shade of the brighter red
|
100 |
+
[UI.theme.light.text]
|
101 |
+
primary = "#212121"
|
102 |
+
secondary = "#616161"
|
103 |
# Override default MUI dark theme. (Check theme.ts)
|
104 |
[UI.theme.dark]
|
105 |
background = "#1C1C1C" # Slightly lighter dark background color
|
|
|
112 |
|
113 |
|
114 |
[meta]
|
115 |
+
generated_by = "1.1.302"
|
code/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .modules import *
|
chainlit.md β code/chainlit.md
RENAMED
File without changes
|
code/main.py
CHANGED
@@ -1,9 +1,8 @@
|
|
1 |
-
from
|
2 |
-
from
|
3 |
-
from
|
4 |
-
from
|
5 |
from langchain.chains import RetrievalQA
|
6 |
-
from langchain.llms import CTransformers
|
7 |
import chainlit as cl
|
8 |
from langchain_community.chat_models import ChatOpenAI
|
9 |
from langchain_community.embeddings import OpenAIEmbeddings
|
@@ -11,27 +10,48 @@ import yaml
|
|
11 |
import logging
|
12 |
from dotenv import load_dotenv
|
13 |
|
14 |
-
from modules.llm_tutor import LLMTutor
|
15 |
-
from modules.constants import *
|
16 |
-
from modules.helpers import get_sources
|
17 |
-
|
18 |
|
|
|
|
|
19 |
logger = logging.getLogger(__name__)
|
20 |
logger.setLevel(logging.INFO)
|
|
|
21 |
|
22 |
# Console Handler
|
23 |
console_handler = logging.StreamHandler()
|
24 |
console_handler.setLevel(logging.INFO)
|
25 |
-
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
|
26 |
console_handler.setFormatter(formatter)
|
27 |
logger.addHandler(console_handler)
|
28 |
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
|
36 |
|
37 |
# Adding option to select the chat profile
|
@@ -66,12 +86,26 @@ def rename(orig_author: str):
|
|
66 |
# chainlit code
|
67 |
@cl.on_chat_start
|
68 |
async def start():
|
69 |
-
with open("
|
70 |
config = yaml.safe_load(f)
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
75 |
|
76 |
chat_profile = cl.user_session.get("chat_profile")
|
77 |
if chat_profile is not None:
|
@@ -93,36 +127,50 @@ async def start():
|
|
93 |
llm_tutor = LLMTutor(config, logger=logger)
|
94 |
|
95 |
chain = llm_tutor.qa_bot()
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
msg.
|
100 |
-
await msg.update()
|
101 |
|
|
|
|
|
102 |
cl.user_session.set("chain", chain)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
103 |
|
104 |
|
105 |
@cl.on_message
|
106 |
async def main(message):
|
|
|
107 |
user = cl.user_session.get("user")
|
108 |
chain = cl.user_session.get("chain")
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
116 |
try:
|
117 |
answer = res["answer"]
|
118 |
except:
|
119 |
answer = res["result"]
|
120 |
-
print(f"answer: {answer}")
|
121 |
-
|
122 |
-
logger.info(f"Question: {res['question']}")
|
123 |
-
logger.info(f"History: {res['chat_history']}")
|
124 |
-
logger.info(f"Answer: {answer}\n")
|
125 |
|
126 |
-
answer_with_sources, source_elements = get_sources(res, answer)
|
|
|
127 |
|
128 |
await cl.Message(content=answer_with_sources, elements=source_elements).send()
|
|
|
1 |
+
from langchain_community.document_loaders import PyPDFLoader, DirectoryLoader
|
2 |
+
from langchain_core.prompts import PromptTemplate
|
3 |
+
from langchain_community.embeddings import HuggingFaceEmbeddings
|
4 |
+
from langchain_community.vectorstores import FAISS
|
5 |
from langchain.chains import RetrievalQA
|
|
|
6 |
import chainlit as cl
|
7 |
from langchain_community.chat_models import ChatOpenAI
|
8 |
from langchain_community.embeddings import OpenAIEmbeddings
|
|
|
10 |
import logging
|
11 |
from dotenv import load_dotenv
|
12 |
|
13 |
+
from modules.chat.llm_tutor import LLMTutor
|
14 |
+
from modules.config.constants import *
|
15 |
+
from modules.chat.helpers import get_sources
|
16 |
+
from modules.chat_processor.chat_processor import ChatProcessor
|
17 |
|
18 |
+
global logger
|
19 |
+
# Initialize logger
|
20 |
logger = logging.getLogger(__name__)
|
21 |
logger.setLevel(logging.INFO)
|
22 |
+
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
|
23 |
|
24 |
# Console Handler
|
25 |
console_handler = logging.StreamHandler()
|
26 |
console_handler.setLevel(logging.INFO)
|
|
|
27 |
console_handler.setFormatter(formatter)
|
28 |
logger.addHandler(console_handler)
|
29 |
|
30 |
+
|
31 |
+
@cl.set_starters
|
32 |
+
async def set_starters():
|
33 |
+
return [
|
34 |
+
cl.Starter(
|
35 |
+
label="recording on CNNs?",
|
36 |
+
message="Where can I find the recording for the lecture on Transfromers?",
|
37 |
+
icon="/public/adv-screen-recorder-svgrepo-com.svg",
|
38 |
+
),
|
39 |
+
cl.Starter(
|
40 |
+
label="where's the slides?",
|
41 |
+
message="When are the lectures? I can't find the schedule.",
|
42 |
+
icon="/public/alarmy-svgrepo-com.svg",
|
43 |
+
),
|
44 |
+
cl.Starter(
|
45 |
+
label="Due Date?",
|
46 |
+
message="When is the final project due?",
|
47 |
+
icon="/public/calendar-samsung-17-svgrepo-com.svg",
|
48 |
+
),
|
49 |
+
cl.Starter(
|
50 |
+
label="Explain backprop.",
|
51 |
+
message="I didnt understand the math behind backprop, could you explain it?",
|
52 |
+
icon="/public/acastusphoton-svgrepo-com.svg",
|
53 |
+
),
|
54 |
+
]
|
55 |
|
56 |
|
57 |
# Adding option to select the chat profile
|
|
|
86 |
# chainlit code
|
87 |
@cl.on_chat_start
|
88 |
async def start():
|
89 |
+
with open("modules/config/config.yml", "r") as f:
|
90 |
config = yaml.safe_load(f)
|
91 |
+
|
92 |
+
# Ensure log directory exists
|
93 |
+
log_directory = config["log_dir"]
|
94 |
+
if not os.path.exists(log_directory):
|
95 |
+
os.makedirs(log_directory)
|
96 |
+
|
97 |
+
# File Handler
|
98 |
+
log_file_path = (
|
99 |
+
f"{log_directory}/tutor.log" # Change this to your desired log file path
|
100 |
+
)
|
101 |
+
file_handler = logging.FileHandler(log_file_path, mode="w")
|
102 |
+
file_handler.setLevel(logging.INFO)
|
103 |
+
file_handler.setFormatter(formatter)
|
104 |
+
logger.addHandler(file_handler)
|
105 |
+
|
106 |
+
logger.info("Config file loaded")
|
107 |
+
logger.info(f"Config: {config}")
|
108 |
+
logger.info("Creating llm_tutor instance")
|
109 |
|
110 |
chat_profile = cl.user_session.get("chat_profile")
|
111 |
if chat_profile is not None:
|
|
|
127 |
llm_tutor = LLMTutor(config, logger=logger)
|
128 |
|
129 |
chain = llm_tutor.qa_bot()
|
130 |
+
# msg = cl.Message(content=f"Starting the bot {chat_profile}...")
|
131 |
+
# await msg.send()
|
132 |
+
# msg.content = opening_message
|
133 |
+
# await msg.update()
|
|
|
134 |
|
135 |
+
tags = [chat_profile, config["vectorstore"]["db_option"]]
|
136 |
+
chat_processor = ChatProcessor(config, tags=tags)
|
137 |
cl.user_session.set("chain", chain)
|
138 |
+
cl.user_session.set("counter", 0)
|
139 |
+
cl.user_session.set("chat_processor", chat_processor)
|
140 |
+
|
141 |
+
|
142 |
+
@cl.on_chat_end
|
143 |
+
async def on_chat_end():
|
144 |
+
await cl.Message(content="Sorry, I have to go now. Goodbye!").send()
|
145 |
|
146 |
|
147 |
@cl.on_message
|
148 |
async def main(message):
|
149 |
+
global logger
|
150 |
user = cl.user_session.get("user")
|
151 |
chain = cl.user_session.get("chain")
|
152 |
+
|
153 |
+
counter = cl.user_session.get("counter")
|
154 |
+
counter += 1
|
155 |
+
cl.user_session.set("counter", counter)
|
156 |
+
|
157 |
+
# if counter >= 3: # Ensure the counter condition is checked
|
158 |
+
# await cl.Message(content="Your credits are up!").send()
|
159 |
+
# await on_chat_end() # Call the on_chat_end function to handle the end of the chat
|
160 |
+
# return # Exit the function to stop further processing
|
161 |
+
# else:
|
162 |
+
|
163 |
+
cb = cl.AsyncLangchainCallbackHandler() # TODO: fix streaming here
|
164 |
+
cb.answer_reached = True
|
165 |
+
|
166 |
+
processor = cl.user_session.get("chat_processor")
|
167 |
+
res = await processor.rag(message.content, chain, cb)
|
168 |
try:
|
169 |
answer = res["answer"]
|
170 |
except:
|
171 |
answer = res["result"]
|
|
|
|
|
|
|
|
|
|
|
172 |
|
173 |
+
answer_with_sources, source_elements, sources_dict = get_sources(res, answer)
|
174 |
+
processor._process(message.content, answer, sources_dict)
|
175 |
|
176 |
await cl.Message(content=answer_with_sources, elements=source_elements).send()
|
code/modules/chat/__init__.py
ADDED
File without changes
|
code/modules/{chat_model_loader.py β chat/chat_model_loader.py}
RENAMED
@@ -1,8 +1,7 @@
|
|
1 |
from langchain_community.chat_models import ChatOpenAI
|
2 |
-
from
|
3 |
-
from langchain.llms.huggingface_pipeline import HuggingFacePipeline
|
4 |
from transformers import AutoTokenizer, TextStreamer
|
5 |
-
from
|
6 |
import torch
|
7 |
import transformers
|
8 |
import os
|
|
|
1 |
from langchain_community.chat_models import ChatOpenAI
|
2 |
+
from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
|
|
|
3 |
from transformers import AutoTokenizer, TextStreamer
|
4 |
+
from langchain_community.llms import LlamaCpp
|
5 |
import torch
|
6 |
import transformers
|
7 |
import os
|
code/modules/chat/helpers.py
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from modules.config.constants import *
|
2 |
+
import chainlit as cl
|
3 |
+
from langchain_core.prompts import PromptTemplate
|
4 |
+
|
5 |
+
|
6 |
+
def get_sources(res, answer):
|
7 |
+
source_elements = []
|
8 |
+
source_dict = {} # Dictionary to store URL elements
|
9 |
+
|
10 |
+
for idx, source in enumerate(res["source_documents"]):
|
11 |
+
source_metadata = source.metadata
|
12 |
+
url = source_metadata.get("source", "N/A")
|
13 |
+
score = source_metadata.get("score", "N/A")
|
14 |
+
page = source_metadata.get("page", 1)
|
15 |
+
|
16 |
+
lecture_tldr = source_metadata.get("tldr", "N/A")
|
17 |
+
lecture_recording = source_metadata.get("lecture_recording", "N/A")
|
18 |
+
suggested_readings = source_metadata.get("suggested_readings", "N/A")
|
19 |
+
date = source_metadata.get("date", "N/A")
|
20 |
+
|
21 |
+
source_type = source_metadata.get("source_type", "N/A")
|
22 |
+
|
23 |
+
url_name = f"{url}_{page}"
|
24 |
+
if url_name not in source_dict:
|
25 |
+
source_dict[url_name] = {
|
26 |
+
"text": source.page_content,
|
27 |
+
"url": url,
|
28 |
+
"score": score,
|
29 |
+
"page": page,
|
30 |
+
"lecture_tldr": lecture_tldr,
|
31 |
+
"lecture_recording": lecture_recording,
|
32 |
+
"suggested_readings": suggested_readings,
|
33 |
+
"date": date,
|
34 |
+
"source_type": source_type,
|
35 |
+
}
|
36 |
+
else:
|
37 |
+
source_dict[url_name]["text"] += f"\n\n{source.page_content}"
|
38 |
+
|
39 |
+
# First, display the answer
|
40 |
+
full_answer = "**Answer:**\n"
|
41 |
+
full_answer += answer
|
42 |
+
|
43 |
+
# Then, display the sources
|
44 |
+
full_answer += "\n\n**Sources:**\n"
|
45 |
+
for idx, (url_name, source_data) in enumerate(source_dict.items()):
|
46 |
+
full_answer += f"\nSource {idx + 1} (Score: {source_data['score']}): {source_data['url']}\n"
|
47 |
+
|
48 |
+
name = f"Source {idx + 1} Text\n"
|
49 |
+
full_answer += name
|
50 |
+
source_elements.append(
|
51 |
+
cl.Text(name=name, content=source_data["text"], display="side")
|
52 |
+
)
|
53 |
+
|
54 |
+
# Add a PDF element if the source is a PDF file
|
55 |
+
if source_data["url"].lower().endswith(".pdf"):
|
56 |
+
name = f"Source {idx + 1} PDF\n"
|
57 |
+
full_answer += name
|
58 |
+
pdf_url = f"{source_data['url']}#page={source_data['page']+1}"
|
59 |
+
source_elements.append(cl.Pdf(name=name, url=pdf_url, display="side"))
|
60 |
+
|
61 |
+
full_answer += "\n**Metadata:**\n"
|
62 |
+
for idx, (url_name, source_data) in enumerate(source_dict.items()):
|
63 |
+
full_answer += f"\nSource {idx + 1} Metadata:\n"
|
64 |
+
source_elements.append(
|
65 |
+
cl.Text(
|
66 |
+
name=f"Source {idx + 1} Metadata",
|
67 |
+
content=f"Source: {source_data['url']}\n"
|
68 |
+
f"Page: {source_data['page']}\n"
|
69 |
+
f"Type: {source_data['source_type']}\n"
|
70 |
+
f"Date: {source_data['date']}\n"
|
71 |
+
f"TL;DR: {source_data['lecture_tldr']}\n"
|
72 |
+
f"Lecture Recording: {source_data['lecture_recording']}\n"
|
73 |
+
f"Suggested Readings: {source_data['suggested_readings']}\n",
|
74 |
+
display="side",
|
75 |
+
)
|
76 |
+
)
|
77 |
+
|
78 |
+
return full_answer, source_elements, source_dict
|
79 |
+
|
80 |
+
|
81 |
+
def get_prompt(config):
|
82 |
+
if config["llm_params"]["use_history"]:
|
83 |
+
if config["llm_params"]["llm_loader"] == "local_llm":
|
84 |
+
custom_prompt_template = tinyllama_prompt_template_with_history
|
85 |
+
elif config["llm_params"]["llm_loader"] == "openai":
|
86 |
+
custom_prompt_template = openai_prompt_template_with_history
|
87 |
+
# else:
|
88 |
+
# custom_prompt_template = tinyllama_prompt_template_with_history # default
|
89 |
+
prompt = PromptTemplate(
|
90 |
+
template=custom_prompt_template,
|
91 |
+
input_variables=["context", "chat_history", "question"],
|
92 |
+
)
|
93 |
+
else:
|
94 |
+
if config["llm_params"]["llm_loader"] == "local_llm":
|
95 |
+
custom_prompt_template = tinyllama_prompt_template
|
96 |
+
elif config["llm_params"]["llm_loader"] == "openai":
|
97 |
+
custom_prompt_template = openai_prompt_template
|
98 |
+
# else:
|
99 |
+
# custom_prompt_template = tinyllama_prompt_template
|
100 |
+
prompt = PromptTemplate(
|
101 |
+
template=custom_prompt_template,
|
102 |
+
input_variables=["context", "question"],
|
103 |
+
)
|
104 |
+
return prompt
|
code/modules/{llm_tutor.py β chat/llm_tutor.py}
RENAMED
@@ -1,24 +1,52 @@
|
|
1 |
-
from langchain import PromptTemplate
|
2 |
-
from langchain.embeddings import HuggingFaceEmbeddings
|
3 |
-
from langchain_community.chat_models import ChatOpenAI
|
4 |
-
from langchain_community.embeddings import OpenAIEmbeddings
|
5 |
-
from langchain.vectorstores import FAISS
|
6 |
from langchain.chains import RetrievalQA, ConversationalRetrievalChain
|
7 |
-
from langchain.
|
8 |
-
|
|
|
|
|
9 |
from langchain.chains.conversational_retrieval.prompts import QA_PROMPT
|
10 |
import os
|
11 |
-
from modules.constants import *
|
12 |
-
from modules.helpers import get_prompt
|
13 |
-
from modules.chat_model_loader import ChatModelLoader
|
14 |
-
from modules.
|
15 |
-
|
|
|
|
|
|
|
16 |
from langchain_core.callbacks.manager import AsyncCallbackManagerForChainRun
|
17 |
import inspect
|
18 |
from langchain.chains.conversational_retrieval.base import _get_chat_history
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
|
20 |
|
21 |
class CustomConversationalRetrievalChain(ConversationalRetrievalChain):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
async def _acall(
|
23 |
self,
|
24 |
inputs: Dict[str, Any],
|
@@ -26,13 +54,34 @@ class CustomConversationalRetrievalChain(ConversationalRetrievalChain):
|
|
26 |
) -> Dict[str, Any]:
|
27 |
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
|
28 |
question = inputs["question"]
|
29 |
-
get_chat_history = self.
|
30 |
chat_history_str = get_chat_history(inputs["chat_history"])
|
31 |
-
print(f"chat_history_str: {chat_history_str}")
|
32 |
if chat_history_str:
|
33 |
-
callbacks = _run_manager.get_child()
|
34 |
-
new_question = await self.question_generator.arun(
|
35 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
)
|
37 |
else:
|
38 |
new_question = question
|
@@ -45,6 +94,7 @@ class CustomConversationalRetrievalChain(ConversationalRetrievalChain):
|
|
45 |
docs = await self._aget_docs(new_question, inputs) # type: ignore[call-arg]
|
46 |
|
47 |
output: Dict[str, Any] = {}
|
|
|
48 |
if self.response_if_no_docs_found is not None and len(docs) == 0:
|
49 |
output[self.output_key] = self.response_if_no_docs_found
|
50 |
else:
|
@@ -56,31 +106,25 @@ class CustomConversationalRetrievalChain(ConversationalRetrievalChain):
|
|
56 |
# Prepare the final prompt with metadata
|
57 |
context = "\n\n".join(
|
58 |
[
|
59 |
-
f"Document content: {doc.page_content}\nMetadata: {doc.metadata}"
|
60 |
-
for doc in docs
|
61 |
]
|
62 |
)
|
63 |
-
final_prompt =
|
64 |
-
You are an AI Tutor for the course DS598, taught by Prof. Thomas Gardos.
|
65 |
-
If you don't know the answer,
|
66 |
-
Use
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
Context:
|
75 |
-
{context}
|
76 |
-
|
77 |
-
Question: {new_question}
|
78 |
-
AI Tutor:
|
79 |
-
"""
|
80 |
|
81 |
-
new_inputs["input"] = final_prompt
|
82 |
new_inputs["question"] = final_prompt
|
83 |
-
output["final_prompt"] = final_prompt
|
84 |
|
85 |
answer = await self.combine_docs_chain.arun(
|
86 |
input_documents=docs, callbacks=_run_manager.get_child(), **new_inputs
|
@@ -89,8 +133,7 @@ class CustomConversationalRetrievalChain(ConversationalRetrievalChain):
|
|
89 |
|
90 |
if self.return_source_documents:
|
91 |
output["source_documents"] = docs
|
92 |
-
|
93 |
-
output["generated_question"] = new_question
|
94 |
return output
|
95 |
|
96 |
|
@@ -98,8 +141,9 @@ class LLMTutor:
|
|
98 |
def __init__(self, config, logger=None):
|
99 |
self.config = config
|
100 |
self.llm = self.load_llm()
|
101 |
-
self.
|
102 |
-
|
|
|
103 |
self.vector_db.create_database()
|
104 |
self.vector_db.save_database()
|
105 |
|
@@ -114,24 +158,11 @@ class LLMTutor:
|
|
114 |
|
115 |
# Retrieval QA Chain
|
116 |
def retrieval_qa_chain(self, llm, prompt, db):
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
# search_type="similarity_score_threshold",
|
121 |
-
# search_kwargs={
|
122 |
-
# "score_threshold": self.config["embedding_options"][
|
123 |
-
# "score_threshold"
|
124 |
-
# ],
|
125 |
-
# "k": self.config["embedding_options"]["search_top_k"],
|
126 |
-
# },
|
127 |
-
)
|
128 |
-
elif self.config["embedding_options"]["db_option"] == "RAGatouille":
|
129 |
-
retriever = db.as_langchain_retriever(
|
130 |
-
k=self.config["embedding_options"]["search_top_k"]
|
131 |
-
)
|
132 |
if self.config["llm_params"]["use_history"]:
|
133 |
-
memory =
|
134 |
-
llm = llm,
|
135 |
k=self.config["llm_params"]["memory_window"],
|
136 |
memory_key="chat_history",
|
137 |
return_messages=True,
|
@@ -145,6 +176,7 @@ class LLMTutor:
|
|
145 |
return_source_documents=True,
|
146 |
memory=memory,
|
147 |
combine_docs_chain_kwargs={"prompt": prompt},
|
|
|
148 |
)
|
149 |
else:
|
150 |
qa_chain = RetrievalQA.from_chain_type(
|
@@ -166,7 +198,9 @@ class LLMTutor:
|
|
166 |
def qa_bot(self):
|
167 |
db = self.vector_db.load_database()
|
168 |
qa_prompt = self.set_custom_prompt()
|
169 |
-
qa = self.retrieval_qa_chain(
|
|
|
|
|
170 |
|
171 |
return qa
|
172 |
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
from langchain.chains import RetrievalQA, ConversationalRetrievalChain
|
2 |
+
from langchain.memory import (
|
3 |
+
ConversationBufferWindowMemory,
|
4 |
+
ConversationSummaryBufferMemory,
|
5 |
+
)
|
6 |
from langchain.chains.conversational_retrieval.prompts import QA_PROMPT
|
7 |
import os
|
8 |
+
from modules.config.constants import *
|
9 |
+
from modules.chat.helpers import get_prompt
|
10 |
+
from modules.chat.chat_model_loader import ChatModelLoader
|
11 |
+
from modules.vectorstore.store_manager import VectorStoreManager
|
12 |
+
|
13 |
+
from modules.retriever.retriever import Retriever
|
14 |
+
|
15 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
|
16 |
from langchain_core.callbacks.manager import AsyncCallbackManagerForChainRun
|
17 |
import inspect
|
18 |
from langchain.chains.conversational_retrieval.base import _get_chat_history
|
19 |
+
from langchain_core.messages import BaseMessage
|
20 |
+
|
21 |
+
CHAT_TURN_TYPE = Union[Tuple[str, str], BaseMessage]
|
22 |
+
|
23 |
+
from langchain_core.output_parsers import StrOutputParser
|
24 |
+
from langchain_core.prompts import ChatPromptTemplate
|
25 |
+
from langchain_community.chat_models import ChatOpenAI
|
26 |
|
27 |
|
28 |
class CustomConversationalRetrievalChain(ConversationalRetrievalChain):
|
29 |
+
|
30 |
+
def _get_chat_history(self, chat_history: List[CHAT_TURN_TYPE]) -> str:
|
31 |
+
_ROLE_MAP = {"human": "Student: ", "ai": "AI Tutor: "}
|
32 |
+
buffer = ""
|
33 |
+
for dialogue_turn in chat_history:
|
34 |
+
if isinstance(dialogue_turn, BaseMessage):
|
35 |
+
role_prefix = _ROLE_MAP.get(
|
36 |
+
dialogue_turn.type, f"{dialogue_turn.type}: "
|
37 |
+
)
|
38 |
+
buffer += f"\n{role_prefix}{dialogue_turn.content}"
|
39 |
+
elif isinstance(dialogue_turn, tuple):
|
40 |
+
human = "Student: " + dialogue_turn[0]
|
41 |
+
ai = "AI Tutor: " + dialogue_turn[1]
|
42 |
+
buffer += "\n" + "\n".join([human, ai])
|
43 |
+
else:
|
44 |
+
raise ValueError(
|
45 |
+
f"Unsupported chat history format: {type(dialogue_turn)}."
|
46 |
+
f" Full chat history: {chat_history} "
|
47 |
+
)
|
48 |
+
return buffer
|
49 |
+
|
50 |
async def _acall(
|
51 |
self,
|
52 |
inputs: Dict[str, Any],
|
|
|
54 |
) -> Dict[str, Any]:
|
55 |
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
|
56 |
question = inputs["question"]
|
57 |
+
get_chat_history = self._get_chat_history
|
58 |
chat_history_str = get_chat_history(inputs["chat_history"])
|
|
|
59 |
if chat_history_str:
|
60 |
+
# callbacks = _run_manager.get_child()
|
61 |
+
# new_question = await self.question_generator.arun(
|
62 |
+
# question=question, chat_history=chat_history_str, callbacks=callbacks
|
63 |
+
# )
|
64 |
+
system = (
|
65 |
+
"You are someone that rephrases statements. Rephrase the student's question to add context from their chat history if relevant, ensuring it remains from the student's point of view. "
|
66 |
+
"Incorporate relevant details from the chat history to make the question clearer and more specific."
|
67 |
+
"Do not change the meaning of the original statement, and maintain the student's tone and perspective. "
|
68 |
+
"If the question is conversational and doesn't require context, do not rephrase it. "
|
69 |
+
"Example: If the student previously asked about backpropagation in the context of deep learning and now asks 'what is it', rephrase to 'What is backprogatation.'. "
|
70 |
+
"Example: Do not rephrase if the user is asking something specific like 'cool, suggest a project with transformers to use as my final project'"
|
71 |
+
"Chat history: \n{chat_history_str}\n"
|
72 |
+
"Rephrase the following question only if necessary: '{question}'"
|
73 |
+
)
|
74 |
+
|
75 |
+
prompt = ChatPromptTemplate.from_messages(
|
76 |
+
[
|
77 |
+
("system", system),
|
78 |
+
("human", "{question}, {chat_history_str}"),
|
79 |
+
]
|
80 |
+
)
|
81 |
+
llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
|
82 |
+
step_back = prompt | llm | StrOutputParser()
|
83 |
+
new_question = step_back.invoke(
|
84 |
+
{"question": question, "chat_history_str": chat_history_str}
|
85 |
)
|
86 |
else:
|
87 |
new_question = question
|
|
|
94 |
docs = await self._aget_docs(new_question, inputs) # type: ignore[call-arg]
|
95 |
|
96 |
output: Dict[str, Any] = {}
|
97 |
+
output["original_question"] = question
|
98 |
if self.response_if_no_docs_found is not None and len(docs) == 0:
|
99 |
output[self.output_key] = self.response_if_no_docs_found
|
100 |
else:
|
|
|
106 |
# Prepare the final prompt with metadata
|
107 |
context = "\n\n".join(
|
108 |
[
|
109 |
+
f"Context {idx+1}: \n(Document content: {doc.page_content}\nMetadata: (source_file: {doc.metadata['source'] if 'source' in doc.metadata else 'unknown'}))"
|
110 |
+
for idx, doc in enumerate(docs)
|
111 |
]
|
112 |
)
|
113 |
+
final_prompt = (
|
114 |
+
"You are an AI Tutor for the course DS598, taught by Prof. Thomas Gardos. Answer the user's question using the provided context. Only use the context if it is relevant. The context is ordered by relevance."
|
115 |
+
"If you don't know the answer, do your best without making things up. Keep the conversation flowing naturally. "
|
116 |
+
"Use chat history and context as guides but avoid repeating past responses. Provide links from the source_file metadata. Use the source context that is most relevent."
|
117 |
+
"Speak in a friendly and engaging manner, like talking to a friend. Avoid sounding repetitive or robotic.\n\n"
|
118 |
+
f"Chat History:\n{chat_history_str}\n\n"
|
119 |
+
f"Context:\n{context}\n\n"
|
120 |
+
"Answer the student's question below in a friendly, concise, and engaging manner. Use the context and history only if relevant, otherwise, engage in a free-flowing conversation.\n"
|
121 |
+
f"Student: {question}\n"
|
122 |
+
"AI Tutor:"
|
123 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
124 |
|
125 |
+
# new_inputs["input"] = final_prompt
|
126 |
new_inputs["question"] = final_prompt
|
127 |
+
# output["final_prompt"] = final_prompt
|
128 |
|
129 |
answer = await self.combine_docs_chain.arun(
|
130 |
input_documents=docs, callbacks=_run_manager.get_child(), **new_inputs
|
|
|
133 |
|
134 |
if self.return_source_documents:
|
135 |
output["source_documents"] = docs
|
136 |
+
output["rephrased_question"] = new_question
|
|
|
137 |
return output
|
138 |
|
139 |
|
|
|
141 |
def __init__(self, config, logger=None):
|
142 |
self.config = config
|
143 |
self.llm = self.load_llm()
|
144 |
+
self.logger = logger
|
145 |
+
self.vector_db = VectorStoreManager(config, logger=self.logger)
|
146 |
+
if self.config["vectorstore"]["embedd_files"]:
|
147 |
self.vector_db.create_database()
|
148 |
self.vector_db.save_database()
|
149 |
|
|
|
158 |
|
159 |
# Retrieval QA Chain
|
160 |
def retrieval_qa_chain(self, llm, prompt, db):
|
161 |
+
|
162 |
+
retriever = Retriever(self.config)._return_retriever(db)
|
163 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
164 |
if self.config["llm_params"]["use_history"]:
|
165 |
+
memory = ConversationBufferWindowMemory(
|
|
|
166 |
k=self.config["llm_params"]["memory_window"],
|
167 |
memory_key="chat_history",
|
168 |
return_messages=True,
|
|
|
176 |
return_source_documents=True,
|
177 |
memory=memory,
|
178 |
combine_docs_chain_kwargs={"prompt": prompt},
|
179 |
+
response_if_no_docs_found="No context found",
|
180 |
)
|
181 |
else:
|
182 |
qa_chain = RetrievalQA.from_chain_type(
|
|
|
198 |
def qa_bot(self):
|
199 |
db = self.vector_db.load_database()
|
200 |
qa_prompt = self.set_custom_prompt()
|
201 |
+
qa = self.retrieval_qa_chain(
|
202 |
+
self.llm, qa_prompt, db
|
203 |
+
) # TODO: PROMPT is overwritten in CustomConversationalRetrievalChain
|
204 |
|
205 |
return qa
|
206 |
|
code/modules/chat_processor/__init__.py
ADDED
File without changes
|
code/modules/chat_processor/base.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Template for chat processor classes
|
2 |
+
|
3 |
+
|
4 |
+
class ChatProcessorBase:
|
5 |
+
def __init__(self, config):
|
6 |
+
self.config = config
|
7 |
+
|
8 |
+
def process(self, message):
|
9 |
+
"""
|
10 |
+
Processes and Logs the message
|
11 |
+
"""
|
12 |
+
raise NotImplementedError("process method not implemented")
|
code/modules/chat_processor/chat_processor.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from modules.chat_processor.literal_ai import LiteralaiChatProcessor
|
2 |
+
|
3 |
+
|
4 |
+
class ChatProcessor:
|
5 |
+
def __init__(self, config, tags=None):
|
6 |
+
self.chat_processor_type = config["chat_logging"]["platform"]
|
7 |
+
self.logging = config["chat_logging"]["log_chat"]
|
8 |
+
self.tags = tags
|
9 |
+
self._init_processor()
|
10 |
+
|
11 |
+
def _init_processor(self):
|
12 |
+
if self.chat_processor_type == "literalai":
|
13 |
+
self.processor = LiteralaiChatProcessor(self.tags)
|
14 |
+
else:
|
15 |
+
raise ValueError(
|
16 |
+
f"Chat processor type {self.chat_processor_type} not supported"
|
17 |
+
)
|
18 |
+
|
19 |
+
def _process(self, user_message, assistant_message, source_dict):
|
20 |
+
if self.logging:
|
21 |
+
return self.processor.process(user_message, assistant_message, source_dict)
|
22 |
+
else:
|
23 |
+
pass
|
24 |
+
|
25 |
+
async def rag(self, user_query: str, chain, cb):
|
26 |
+
if self.logging:
|
27 |
+
return await self.processor.rag(user_query, chain, cb)
|
28 |
+
else:
|
29 |
+
return await chain.acall(user_query, callbacks=[cb])
|
code/modules/chat_processor/literal_ai.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from literalai import LiteralClient
|
2 |
+
import os
|
3 |
+
from .base import ChatProcessorBase
|
4 |
+
|
5 |
+
|
6 |
+
class LiteralaiChatProcessor(ChatProcessorBase):
|
7 |
+
def __init__(self, tags=None):
|
8 |
+
self.literal_client = LiteralClient(api_key=os.getenv("LITERAL_API_KEY"))
|
9 |
+
self.literal_client.reset_context()
|
10 |
+
with self.literal_client.thread(name="TEST") as thread:
|
11 |
+
self.thread_id = thread.id
|
12 |
+
self.thread = thread
|
13 |
+
if tags is not None and type(tags) == list:
|
14 |
+
self.thread.tags = tags
|
15 |
+
print(f"Thread ID: {self.thread}")
|
16 |
+
|
17 |
+
def process(self, user_message, assistant_message, source_dict):
|
18 |
+
with self.literal_client.thread(thread_id=self.thread_id) as thread:
|
19 |
+
self.literal_client.message(
|
20 |
+
content=user_message,
|
21 |
+
type="user_message",
|
22 |
+
name="User",
|
23 |
+
)
|
24 |
+
self.literal_client.message(
|
25 |
+
content=assistant_message,
|
26 |
+
type="assistant_message",
|
27 |
+
name="AI_Tutor",
|
28 |
+
)
|
29 |
+
|
30 |
+
async def rag(self, user_query: str, chain, cb):
|
31 |
+
with self.literal_client.step(
|
32 |
+
type="retrieval", name="RAG", thread_id=self.thread_id
|
33 |
+
) as step:
|
34 |
+
step.input = {"question": user_query}
|
35 |
+
res = await chain.acall(user_query, callbacks=[cb])
|
36 |
+
step.output = res
|
37 |
+
return res
|
code/modules/config/__init__.py
ADDED
File without changes
|
code/{config.yml β modules/config/config.yml}
RENAMED
@@ -1,13 +1,28 @@
|
|
1 |
-
|
|
|
|
|
|
|
|
|
2 |
embedd_files: False # bool
|
3 |
-
data_path: 'storage/data' # str
|
4 |
-
url_file_path: 'storage/data/urls.txt' # str
|
5 |
expand_urls: True # bool
|
6 |
-
db_option : '
|
7 |
-
db_path : 'vectorstores' # str
|
8 |
model : 'sentence-transformers/all-MiniLM-L6-v2' # str [sentence-transformers/all-MiniLM-L6-v2, text-embedding-ada-002']
|
9 |
search_top_k : 3 # int
|
10 |
score_threshold : 0.2 # float
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
llm_params:
|
12 |
use_history: True # bool
|
13 |
memory_window: 3 # int
|
@@ -15,9 +30,13 @@ llm_params:
|
|
15 |
openai_params:
|
16 |
model: 'gpt-3.5-turbo-1106' # str [gpt-3.5-turbo-1106, gpt-4]
|
17 |
local_llm_params:
|
18 |
-
model:
|
19 |
-
|
20 |
-
|
|
|
|
|
|
|
|
|
21 |
splitter_options:
|
22 |
use_splitter: True # bool
|
23 |
split_by_token : True # bool
|
|
|
1 |
+
log_dir: '../storage/logs' # str
|
2 |
+
log_chunk_dir: '../storage/logs/chunks' # str
|
3 |
+
device: 'cpu' # str [cuda, cpu]
|
4 |
+
|
5 |
+
vectorstore:
|
6 |
embedd_files: False # bool
|
7 |
+
data_path: '../storage/data' # str
|
8 |
+
url_file_path: '../storage/data/urls.txt' # str
|
9 |
expand_urls: True # bool
|
10 |
+
db_option : 'FAISS' # str [FAISS, Chroma, RAGatouille, RAPTOR]
|
11 |
+
db_path : '../vectorstores' # str
|
12 |
model : 'sentence-transformers/all-MiniLM-L6-v2' # str [sentence-transformers/all-MiniLM-L6-v2, text-embedding-ada-002']
|
13 |
search_top_k : 3 # int
|
14 |
score_threshold : 0.2 # float
|
15 |
+
|
16 |
+
faiss_params: # Not used as of now
|
17 |
+
index_path: '../vectorstores/faiss.index' # str
|
18 |
+
index_type: 'Flat' # str [Flat, HNSW, IVF]
|
19 |
+
index_dimension: 384 # int
|
20 |
+
index_nlist: 100 # int
|
21 |
+
index_nprobe: 10 # int
|
22 |
+
|
23 |
+
colbert_params:
|
24 |
+
index_name: "new_idx" # str
|
25 |
+
|
26 |
llm_params:
|
27 |
use_history: True # bool
|
28 |
memory_window: 3 # int
|
|
|
30 |
openai_params:
|
31 |
model: 'gpt-3.5-turbo-1106' # str [gpt-3.5-turbo-1106, gpt-4]
|
32 |
local_llm_params:
|
33 |
+
model: 'tiny-llama'
|
34 |
+
temperature: 0.7
|
35 |
+
|
36 |
+
chat_logging:
|
37 |
+
log_chat: False # bool
|
38 |
+
platform: 'literalai'
|
39 |
+
|
40 |
splitter_options:
|
41 |
use_splitter: True # bool
|
42 |
split_by_token : True # bool
|
code/modules/{constants.py β config/constants.py}
RENAMED
@@ -7,6 +7,7 @@ load_dotenv()
|
|
7 |
|
8 |
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
|
9 |
HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN")
|
|
|
10 |
|
11 |
opening_message = f"Hey, What Can I Help You With?\n\nYou can me ask me questions about the course logistics, course content, about the final project, or anything else!"
|
12 |
|
@@ -77,5 +78,5 @@ Question: {question}
|
|
77 |
|
78 |
# Model Paths
|
79 |
|
80 |
-
LLAMA_PATH = "storage/models/tinyllama-1.1b-chat-v1.0.Q5_K_M.gguf"
|
81 |
MISTRAL_PATH = "storage/models/mistral-7b-v0.1.Q4_K_M.gguf"
|
|
|
7 |
|
8 |
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
|
9 |
HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN")
|
10 |
+
LITERAL_API_KEY = os.getenv("LITERAL_API_KEY")
|
11 |
|
12 |
opening_message = f"Hey, What Can I Help You With?\n\nYou can me ask me questions about the course logistics, course content, about the final project, or anything else!"
|
13 |
|
|
|
78 |
|
79 |
# Model Paths
|
80 |
|
81 |
+
LLAMA_PATH = "../storage/models/tinyllama-1.1b-chat-v1.0.Q5_K_M.gguf"
|
82 |
MISTRAL_PATH = "storage/models/mistral-7b-v0.1.Q4_K_M.gguf"
|
code/modules/dataloader/__init__.py
ADDED
File without changes
|
code/modules/{data_loader.py β dataloader/data_loader.py}
RENAMED
@@ -16,15 +16,12 @@ import logging
|
|
16 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
17 |
from ragatouille import RAGPretrainedModel
|
18 |
from langchain.chains import LLMChain
|
19 |
-
from
|
20 |
from langchain import PromptTemplate
|
|
|
|
|
21 |
|
22 |
-
|
23 |
-
from modules.helpers import get_metadata
|
24 |
-
except:
|
25 |
-
from helpers import get_metadata
|
26 |
-
|
27 |
-
logger = logging.getLogger(__name__)
|
28 |
|
29 |
|
30 |
class PDFReader:
|
@@ -40,8 +37,9 @@ class PDFReader:
|
|
40 |
|
41 |
|
42 |
class FileReader:
|
43 |
-
def __init__(self):
|
44 |
self.pdf_reader = PDFReader()
|
|
|
45 |
|
46 |
def extract_text_from_pdf(self, pdf_path):
|
47 |
text = ""
|
@@ -61,7 +59,7 @@ class FileReader:
|
|
61 |
temp_file_path = temp_file.name
|
62 |
return temp_file_path
|
63 |
else:
|
64 |
-
|
65 |
return None
|
66 |
|
67 |
def read_pdf(self, temp_file_path: str):
|
@@ -99,13 +97,18 @@ class FileReader:
|
|
99 |
if response.status_code == 200:
|
100 |
return [Document(page_content=response.text)]
|
101 |
else:
|
102 |
-
|
103 |
return None
|
104 |
|
105 |
|
106 |
class ChunkProcessor:
|
107 |
-
def __init__(self, config):
|
108 |
self.config = config
|
|
|
|
|
|
|
|
|
|
|
109 |
|
110 |
if config["splitter_options"]["use_splitter"]:
|
111 |
if config["splitter_options"]["split_by_token"]:
|
@@ -124,7 +127,7 @@ class ChunkProcessor:
|
|
124 |
)
|
125 |
else:
|
126 |
self.splitter = None
|
127 |
-
logger.info("ChunkProcessor instance created")
|
128 |
|
129 |
def remove_delimiters(self, document_chunks: list):
|
130 |
for chunk in document_chunks:
|
@@ -139,7 +142,6 @@ class ChunkProcessor:
|
|
139 |
del document_chunks[0]
|
140 |
for _ in range(end):
|
141 |
document_chunks.pop()
|
142 |
-
logger.info(f"\tNumber of pages after skipping: {len(document_chunks)}")
|
143 |
return document_chunks
|
144 |
|
145 |
def process_chunks(
|
@@ -172,122 +174,187 @@ class ChunkProcessor:
|
|
172 |
|
173 |
return document_chunks
|
174 |
|
175 |
-
def
|
176 |
-
self.document_chunks_full = []
|
177 |
-
self.parent_document_names = []
|
178 |
-
self.child_document_names = []
|
179 |
-
self.documents = []
|
180 |
-
self.document_metadata = []
|
181 |
-
|
182 |
addl_metadata = get_metadata(
|
183 |
"https://dl4ds.github.io/sp2024/lectures/",
|
184 |
"https://dl4ds.github.io/sp2024/schedule/",
|
185 |
) # For any additional metadata
|
186 |
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
self.parent_document_names.append(file_name)
|
219 |
-
if self.config["embedding_options"]["db_option"] not in [
|
220 |
-
"RAGatouille"
|
221 |
-
]:
|
222 |
-
document_chunks = self.process_chunks(
|
223 |
-
self.documents[-1],
|
224 |
-
file_type,
|
225 |
-
source=file_path,
|
226 |
-
page=page_num,
|
227 |
-
metadata=metadata,
|
228 |
-
)
|
229 |
-
self.document_chunks_full.extend(document_chunks)
|
230 |
-
|
231 |
-
# except Exception as e:
|
232 |
-
# logger.error(f"Error processing file {file_name}: {str(e)}")
|
233 |
-
|
234 |
-
self.process_weblinks(file_reader, weblinks)
|
235 |
-
|
236 |
-
logger.info(
|
237 |
f"Total document chunks extracted: {len(self.document_chunks_full)}"
|
238 |
)
|
239 |
-
return (
|
240 |
-
self.document_chunks_full,
|
241 |
-
self.child_document_names,
|
242 |
-
self.documents,
|
243 |
-
self.document_metadata,
|
244 |
-
)
|
245 |
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
-
|
273 |
-
|
274 |
-
|
275 |
-
|
276 |
-
|
277 |
-
|
278 |
-
|
279 |
-
|
280 |
-
|
281 |
-
|
282 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
283 |
|
284 |
|
285 |
class DataLoader:
|
286 |
-
def __init__(self, config):
|
287 |
-
self.file_reader = FileReader()
|
288 |
-
self.chunk_processor = ChunkProcessor(config)
|
289 |
|
290 |
def get_chunks(self, uploaded_files, weblinks):
|
291 |
-
return self.chunk_processor.
|
292 |
self.file_reader, uploaded_files, weblinks
|
293 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
17 |
from ragatouille import RAGPretrainedModel
|
18 |
from langchain.chains import LLMChain
|
19 |
+
from langchain_community.llms import OpenAI
|
20 |
from langchain import PromptTemplate
|
21 |
+
import json
|
22 |
+
from concurrent.futures import ThreadPoolExecutor
|
23 |
|
24 |
+
from modules.dataloader.helpers import get_metadata
|
|
|
|
|
|
|
|
|
|
|
25 |
|
26 |
|
27 |
class PDFReader:
|
|
|
37 |
|
38 |
|
39 |
class FileReader:
|
40 |
+
def __init__(self, logger):
|
41 |
self.pdf_reader = PDFReader()
|
42 |
+
self.logger = logger
|
43 |
|
44 |
def extract_text_from_pdf(self, pdf_path):
|
45 |
text = ""
|
|
|
59 |
temp_file_path = temp_file.name
|
60 |
return temp_file_path
|
61 |
else:
|
62 |
+
self.logger.error(f"Failed to download PDF from URL: {pdf_url}")
|
63 |
return None
|
64 |
|
65 |
def read_pdf(self, temp_file_path: str):
|
|
|
97 |
if response.status_code == 200:
|
98 |
return [Document(page_content=response.text)]
|
99 |
else:
|
100 |
+
self.logger.error(f"Failed to fetch .tex file from URL: {tex_url}")
|
101 |
return None
|
102 |
|
103 |
|
104 |
class ChunkProcessor:
|
105 |
+
def __init__(self, config, logger):
|
106 |
self.config = config
|
107 |
+
self.logger = logger
|
108 |
+
|
109 |
+
self.document_data = {}
|
110 |
+
self.document_metadata = {}
|
111 |
+
self.document_chunks_full = []
|
112 |
|
113 |
if config["splitter_options"]["use_splitter"]:
|
114 |
if config["splitter_options"]["split_by_token"]:
|
|
|
127 |
)
|
128 |
else:
|
129 |
self.splitter = None
|
130 |
+
self.logger.info("ChunkProcessor instance created")
|
131 |
|
132 |
def remove_delimiters(self, document_chunks: list):
|
133 |
for chunk in document_chunks:
|
|
|
142 |
del document_chunks[0]
|
143 |
for _ in range(end):
|
144 |
document_chunks.pop()
|
|
|
145 |
return document_chunks
|
146 |
|
147 |
def process_chunks(
|
|
|
174 |
|
175 |
return document_chunks
|
176 |
|
177 |
+
def chunk_docs(self, file_reader, uploaded_files, weblinks):
|
|
|
|
|
|
|
|
|
|
|
|
|
178 |
addl_metadata = get_metadata(
|
179 |
"https://dl4ds.github.io/sp2024/lectures/",
|
180 |
"https://dl4ds.github.io/sp2024/schedule/",
|
181 |
) # For any additional metadata
|
182 |
|
183 |
+
with ThreadPoolExecutor() as executor:
|
184 |
+
executor.map(
|
185 |
+
self.process_file,
|
186 |
+
uploaded_files,
|
187 |
+
range(len(uploaded_files)),
|
188 |
+
[file_reader] * len(uploaded_files),
|
189 |
+
[addl_metadata] * len(uploaded_files),
|
190 |
+
)
|
191 |
+
executor.map(
|
192 |
+
self.process_weblink,
|
193 |
+
weblinks,
|
194 |
+
range(len(weblinks)),
|
195 |
+
[file_reader] * len(weblinks),
|
196 |
+
[addl_metadata] * len(weblinks),
|
197 |
+
)
|
198 |
+
|
199 |
+
document_names = [
|
200 |
+
f"{file_name}_{page_num}"
|
201 |
+
for file_name, pages in self.document_data.items()
|
202 |
+
for page_num in pages.keys()
|
203 |
+
]
|
204 |
+
documents = [
|
205 |
+
page for doc in self.document_data.values() for page in doc.values()
|
206 |
+
]
|
207 |
+
document_metadata = [
|
208 |
+
page for doc in self.document_metadata.values() for page in doc.values()
|
209 |
+
]
|
210 |
+
|
211 |
+
self.save_document_data()
|
212 |
+
|
213 |
+
self.logger.info(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
214 |
f"Total document chunks extracted: {len(self.document_chunks_full)}"
|
215 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
216 |
|
217 |
+
return self.document_chunks_full, document_names, documents, document_metadata
|
218 |
+
|
219 |
+
def process_documents(
|
220 |
+
self, documents, file_path, file_type, metadata_source, addl_metadata
|
221 |
+
):
|
222 |
+
file_data = {}
|
223 |
+
file_metadata = {}
|
224 |
+
|
225 |
+
for doc in documents:
|
226 |
+
# if len(doc.page_content) <= 400: # better approach to filter out non-informative documents
|
227 |
+
# continue
|
228 |
+
|
229 |
+
page_num = doc.metadata.get("page", 0)
|
230 |
+
file_data[page_num] = doc.page_content
|
231 |
+
metadata = (
|
232 |
+
addl_metadata.get(file_path, {})
|
233 |
+
if metadata_source == "file"
|
234 |
+
else {"source": file_path, "page": page_num}
|
235 |
+
)
|
236 |
+
file_metadata[page_num] = metadata
|
237 |
+
|
238 |
+
if self.config["vectorstore"]["db_option"] not in ["RAGatouille"]:
|
239 |
+
document_chunks = self.process_chunks(
|
240 |
+
doc.page_content,
|
241 |
+
file_type,
|
242 |
+
source=file_path,
|
243 |
+
page=page_num,
|
244 |
+
metadata=metadata,
|
245 |
+
)
|
246 |
+
self.document_chunks_full.extend(document_chunks)
|
247 |
+
|
248 |
+
self.document_data[file_path] = file_data
|
249 |
+
self.document_metadata[file_path] = file_metadata
|
250 |
+
|
251 |
+
def process_file(self, file_path, file_index, file_reader, addl_metadata):
|
252 |
+
file_name = os.path.basename(file_path)
|
253 |
+
if file_name in self.document_data:
|
254 |
+
return
|
255 |
+
|
256 |
+
file_type = file_name.split(".")[-1].lower()
|
257 |
+
self.logger.info(f"Reading file {file_index + 1}: {file_path}")
|
258 |
+
|
259 |
+
read_methods = {
|
260 |
+
"pdf": file_reader.read_pdf,
|
261 |
+
"txt": file_reader.read_txt,
|
262 |
+
"docx": file_reader.read_docx,
|
263 |
+
"srt": file_reader.read_srt,
|
264 |
+
"tex": file_reader.read_tex_from_url,
|
265 |
+
}
|
266 |
+
if file_type not in read_methods:
|
267 |
+
self.logger.warning(f"Unsupported file type: {file_type}")
|
268 |
+
return
|
269 |
+
|
270 |
+
try:
|
271 |
+
documents = read_methods[file_type](file_path)
|
272 |
+
self.process_documents(
|
273 |
+
documents, file_path, file_type, "file", addl_metadata
|
274 |
+
)
|
275 |
+
except Exception as e:
|
276 |
+
self.logger.error(f"Error processing file {file_name}: {str(e)}")
|
277 |
+
|
278 |
+
def process_weblink(self, link, link_index, file_reader, addl_metadata):
|
279 |
+
if link in self.document_data:
|
280 |
+
return
|
281 |
+
|
282 |
+
self.logger.info(f"Reading link {link_index + 1} : {link}")
|
283 |
+
|
284 |
+
try:
|
285 |
+
if "youtube" in link:
|
286 |
+
documents = file_reader.read_youtube_transcript(link)
|
287 |
+
else:
|
288 |
+
documents = file_reader.read_html(link)
|
289 |
+
|
290 |
+
self.process_documents(documents, link, "txt", "link", addl_metadata)
|
291 |
+
except Exception as e:
|
292 |
+
self.logger.error(f"Error Reading link {link_index + 1} : {link}: {str(e)}")
|
293 |
+
|
294 |
+
def save_document_data(self):
|
295 |
+
if not os.path.exists(f"{self.config['log_chunk_dir']}/docs"):
|
296 |
+
os.makedirs(f"{self.config['log_chunk_dir']}/docs")
|
297 |
+
self.logger.info(
|
298 |
+
f"Creating directory {self.config['log_chunk_dir']}/docs for document data"
|
299 |
+
)
|
300 |
+
self.logger.info(
|
301 |
+
f"Saving document content to {self.config['log_chunk_dir']}/docs/doc_content.json"
|
302 |
+
)
|
303 |
+
if not os.path.exists(f"{self.config['log_chunk_dir']}/metadata"):
|
304 |
+
os.makedirs(f"{self.config['log_chunk_dir']}/metadata")
|
305 |
+
self.logger.info(
|
306 |
+
f"Creating directory {self.config['log_chunk_dir']}/metadata for document metadata"
|
307 |
+
)
|
308 |
+
self.logger.info(
|
309 |
+
f"Saving document metadata to {self.config['log_chunk_dir']}/metadata/doc_metadata.json"
|
310 |
+
)
|
311 |
+
with open(
|
312 |
+
f"{self.config['log_chunk_dir']}/docs/doc_content.json", "w"
|
313 |
+
) as json_file:
|
314 |
+
json.dump(self.document_data, json_file, indent=4)
|
315 |
+
with open(
|
316 |
+
f"{self.config['log_chunk_dir']}/metadata/doc_metadata.json", "w"
|
317 |
+
) as json_file:
|
318 |
+
json.dump(self.document_metadata, json_file, indent=4)
|
319 |
+
|
320 |
+
def load_document_data(self):
|
321 |
+
with open(
|
322 |
+
f"{self.config['log_chunk_dir']}/docs/doc_content.json", "r"
|
323 |
+
) as json_file:
|
324 |
+
self.document_data = json.load(json_file)
|
325 |
+
with open(
|
326 |
+
f"{self.config['log_chunk_dir']}/metadata/doc_metadata.json", "r"
|
327 |
+
) as json_file:
|
328 |
+
self.document_metadata = json.load(json_file)
|
329 |
|
330 |
|
331 |
class DataLoader:
|
332 |
+
def __init__(self, config, logger=None):
|
333 |
+
self.file_reader = FileReader(logger=logger)
|
334 |
+
self.chunk_processor = ChunkProcessor(config, logger=logger)
|
335 |
|
336 |
def get_chunks(self, uploaded_files, weblinks):
|
337 |
+
return self.chunk_processor.chunk_docs(
|
338 |
self.file_reader, uploaded_files, weblinks
|
339 |
)
|
340 |
+
|
341 |
+
|
342 |
+
if __name__ == "__main__":
|
343 |
+
import yaml
|
344 |
+
|
345 |
+
logger = logging.getLogger(__name__)
|
346 |
+
logger.setLevel(logging.INFO)
|
347 |
+
|
348 |
+
with open("../code/modules/config/config.yml", "r") as f:
|
349 |
+
config = yaml.safe_load(f)
|
350 |
+
|
351 |
+
data_loader = DataLoader(config, logger=logger)
|
352 |
+
document_chunks, document_names, documents, document_metadata = (
|
353 |
+
data_loader.get_chunks(
|
354 |
+
[],
|
355 |
+
["https://dl4ds.github.io/sp2024/"],
|
356 |
+
)
|
357 |
+
)
|
358 |
+
|
359 |
+
print(document_names)
|
360 |
+
print(len(document_chunks))
|
code/modules/dataloader/helpers.py
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import requests
|
2 |
+
from bs4 import BeautifulSoup
|
3 |
+
from tqdm import tqdm
|
4 |
+
|
5 |
+
|
6 |
+
def get_urls_from_file(file_path: str):
|
7 |
+
"""
|
8 |
+
Function to get urls from a file
|
9 |
+
"""
|
10 |
+
with open(file_path, "r") as f:
|
11 |
+
urls = f.readlines()
|
12 |
+
urls = [url.strip() for url in urls]
|
13 |
+
return urls
|
14 |
+
|
15 |
+
|
16 |
+
def get_base_url(url):
|
17 |
+
parsed_url = urlparse(url)
|
18 |
+
base_url = f"{parsed_url.scheme}://{parsed_url.netloc}/"
|
19 |
+
return base_url
|
20 |
+
|
21 |
+
|
22 |
+
def get_metadata(lectures_url, schedule_url):
|
23 |
+
"""
|
24 |
+
Function to get the lecture metadata from the lectures and schedule URLs.
|
25 |
+
"""
|
26 |
+
lecture_metadata = {}
|
27 |
+
|
28 |
+
# Get the main lectures page content
|
29 |
+
r_lectures = requests.get(lectures_url)
|
30 |
+
soup_lectures = BeautifulSoup(r_lectures.text, "html.parser")
|
31 |
+
|
32 |
+
# Get the main schedule page content
|
33 |
+
r_schedule = requests.get(schedule_url)
|
34 |
+
soup_schedule = BeautifulSoup(r_schedule.text, "html.parser")
|
35 |
+
|
36 |
+
# Find all lecture blocks
|
37 |
+
lecture_blocks = soup_lectures.find_all("div", class_="lecture-container")
|
38 |
+
|
39 |
+
# Create a mapping from slides link to date
|
40 |
+
date_mapping = {}
|
41 |
+
schedule_rows = soup_schedule.find_all("li", class_="table-row-lecture")
|
42 |
+
for row in schedule_rows:
|
43 |
+
try:
|
44 |
+
date = (
|
45 |
+
row.find("div", {"data-label": "Date"}).get_text(separator=" ").strip()
|
46 |
+
)
|
47 |
+
description_div = row.find("div", {"data-label": "Description"})
|
48 |
+
slides_link_tag = description_div.find("a", title="Download slides")
|
49 |
+
slides_link = slides_link_tag["href"].strip() if slides_link_tag else None
|
50 |
+
slides_link = (
|
51 |
+
f"https://dl4ds.github.io{slides_link}" if slides_link else None
|
52 |
+
)
|
53 |
+
if slides_link:
|
54 |
+
date_mapping[slides_link] = date
|
55 |
+
except Exception as e:
|
56 |
+
print(f"Error processing schedule row: {e}")
|
57 |
+
continue
|
58 |
+
|
59 |
+
for block in lecture_blocks:
|
60 |
+
try:
|
61 |
+
# Extract the lecture title
|
62 |
+
title = block.find("span", style="font-weight: bold;").text.strip()
|
63 |
+
|
64 |
+
# Extract the TL;DR
|
65 |
+
tldr = block.find("strong", text="tl;dr:").next_sibling.strip()
|
66 |
+
|
67 |
+
# Extract the link to the slides
|
68 |
+
slides_link_tag = block.find("a", title="Download slides")
|
69 |
+
slides_link = slides_link_tag["href"].strip() if slides_link_tag else None
|
70 |
+
slides_link = (
|
71 |
+
f"https://dl4ds.github.io{slides_link}" if slides_link else None
|
72 |
+
)
|
73 |
+
|
74 |
+
# Extract the link to the lecture recording
|
75 |
+
recording_link_tag = block.find("a", title="Download lecture recording")
|
76 |
+
recording_link = (
|
77 |
+
recording_link_tag["href"].strip() if recording_link_tag else None
|
78 |
+
)
|
79 |
+
|
80 |
+
# Extract suggested readings or summary if available
|
81 |
+
suggested_readings_tag = block.find("p", text="Suggested Readings:")
|
82 |
+
if suggested_readings_tag:
|
83 |
+
suggested_readings = suggested_readings_tag.find_next_sibling("ul")
|
84 |
+
if suggested_readings:
|
85 |
+
suggested_readings = suggested_readings.get_text(
|
86 |
+
separator="\n"
|
87 |
+
).strip()
|
88 |
+
else:
|
89 |
+
suggested_readings = "No specific readings provided."
|
90 |
+
else:
|
91 |
+
suggested_readings = "No specific readings provided."
|
92 |
+
|
93 |
+
# Get the date from the schedule
|
94 |
+
date = date_mapping.get(slides_link, "No date available")
|
95 |
+
|
96 |
+
# Add to the dictionary
|
97 |
+
lecture_metadata[slides_link] = {
|
98 |
+
"date": date,
|
99 |
+
"tldr": tldr,
|
100 |
+
"title": title,
|
101 |
+
"lecture_recording": recording_link,
|
102 |
+
"suggested_readings": suggested_readings,
|
103 |
+
}
|
104 |
+
except Exception as e:
|
105 |
+
print(f"Error processing block: {e}")
|
106 |
+
continue
|
107 |
+
|
108 |
+
return lecture_metadata
|
code/modules/{helpers.py β dataloader/webpage_crawler.py}
RENAMED
@@ -1,25 +1,9 @@
|
|
1 |
-
import
|
2 |
-
from
|
3 |
-
|
4 |
-
import chainlit as cl
|
5 |
-
from langchain import PromptTemplate
|
6 |
import requests
|
7 |
from bs4 import BeautifulSoup
|
8 |
from urllib.parse import urlparse, urljoin, urldefrag
|
9 |
-
import asyncio
|
10 |
-
import aiohttp
|
11 |
-
from aiohttp import ClientSession
|
12 |
-
from typing import Dict, Any, List
|
13 |
-
|
14 |
-
try:
|
15 |
-
from modules.constants import *
|
16 |
-
except:
|
17 |
-
from constants import *
|
18 |
-
|
19 |
-
"""
|
20 |
-
Ref: https://python.plainenglish.io/scraping-the-subpages-on-a-website-ea2d4e3db113
|
21 |
-
"""
|
22 |
-
|
23 |
|
24 |
class WebpageCrawler:
|
25 |
def __init__(self):
|
@@ -129,209 +113,3 @@ class WebpageCrawler:
|
|
129 |
# Strip the fragment identifier
|
130 |
defragged_url, _ = urldefrag(url)
|
131 |
return defragged_url
|
132 |
-
|
133 |
-
|
134 |
-
def get_urls_from_file(file_path: str):
|
135 |
-
"""
|
136 |
-
Function to get urls from a file
|
137 |
-
"""
|
138 |
-
with open(file_path, "r") as f:
|
139 |
-
urls = f.readlines()
|
140 |
-
urls = [url.strip() for url in urls]
|
141 |
-
return urls
|
142 |
-
|
143 |
-
|
144 |
-
def get_base_url(url):
|
145 |
-
parsed_url = urlparse(url)
|
146 |
-
base_url = f"{parsed_url.scheme}://{parsed_url.netloc}/"
|
147 |
-
return base_url
|
148 |
-
|
149 |
-
|
150 |
-
def get_prompt(config):
|
151 |
-
if config["llm_params"]["use_history"]:
|
152 |
-
if config["llm_params"]["llm_loader"] == "local_llm":
|
153 |
-
custom_prompt_template = tinyllama_prompt_template_with_history
|
154 |
-
elif config["llm_params"]["llm_loader"] == "openai":
|
155 |
-
custom_prompt_template = openai_prompt_template_with_history
|
156 |
-
# else:
|
157 |
-
# custom_prompt_template = tinyllama_prompt_template_with_history # default
|
158 |
-
prompt = PromptTemplate(
|
159 |
-
template=custom_prompt_template,
|
160 |
-
input_variables=["context", "chat_history", "question"],
|
161 |
-
)
|
162 |
-
else:
|
163 |
-
if config["llm_params"]["llm_loader"] == "local_llm":
|
164 |
-
custom_prompt_template = tinyllama_prompt_template
|
165 |
-
elif config["llm_params"]["llm_loader"] == "openai":
|
166 |
-
custom_prompt_template = openai_prompt_template
|
167 |
-
# else:
|
168 |
-
# custom_prompt_template = tinyllama_prompt_template
|
169 |
-
prompt = PromptTemplate(
|
170 |
-
template=custom_prompt_template,
|
171 |
-
input_variables=["context", "question"],
|
172 |
-
)
|
173 |
-
return prompt
|
174 |
-
|
175 |
-
|
176 |
-
def get_sources(res, answer):
|
177 |
-
source_elements = []
|
178 |
-
source_dict = {} # Dictionary to store URL elements
|
179 |
-
|
180 |
-
for idx, source in enumerate(res["source_documents"]):
|
181 |
-
source_metadata = source.metadata
|
182 |
-
url = source_metadata["source"]
|
183 |
-
score = source_metadata.get("score", "N/A")
|
184 |
-
page = source_metadata.get("page", 1)
|
185 |
-
|
186 |
-
lecture_tldr = source_metadata.get("tldr", "N/A")
|
187 |
-
lecture_recording = source_metadata.get("lecture_recording", "N/A")
|
188 |
-
suggested_readings = source_metadata.get("suggested_readings", "N/A")
|
189 |
-
date = source_metadata.get("date", "N/A")
|
190 |
-
|
191 |
-
source_type = source_metadata.get("source_type", "N/A")
|
192 |
-
|
193 |
-
url_name = f"{url}_{page}"
|
194 |
-
if url_name not in source_dict:
|
195 |
-
source_dict[url_name] = {
|
196 |
-
"text": source.page_content,
|
197 |
-
"url": url,
|
198 |
-
"score": score,
|
199 |
-
"page": page,
|
200 |
-
"lecture_tldr": lecture_tldr,
|
201 |
-
"lecture_recording": lecture_recording,
|
202 |
-
"suggested_readings": suggested_readings,
|
203 |
-
"date": date,
|
204 |
-
"source_type": source_type,
|
205 |
-
}
|
206 |
-
else:
|
207 |
-
source_dict[url_name]["text"] += f"\n\n{source.page_content}"
|
208 |
-
|
209 |
-
# First, display the answer
|
210 |
-
full_answer = "**Answer:**\n"
|
211 |
-
full_answer += answer
|
212 |
-
|
213 |
-
# Then, display the sources
|
214 |
-
full_answer += "\n\n**Sources:**\n"
|
215 |
-
for idx, (url_name, source_data) in enumerate(source_dict.items()):
|
216 |
-
full_answer += f"\nSource {idx + 1} (Score: {source_data['score']}): {source_data['url']}\n"
|
217 |
-
|
218 |
-
name = f"Source {idx + 1} Text\n"
|
219 |
-
full_answer += name
|
220 |
-
source_elements.append(
|
221 |
-
cl.Text(name=name, content=source_data["text"], display="side")
|
222 |
-
)
|
223 |
-
|
224 |
-
# Add a PDF element if the source is a PDF file
|
225 |
-
if source_data["url"].lower().endswith(".pdf"):
|
226 |
-
name = f"Source {idx + 1} PDF\n"
|
227 |
-
full_answer += name
|
228 |
-
pdf_url = f"{source_data['url']}#page={source_data['page']+1}"
|
229 |
-
source_elements.append(cl.Pdf(name=name, url=pdf_url, display="side"))
|
230 |
-
|
231 |
-
full_answer += "\n**Metadata:**\n"
|
232 |
-
for idx, (url_name, source_data) in enumerate(source_dict.items()):
|
233 |
-
full_answer += f"\nSource {idx + 1} Metadata:\n"
|
234 |
-
source_elements.append(
|
235 |
-
cl.Text(
|
236 |
-
name=f"Source {idx + 1} Metadata",
|
237 |
-
content=f"Source: {source_data['url']}\n"
|
238 |
-
f"Page: {source_data['page']}\n"
|
239 |
-
f"Type: {source_data['source_type']}\n"
|
240 |
-
f"Date: {source_data['date']}\n"
|
241 |
-
f"TL;DR: {source_data['lecture_tldr']}\n"
|
242 |
-
f"Lecture Recording: {source_data['lecture_recording']}\n"
|
243 |
-
f"Suggested Readings: {source_data['suggested_readings']}\n",
|
244 |
-
display="side",
|
245 |
-
)
|
246 |
-
)
|
247 |
-
|
248 |
-
return full_answer, source_elements
|
249 |
-
|
250 |
-
|
251 |
-
def get_metadata(lectures_url, schedule_url):
|
252 |
-
"""
|
253 |
-
Function to get the lecture metadata from the lectures and schedule URLs.
|
254 |
-
"""
|
255 |
-
lecture_metadata = {}
|
256 |
-
|
257 |
-
# Get the main lectures page content
|
258 |
-
r_lectures = requests.get(lectures_url)
|
259 |
-
soup_lectures = BeautifulSoup(r_lectures.text, "html.parser")
|
260 |
-
|
261 |
-
# Get the main schedule page content
|
262 |
-
r_schedule = requests.get(schedule_url)
|
263 |
-
soup_schedule = BeautifulSoup(r_schedule.text, "html.parser")
|
264 |
-
|
265 |
-
# Find all lecture blocks
|
266 |
-
lecture_blocks = soup_lectures.find_all("div", class_="lecture-container")
|
267 |
-
|
268 |
-
# Create a mapping from slides link to date
|
269 |
-
date_mapping = {}
|
270 |
-
schedule_rows = soup_schedule.find_all("li", class_="table-row-lecture")
|
271 |
-
for row in schedule_rows:
|
272 |
-
try:
|
273 |
-
date = (
|
274 |
-
row.find("div", {"data-label": "Date"}).get_text(separator=" ").strip()
|
275 |
-
)
|
276 |
-
description_div = row.find("div", {"data-label": "Description"})
|
277 |
-
slides_link_tag = description_div.find("a", title="Download slides")
|
278 |
-
slides_link = slides_link_tag["href"].strip() if slides_link_tag else None
|
279 |
-
slides_link = (
|
280 |
-
f"https://dl4ds.github.io{slides_link}" if slides_link else None
|
281 |
-
)
|
282 |
-
if slides_link:
|
283 |
-
date_mapping[slides_link] = date
|
284 |
-
except Exception as e:
|
285 |
-
print(f"Error processing schedule row: {e}")
|
286 |
-
continue
|
287 |
-
|
288 |
-
for block in lecture_blocks:
|
289 |
-
try:
|
290 |
-
# Extract the lecture title
|
291 |
-
title = block.find("span", style="font-weight: bold;").text.strip()
|
292 |
-
|
293 |
-
# Extract the TL;DR
|
294 |
-
tldr = block.find("strong", text="tl;dr:").next_sibling.strip()
|
295 |
-
|
296 |
-
# Extract the link to the slides
|
297 |
-
slides_link_tag = block.find("a", title="Download slides")
|
298 |
-
slides_link = slides_link_tag["href"].strip() if slides_link_tag else None
|
299 |
-
slides_link = (
|
300 |
-
f"https://dl4ds.github.io{slides_link}" if slides_link else None
|
301 |
-
)
|
302 |
-
|
303 |
-
# Extract the link to the lecture recording
|
304 |
-
recording_link_tag = block.find("a", title="Download lecture recording")
|
305 |
-
recording_link = (
|
306 |
-
recording_link_tag["href"].strip() if recording_link_tag else None
|
307 |
-
)
|
308 |
-
|
309 |
-
# Extract suggested readings or summary if available
|
310 |
-
suggested_readings_tag = block.find("p", text="Suggested Readings:")
|
311 |
-
if suggested_readings_tag:
|
312 |
-
suggested_readings = suggested_readings_tag.find_next_sibling("ul")
|
313 |
-
if suggested_readings:
|
314 |
-
suggested_readings = suggested_readings.get_text(
|
315 |
-
separator="\n"
|
316 |
-
).strip()
|
317 |
-
else:
|
318 |
-
suggested_readings = "No specific readings provided."
|
319 |
-
else:
|
320 |
-
suggested_readings = "No specific readings provided."
|
321 |
-
|
322 |
-
# Get the date from the schedule
|
323 |
-
date = date_mapping.get(slides_link, "No date available")
|
324 |
-
|
325 |
-
# Add to the dictionary
|
326 |
-
lecture_metadata[slides_link] = {
|
327 |
-
"date": date,
|
328 |
-
"tldr": tldr,
|
329 |
-
"title": title,
|
330 |
-
"lecture_recording": recording_link,
|
331 |
-
"suggested_readings": suggested_readings,
|
332 |
-
}
|
333 |
-
except Exception as e:
|
334 |
-
print(f"Error processing block: {e}")
|
335 |
-
continue
|
336 |
-
|
337 |
-
return lecture_metadata
|
|
|
1 |
+
import aiohttp
|
2 |
+
from aiohttp import ClientSession
|
3 |
+
import asyncio
|
|
|
|
|
4 |
import requests
|
5 |
from bs4 import BeautifulSoup
|
6 |
from urllib.parse import urlparse, urljoin, urldefrag
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
|
8 |
class WebpageCrawler:
|
9 |
def __init__(self):
|
|
|
113 |
# Strip the fragment identifier
|
114 |
defragged_url, _ = urldefrag(url)
|
115 |
return defragged_url
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
code/modules/retriever/__init__.py
ADDED
File without changes
|
code/modules/retriever/base.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# template for retriever classes
|
2 |
+
|
3 |
+
|
4 |
+
class BaseRetriever:
|
5 |
+
def __init__(self, config):
|
6 |
+
self.config = config
|
7 |
+
|
8 |
+
def return_retriever(self):
|
9 |
+
"""
|
10 |
+
Returns the retriever object
|
11 |
+
"""
|
12 |
+
raise NotImplementedError
|
code/modules/retriever/chroma_retriever.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .helpers import VectorStoreRetrieverScore
|
2 |
+
from .base import BaseRetriever
|
3 |
+
|
4 |
+
|
5 |
+
class ChromaRetriever(BaseRetriever):
|
6 |
+
def __init__(self):
|
7 |
+
pass
|
8 |
+
|
9 |
+
def return_retriever(self, db, config):
|
10 |
+
retriever = VectorStoreRetrieverScore(
|
11 |
+
vectorstore=db,
|
12 |
+
# search_type="similarity_score_threshold",
|
13 |
+
# search_kwargs={
|
14 |
+
# "score_threshold": self.config["vectorstore"][
|
15 |
+
# "score_threshold"
|
16 |
+
# ],
|
17 |
+
# "k": self.config["vectorstore"]["search_top_k"],
|
18 |
+
# },
|
19 |
+
search_kwargs={
|
20 |
+
"k": config["vectorstore"]["search_top_k"],
|
21 |
+
},
|
22 |
+
)
|
23 |
+
|
24 |
+
return retriever
|
code/modules/retriever/colbert_retriever.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .base import BaseRetriever
|
2 |
+
|
3 |
+
|
4 |
+
class ColbertRetriever(BaseRetriever):
|
5 |
+
def __init__(self):
|
6 |
+
pass
|
7 |
+
|
8 |
+
def return_retriever(self, db, config):
|
9 |
+
retriever = db.as_langchain_retriever(k=config["vectorstore"]["search_top_k"])
|
10 |
+
return retriever
|
code/modules/retriever/faiss_retriever.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .helpers import VectorStoreRetrieverScore
|
2 |
+
from .base import BaseRetriever
|
3 |
+
|
4 |
+
|
5 |
+
class FaissRetriever(BaseRetriever):
|
6 |
+
def __init__(self):
|
7 |
+
pass
|
8 |
+
|
9 |
+
def return_retriever(self, db, config):
|
10 |
+
retriever = VectorStoreRetrieverScore(
|
11 |
+
vectorstore=db,
|
12 |
+
# search_type="similarity_score_threshold",
|
13 |
+
# search_kwargs={
|
14 |
+
# "score_threshold": self.config["vectorstore"][
|
15 |
+
# "score_threshold"
|
16 |
+
# ],
|
17 |
+
# "k": self.config["vectorstore"]["search_top_k"],
|
18 |
+
# },
|
19 |
+
search_kwargs={
|
20 |
+
"k": config["vectorstore"]["search_top_k"],
|
21 |
+
},
|
22 |
+
)
|
23 |
+
return retriever
|
code/modules/retriever/helpers.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain.schema.vectorstore import VectorStoreRetriever
|
2 |
+
from langchain.callbacks.manager import CallbackManagerForRetrieverRun
|
3 |
+
from langchain.schema.document import Document
|
4 |
+
from langchain_core.callbacks import AsyncCallbackManagerForRetrieverRun
|
5 |
+
from typing import List
|
6 |
+
|
7 |
+
|
8 |
+
class VectorStoreRetrieverScore(VectorStoreRetriever):
|
9 |
+
|
10 |
+
# See https://github.com/langchain-ai/langchain/blob/61dd92f8215daef3d9cf1734b0d1f8c70c1571c3/libs/langchain/langchain/vectorstores/base.py#L500
|
11 |
+
def _get_relevant_documents(
|
12 |
+
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
13 |
+
) -> List[Document]:
|
14 |
+
docs_and_similarities = (
|
15 |
+
self.vectorstore.similarity_search_with_relevance_scores(
|
16 |
+
query, **self.search_kwargs
|
17 |
+
)
|
18 |
+
)
|
19 |
+
# Make the score part of the document metadata
|
20 |
+
for doc, similarity in docs_and_similarities:
|
21 |
+
doc.metadata["score"] = similarity
|
22 |
+
|
23 |
+
docs = [doc for doc, _ in docs_and_similarities]
|
24 |
+
return docs
|
25 |
+
|
26 |
+
async def _aget_relevant_documents(
|
27 |
+
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
|
28 |
+
) -> List[Document]:
|
29 |
+
docs_and_similarities = (
|
30 |
+
self.vectorstore.similarity_search_with_relevance_scores(
|
31 |
+
query, **self.search_kwargs
|
32 |
+
)
|
33 |
+
)
|
34 |
+
# Make the score part of the document metadata
|
35 |
+
for doc, similarity in docs_and_similarities:
|
36 |
+
doc.metadata["score"] = similarity
|
37 |
+
|
38 |
+
docs = [doc for doc, _ in docs_and_similarities]
|
39 |
+
return docs
|
code/modules/retriever/raptor_retriever.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .helpers import VectorStoreRetrieverScore
|
2 |
+
from .base import BaseRetriever
|
3 |
+
|
4 |
+
|
5 |
+
class RaptorRetriever(BaseRetriever):
|
6 |
+
def __init__(self):
|
7 |
+
pass
|
8 |
+
|
9 |
+
def return_retriever(self, db, config):
|
10 |
+
retriever = VectorStoreRetrieverScore(
|
11 |
+
vectorstore=db,
|
12 |
+
search_kwargs={
|
13 |
+
"k": config["vectorstore"]["search_top_k"],
|
14 |
+
},
|
15 |
+
)
|
16 |
+
return retriever
|
code/modules/retriever/retriever.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from modules.retriever.faiss_retriever import FaissRetriever
|
2 |
+
from modules.retriever.chroma_retriever import ChromaRetriever
|
3 |
+
from modules.retriever.colbert_retriever import ColbertRetriever
|
4 |
+
from modules.retriever.raptor_retriever import RaptorRetriever
|
5 |
+
|
6 |
+
|
7 |
+
class Retriever:
|
8 |
+
def __init__(self, config):
|
9 |
+
self.config = config
|
10 |
+
self.retriever_classes = {
|
11 |
+
"FAISS": FaissRetriever,
|
12 |
+
"Chroma": ChromaRetriever,
|
13 |
+
"RAGatouille": ColbertRetriever,
|
14 |
+
"RAPTOR": RaptorRetriever,
|
15 |
+
}
|
16 |
+
self._create_retriever()
|
17 |
+
|
18 |
+
def _create_retriever(self):
|
19 |
+
db_option = self.config["vectorstore"]["db_option"]
|
20 |
+
retriever_class = self.retriever_classes.get(db_option)
|
21 |
+
if not retriever_class:
|
22 |
+
raise ValueError(f"Invalid db_option: {db_option}")
|
23 |
+
self.retriever = retriever_class()
|
24 |
+
|
25 |
+
def _return_retriever(self, db):
|
26 |
+
return self.retriever.return_retriever(db, self.config)
|
code/modules/vector_db.py
DELETED
@@ -1,226 +0,0 @@
|
|
1 |
-
import logging
|
2 |
-
import os
|
3 |
-
import yaml
|
4 |
-
from langchain_community.vectorstores import FAISS, Chroma
|
5 |
-
from langchain.schema.vectorstore import VectorStoreRetriever
|
6 |
-
from langchain.callbacks.manager import CallbackManagerForRetrieverRun
|
7 |
-
from langchain.schema.document import Document
|
8 |
-
from langchain_core.callbacks import AsyncCallbackManagerForRetrieverRun
|
9 |
-
from ragatouille import RAGPretrainedModel
|
10 |
-
|
11 |
-
try:
|
12 |
-
from modules.embedding_model_loader import EmbeddingModelLoader
|
13 |
-
from modules.data_loader import DataLoader
|
14 |
-
from modules.constants import *
|
15 |
-
from modules.helpers import *
|
16 |
-
except:
|
17 |
-
from embedding_model_loader import EmbeddingModelLoader
|
18 |
-
from data_loader import DataLoader
|
19 |
-
from constants import *
|
20 |
-
from helpers import *
|
21 |
-
|
22 |
-
from typing import List
|
23 |
-
|
24 |
-
|
25 |
-
class VectorDBScore(VectorStoreRetriever):
|
26 |
-
|
27 |
-
# See https://github.com/langchain-ai/langchain/blob/61dd92f8215daef3d9cf1734b0d1f8c70c1571c3/libs/langchain/langchain/vectorstores/base.py#L500
|
28 |
-
def _get_relevant_documents(
|
29 |
-
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
30 |
-
) -> List[Document]:
|
31 |
-
docs_and_similarities = (
|
32 |
-
self.vectorstore.similarity_search_with_relevance_scores(
|
33 |
-
query, **self.search_kwargs
|
34 |
-
)
|
35 |
-
)
|
36 |
-
# Make the score part of the document metadata
|
37 |
-
for doc, similarity in docs_and_similarities:
|
38 |
-
doc.metadata["score"] = similarity
|
39 |
-
|
40 |
-
docs = [doc for doc, _ in docs_and_similarities]
|
41 |
-
return docs
|
42 |
-
|
43 |
-
async def _aget_relevant_documents(
|
44 |
-
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
|
45 |
-
) -> List[Document]:
|
46 |
-
docs_and_similarities = (
|
47 |
-
self.vectorstore.similarity_search_with_relevance_scores(
|
48 |
-
query, **self.search_kwargs
|
49 |
-
)
|
50 |
-
)
|
51 |
-
# Make the score part of the document metadata
|
52 |
-
for doc, similarity in docs_and_similarities:
|
53 |
-
doc.metadata["score"] = similarity
|
54 |
-
|
55 |
-
docs = [doc for doc, _ in docs_and_similarities]
|
56 |
-
return docs
|
57 |
-
|
58 |
-
|
59 |
-
class VectorDB:
|
60 |
-
def __init__(self, config, logger=None):
|
61 |
-
self.config = config
|
62 |
-
self.db_option = config["embedding_options"]["db_option"]
|
63 |
-
self.document_names = None
|
64 |
-
self.webpage_crawler = WebpageCrawler()
|
65 |
-
|
66 |
-
# Set up logging to both console and a file
|
67 |
-
if logger is None:
|
68 |
-
self.logger = logging.getLogger(__name__)
|
69 |
-
self.logger.setLevel(logging.INFO)
|
70 |
-
|
71 |
-
# Console Handler
|
72 |
-
console_handler = logging.StreamHandler()
|
73 |
-
console_handler.setLevel(logging.INFO)
|
74 |
-
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
|
75 |
-
console_handler.setFormatter(formatter)
|
76 |
-
self.logger.addHandler(console_handler)
|
77 |
-
|
78 |
-
# File Handler
|
79 |
-
log_file_path = "vector_db.log" # Change this to your desired log file path
|
80 |
-
file_handler = logging.FileHandler(log_file_path, mode="w")
|
81 |
-
file_handler.setLevel(logging.INFO)
|
82 |
-
file_handler.setFormatter(formatter)
|
83 |
-
self.logger.addHandler(file_handler)
|
84 |
-
else:
|
85 |
-
self.logger = logger
|
86 |
-
|
87 |
-
self.logger.info("VectorDB instance instantiated")
|
88 |
-
|
89 |
-
def load_files(self):
|
90 |
-
files = os.listdir(self.config["embedding_options"]["data_path"])
|
91 |
-
files = [
|
92 |
-
os.path.join(self.config["embedding_options"]["data_path"], file)
|
93 |
-
for file in files
|
94 |
-
]
|
95 |
-
urls = get_urls_from_file(self.config["embedding_options"]["url_file_path"])
|
96 |
-
if self.config["embedding_options"]["expand_urls"]:
|
97 |
-
all_urls = []
|
98 |
-
for url in urls:
|
99 |
-
loop = asyncio.get_event_loop()
|
100 |
-
all_urls.extend(
|
101 |
-
loop.run_until_complete(
|
102 |
-
self.webpage_crawler.get_all_pages(
|
103 |
-
url, url
|
104 |
-
) # only get child urls, if you want to get all urls, replace the second argument with the base url
|
105 |
-
)
|
106 |
-
)
|
107 |
-
urls = all_urls
|
108 |
-
return files, urls
|
109 |
-
|
110 |
-
def create_embedding_model(self):
|
111 |
-
self.logger.info("Creating embedding function")
|
112 |
-
self.embedding_model_loader = EmbeddingModelLoader(self.config)
|
113 |
-
self.embedding_model = self.embedding_model_loader.load_embedding_model()
|
114 |
-
|
115 |
-
def initialize_database(
|
116 |
-
self,
|
117 |
-
document_chunks: list,
|
118 |
-
document_names: list,
|
119 |
-
documents: list,
|
120 |
-
document_metadata: list,
|
121 |
-
):
|
122 |
-
if self.db_option in ["FAISS", "Chroma"]:
|
123 |
-
self.create_embedding_model()
|
124 |
-
# Track token usage
|
125 |
-
self.logger.info("Initializing vector_db")
|
126 |
-
self.logger.info("\tUsing {} as db_option".format(self.db_option))
|
127 |
-
if self.db_option == "FAISS":
|
128 |
-
self.vector_db = FAISS.from_documents(
|
129 |
-
documents=document_chunks, embedding=self.embedding_model
|
130 |
-
)
|
131 |
-
elif self.db_option == "Chroma":
|
132 |
-
self.vector_db = Chroma.from_documents(
|
133 |
-
documents=document_chunks,
|
134 |
-
embedding=self.embedding_model,
|
135 |
-
persist_directory=os.path.join(
|
136 |
-
self.config["embedding_options"]["db_path"],
|
137 |
-
"db_"
|
138 |
-
+ self.config["embedding_options"]["db_option"]
|
139 |
-
+ "_"
|
140 |
-
+ self.config["embedding_options"]["model"],
|
141 |
-
),
|
142 |
-
)
|
143 |
-
elif self.db_option == "RAGatouille":
|
144 |
-
self.RAG = RAGPretrainedModel.from_pretrained("colbert-ir/colbertv2.0")
|
145 |
-
index_path = self.RAG.index(
|
146 |
-
index_name="new_idx",
|
147 |
-
collection=documents,
|
148 |
-
document_ids=document_names,
|
149 |
-
document_metadatas=document_metadata,
|
150 |
-
)
|
151 |
-
self.logger.info("Completed initializing vector_db")
|
152 |
-
|
153 |
-
def create_database(self):
|
154 |
-
data_loader = DataLoader(self.config)
|
155 |
-
self.logger.info("Loading data")
|
156 |
-
files, urls = self.load_files()
|
157 |
-
files, webpages = self.webpage_crawler.clean_url_list(urls)
|
158 |
-
if "storage/data/urls.txt" in files:
|
159 |
-
files.remove("storage/data/urls.txt")
|
160 |
-
document_chunks, document_names, documents, document_metadata = (
|
161 |
-
data_loader.get_chunks(files, webpages)
|
162 |
-
)
|
163 |
-
self.logger.info("Completed loading data")
|
164 |
-
self.initialize_database(
|
165 |
-
document_chunks, document_names, documents, document_metadata
|
166 |
-
)
|
167 |
-
|
168 |
-
def save_database(self):
|
169 |
-
if self.db_option == "FAISS":
|
170 |
-
self.vector_db.save_local(
|
171 |
-
os.path.join(
|
172 |
-
self.config["embedding_options"]["db_path"],
|
173 |
-
"db_"
|
174 |
-
+ self.config["embedding_options"]["db_option"]
|
175 |
-
+ "_"
|
176 |
-
+ self.config["embedding_options"]["model"],
|
177 |
-
)
|
178 |
-
)
|
179 |
-
elif self.db_option == "Chroma":
|
180 |
-
# db is saved in the persist directory during initialization
|
181 |
-
pass
|
182 |
-
elif self.db_option == "RAGatouille":
|
183 |
-
# index is saved during initialization
|
184 |
-
pass
|
185 |
-
self.logger.info("Saved database")
|
186 |
-
|
187 |
-
def load_database(self):
|
188 |
-
self.create_embedding_model()
|
189 |
-
if self.db_option == "FAISS":
|
190 |
-
self.vector_db = FAISS.load_local(
|
191 |
-
os.path.join(
|
192 |
-
self.config["embedding_options"]["db_path"],
|
193 |
-
"db_"
|
194 |
-
+ self.config["embedding_options"]["db_option"]
|
195 |
-
+ "_"
|
196 |
-
+ self.config["embedding_options"]["model"],
|
197 |
-
),
|
198 |
-
self.embedding_model,
|
199 |
-
allow_dangerous_deserialization=True,
|
200 |
-
)
|
201 |
-
elif self.db_option == "Chroma":
|
202 |
-
self.vector_db = Chroma(
|
203 |
-
persist_directory=os.path.join(
|
204 |
-
self.config["embedding_options"]["db_path"],
|
205 |
-
"db_"
|
206 |
-
+ self.config["embedding_options"]["db_option"]
|
207 |
-
+ "_"
|
208 |
-
+ self.config["embedding_options"]["model"],
|
209 |
-
),
|
210 |
-
embedding_function=self.embedding_model,
|
211 |
-
)
|
212 |
-
elif self.db_option == "RAGatouille":
|
213 |
-
self.vector_db = RAGPretrainedModel.from_index(
|
214 |
-
".ragatouille/colbert/indexes/new_idx"
|
215 |
-
)
|
216 |
-
self.logger.info("Loaded database")
|
217 |
-
return self.vector_db
|
218 |
-
|
219 |
-
|
220 |
-
if __name__ == "__main__":
|
221 |
-
with open("code/config.yml", "r") as f:
|
222 |
-
config = yaml.safe_load(f)
|
223 |
-
print(config)
|
224 |
-
vector_db = VectorDB(config)
|
225 |
-
vector_db.create_database()
|
226 |
-
vector_db.save_database()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
code/modules/vectorstore/__init__.py
ADDED
File without changes
|
code/modules/vectorstore/base.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# template for vector store classes
|
2 |
+
|
3 |
+
|
4 |
+
class VectorStoreBase:
|
5 |
+
def __init__(self, config):
|
6 |
+
self.config = config
|
7 |
+
|
8 |
+
def _init_vector_db(self):
|
9 |
+
"""
|
10 |
+
Creates a vector store object
|
11 |
+
"""
|
12 |
+
raise NotImplementedError
|
13 |
+
|
14 |
+
def create_database(self):
|
15 |
+
"""
|
16 |
+
Populates the vector store with documents
|
17 |
+
"""
|
18 |
+
raise NotImplementedError
|
19 |
+
|
20 |
+
def load_database(self):
|
21 |
+
"""
|
22 |
+
Loads the vector store from disk
|
23 |
+
"""
|
24 |
+
raise NotImplementedError
|
25 |
+
|
26 |
+
def as_retriever(self):
|
27 |
+
"""
|
28 |
+
Returns the vector store as a retriever
|
29 |
+
"""
|
30 |
+
raise NotImplementedError
|
31 |
+
|
32 |
+
def __str__(self):
|
33 |
+
return self.__class__.__name__
|
code/modules/vectorstore/chroma.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain_community.vectorstores import Chroma
|
2 |
+
from modules.vectorstore.base import VectorStoreBase
|
3 |
+
import os
|
4 |
+
|
5 |
+
|
6 |
+
class ChromaVectorStore(VectorStoreBase):
|
7 |
+
def __init__(self, config):
|
8 |
+
self.config = config
|
9 |
+
self._init_vector_db()
|
10 |
+
|
11 |
+
def _init_vector_db(self):
|
12 |
+
self.chroma = Chroma()
|
13 |
+
|
14 |
+
def create_database(self, document_chunks, embedding_model):
|
15 |
+
self.vectorstore = self.chroma.from_documents(
|
16 |
+
documents=document_chunks,
|
17 |
+
embedding=embedding_model,
|
18 |
+
persist_directory=os.path.join(
|
19 |
+
self.config["vectorstore"]["db_path"],
|
20 |
+
"db_"
|
21 |
+
+ self.config["vectorstore"]["db_option"]
|
22 |
+
+ "_"
|
23 |
+
+ self.config["vectorstore"]["model"],
|
24 |
+
),
|
25 |
+
)
|
26 |
+
|
27 |
+
def load_database(self, embedding_model):
|
28 |
+
self.vectorstore = Chroma(
|
29 |
+
persist_directory=os.path.join(
|
30 |
+
self.config["vectorstore"]["db_path"],
|
31 |
+
"db_"
|
32 |
+
+ self.config["vectorstore"]["db_option"]
|
33 |
+
+ "_"
|
34 |
+
+ self.config["vectorstore"]["model"],
|
35 |
+
),
|
36 |
+
embedding_function=embedding_model,
|
37 |
+
)
|
38 |
+
return self.vectorstore
|
39 |
+
|
40 |
+
def as_retriever(self):
|
41 |
+
return self.vectorstore.as_retriever()
|
code/modules/vectorstore/colbert.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ragatouille import RAGPretrainedModel
|
2 |
+
from modules.vectorstore.base import VectorStoreBase
|
3 |
+
import os
|
4 |
+
|
5 |
+
|
6 |
+
class ColbertVectorStore(VectorStoreBase):
|
7 |
+
def __init__(self, config):
|
8 |
+
self.config = config
|
9 |
+
self._init_vector_db()
|
10 |
+
|
11 |
+
def _init_vector_db(self):
|
12 |
+
self.colbert = RAGPretrainedModel.from_pretrained(
|
13 |
+
"colbert-ir/colbertv2.0",
|
14 |
+
index_root=os.path.join(
|
15 |
+
self.config["vectorstore"]["db_path"],
|
16 |
+
"db_" + self.config["vectorstore"]["db_option"],
|
17 |
+
),
|
18 |
+
)
|
19 |
+
|
20 |
+
def create_database(self, documents, document_names, document_metadata):
|
21 |
+
index_path = self.colbert.index(
|
22 |
+
index_name="new_idx",
|
23 |
+
collection=documents,
|
24 |
+
document_ids=document_names,
|
25 |
+
document_metadatas=document_metadata,
|
26 |
+
)
|
27 |
+
|
28 |
+
def load_database(self):
|
29 |
+
path = os.path.join(
|
30 |
+
self.config["vectorstore"]["db_path"],
|
31 |
+
"db_" + self.config["vectorstore"]["db_option"],
|
32 |
+
)
|
33 |
+
self.vectorstore = RAGPretrainedModel.from_index(
|
34 |
+
f"{path}/colbert/indexes/new_idx"
|
35 |
+
)
|
36 |
+
return self.vectorstore
|
37 |
+
|
38 |
+
def as_retriever(self):
|
39 |
+
return self.vectorstore.as_retriever()
|
code/modules/{embedding_model_loader.py β vectorstore/embedding_model_loader.py}
RENAMED
@@ -2,10 +2,7 @@ from langchain_community.embeddings import OpenAIEmbeddings
|
|
2 |
from langchain_community.embeddings import HuggingFaceEmbeddings
|
3 |
from langchain_community.embeddings import LlamaCppEmbeddings
|
4 |
|
5 |
-
|
6 |
-
from modules.constants import *
|
7 |
-
except:
|
8 |
-
from constants import *
|
9 |
import os
|
10 |
|
11 |
|
@@ -14,19 +11,19 @@ class EmbeddingModelLoader:
|
|
14 |
self.config = config
|
15 |
|
16 |
def load_embedding_model(self):
|
17 |
-
if self.config["
|
18 |
embedding_model = OpenAIEmbeddings(
|
19 |
deployment="SL-document_embedder",
|
20 |
-
model=self.config["
|
21 |
show_progress_bar=True,
|
22 |
openai_api_key=OPENAI_API_KEY,
|
23 |
disallowed_special=(),
|
24 |
)
|
25 |
else:
|
26 |
embedding_model = HuggingFaceEmbeddings(
|
27 |
-
model_name=self.config["
|
28 |
model_kwargs={
|
29 |
-
"device": "
|
30 |
"token": f"{HUGGINGFACE_TOKEN}",
|
31 |
"trust_remote_code": True,
|
32 |
},
|
|
|
2 |
from langchain_community.embeddings import HuggingFaceEmbeddings
|
3 |
from langchain_community.embeddings import LlamaCppEmbeddings
|
4 |
|
5 |
+
from modules.config.constants import *
|
|
|
|
|
|
|
6 |
import os
|
7 |
|
8 |
|
|
|
11 |
self.config = config
|
12 |
|
13 |
def load_embedding_model(self):
|
14 |
+
if self.config["vectorstore"]["model"] in ["text-embedding-ada-002"]:
|
15 |
embedding_model = OpenAIEmbeddings(
|
16 |
deployment="SL-document_embedder",
|
17 |
+
model=self.config["vectorestore"]["model"],
|
18 |
show_progress_bar=True,
|
19 |
openai_api_key=OPENAI_API_KEY,
|
20 |
disallowed_special=(),
|
21 |
)
|
22 |
else:
|
23 |
embedding_model = HuggingFaceEmbeddings(
|
24 |
+
model_name=self.config["vectorstore"]["model"],
|
25 |
model_kwargs={
|
26 |
+
"device": f"{self.config['device']}",
|
27 |
"token": f"{HUGGINGFACE_TOKEN}",
|
28 |
"trust_remote_code": True,
|
29 |
},
|
code/modules/vectorstore/faiss.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain_community.vectorstores import FAISS
|
2 |
+
from modules.vectorstore.base import VectorStoreBase
|
3 |
+
import os
|
4 |
+
|
5 |
+
|
6 |
+
class FaissVectorStore(VectorStoreBase):
|
7 |
+
def __init__(self, config):
|
8 |
+
self.config = config
|
9 |
+
self._init_vector_db()
|
10 |
+
|
11 |
+
def _init_vector_db(self):
|
12 |
+
self.faiss = FAISS(
|
13 |
+
embedding_function=None, index=0, index_to_docstore_id={}, docstore={}
|
14 |
+
)
|
15 |
+
|
16 |
+
def create_database(self, document_chunks, embedding_model):
|
17 |
+
self.vectorstore = self.faiss.from_documents(
|
18 |
+
documents=document_chunks, embedding=embedding_model
|
19 |
+
)
|
20 |
+
self.vectorstore.save_local(
|
21 |
+
os.path.join(
|
22 |
+
self.config["vectorstore"]["db_path"],
|
23 |
+
"db_"
|
24 |
+
+ self.config["vectorstore"]["db_option"]
|
25 |
+
+ "_"
|
26 |
+
+ self.config["vectorstore"]["model"],
|
27 |
+
)
|
28 |
+
)
|
29 |
+
|
30 |
+
def load_database(self, embedding_model):
|
31 |
+
self.vectorstore = self.faiss.load_local(
|
32 |
+
os.path.join(
|
33 |
+
self.config["vectorstore"]["db_path"],
|
34 |
+
"db_"
|
35 |
+
+ self.config["vectorstore"]["db_option"]
|
36 |
+
+ "_"
|
37 |
+
+ self.config["vectorstore"]["model"],
|
38 |
+
),
|
39 |
+
embedding_model,
|
40 |
+
allow_dangerous_deserialization=True,
|
41 |
+
)
|
42 |
+
return self.vectorstore
|
43 |
+
|
44 |
+
def as_retriever(self):
|
45 |
+
return self.vectorstore.as_retriever()
|
code/modules/vectorstore/helpers.py
ADDED
File without changes
|
code/modules/vectorstore/raptor.py
ADDED
@@ -0,0 +1,438 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# code modified from https://github.com/langchain-ai/langchain/blob/master/cookbook/RAPTOR.ipynb
|
2 |
+
|
3 |
+
from typing import Dict, List, Optional, Tuple
|
4 |
+
import os
|
5 |
+
import numpy as np
|
6 |
+
import pandas as pd
|
7 |
+
import umap
|
8 |
+
from langchain_core.prompts import ChatPromptTemplate
|
9 |
+
from langchain_core.output_parsers import StrOutputParser
|
10 |
+
from sklearn.mixture import GaussianMixture
|
11 |
+
from langchain_community.chat_models import ChatOpenAI
|
12 |
+
from langchain_community.vectorstores import FAISS
|
13 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
14 |
+
from modules.vectorstore.base import VectorStoreBase
|
15 |
+
|
16 |
+
RANDOM_SEED = 42
|
17 |
+
|
18 |
+
|
19 |
+
class RAPTORVectoreStore(VectorStoreBase):
|
20 |
+
def __init__(self, config, documents=[], text_splitter=None, embedding_model=None):
|
21 |
+
self.documents = documents
|
22 |
+
self.config = config
|
23 |
+
self.text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
|
24 |
+
chunk_size=self.config["splitter_options"]["chunk_size"],
|
25 |
+
chunk_overlap=self.config["splitter_options"]["chunk_overlap"],
|
26 |
+
separators=self.config["splitter_options"]["chunk_separators"],
|
27 |
+
disallowed_special=(),
|
28 |
+
)
|
29 |
+
self.embd = embedding_model
|
30 |
+
self.model = ChatOpenAI(
|
31 |
+
model="gpt-3.5-turbo",
|
32 |
+
)
|
33 |
+
|
34 |
+
def concat_documents(self, documents):
|
35 |
+
d_sorted = sorted(documents, key=lambda x: x.metadata["source"])
|
36 |
+
d_reversed = list(reversed(d_sorted))
|
37 |
+
concatenated_content = "\n\n\n --- \n\n\n".join(
|
38 |
+
[doc.page_content for doc in d_reversed]
|
39 |
+
)
|
40 |
+
return concatenated_content
|
41 |
+
|
42 |
+
def split_documents(self, documents):
|
43 |
+
concatenated_content = self.concat_documents(documents)
|
44 |
+
texts_split = self.text_splitter.split_text(concatenated_content)
|
45 |
+
return texts_split
|
46 |
+
|
47 |
+
def add_documents(self, documents):
|
48 |
+
self.documents.extend(documents)
|
49 |
+
|
50 |
+
def global_cluster_embeddings(
|
51 |
+
self,
|
52 |
+
embeddings: np.ndarray,
|
53 |
+
dim: int,
|
54 |
+
n_neighbors: Optional[int] = None,
|
55 |
+
metric: str = "cosine",
|
56 |
+
) -> np.ndarray:
|
57 |
+
"""
|
58 |
+
Perform global dimensionality reduction on the embeddings using UMAP.
|
59 |
+
|
60 |
+
Parameters:
|
61 |
+
- embeddings: The input embeddings as a numpy array.
|
62 |
+
- dim: The target dimensionality for the reduced space.
|
63 |
+
- n_neighbors: Optional; the number of neighbors to consider for each point.
|
64 |
+
If not provided, it defaults to the square root of the number of embeddings.
|
65 |
+
- metric: The distance metric to use for UMAP.
|
66 |
+
|
67 |
+
Returns:
|
68 |
+
- A numpy array of the embeddings reduced to the specified dimensionality.
|
69 |
+
"""
|
70 |
+
if n_neighbors is None:
|
71 |
+
n_neighbors = int((len(embeddings) - 1) ** 0.5)
|
72 |
+
return umap.UMAP(
|
73 |
+
n_neighbors=n_neighbors, n_components=dim, metric=metric
|
74 |
+
).fit_transform(embeddings)
|
75 |
+
|
76 |
+
def local_cluster_embeddings(
|
77 |
+
self,
|
78 |
+
embeddings: np.ndarray,
|
79 |
+
dim: int,
|
80 |
+
num_neighbors: int = 10,
|
81 |
+
metric: str = "cosine",
|
82 |
+
) -> np.ndarray:
|
83 |
+
"""
|
84 |
+
Perform local dimensionality reduction on the embeddings using UMAP, typically after global clustering.
|
85 |
+
|
86 |
+
Parameters:
|
87 |
+
- embeddings: The input embeddings as a numpy array.
|
88 |
+
- dim: The target dimensionality for the reduced space.
|
89 |
+
- num_neighbors: The number of neighbors to consider for each point.
|
90 |
+
- metric: The distance metric to use for UMAP.
|
91 |
+
|
92 |
+
Returns:
|
93 |
+
- A numpy array of the embeddings reduced to the specified dimensionality.
|
94 |
+
"""
|
95 |
+
return umap.UMAP(
|
96 |
+
n_neighbors=num_neighbors, n_components=dim, metric=metric
|
97 |
+
).fit_transform(embeddings)
|
98 |
+
|
99 |
+
def get_optimal_clusters(
|
100 |
+
self,
|
101 |
+
embeddings: np.ndarray,
|
102 |
+
max_clusters: int = 50,
|
103 |
+
random_state: int = RANDOM_SEED,
|
104 |
+
) -> int:
|
105 |
+
"""
|
106 |
+
Determine the optimal number of clusters using the Bayesian Information Criterion (BIC) with a Gaussian Mixture Model.
|
107 |
+
|
108 |
+
Parameters:
|
109 |
+
- embeddings: The input embeddings as a numpy array.
|
110 |
+
- max_clusters: The maximum number of clusters to consider.
|
111 |
+
- random_state: Seed for reproducibility.
|
112 |
+
|
113 |
+
Returns:
|
114 |
+
- An integer representing the optimal number of clusters found.
|
115 |
+
"""
|
116 |
+
max_clusters = min(max_clusters, len(embeddings))
|
117 |
+
n_clusters = np.arange(1, max_clusters)
|
118 |
+
bics = []
|
119 |
+
for n in n_clusters:
|
120 |
+
gm = GaussianMixture(n_components=n, random_state=random_state)
|
121 |
+
gm.fit(embeddings)
|
122 |
+
bics.append(gm.bic(embeddings))
|
123 |
+
return n_clusters[np.argmin(bics)]
|
124 |
+
|
125 |
+
def GMM_cluster(
|
126 |
+
self, embeddings: np.ndarray, threshold: float, random_state: int = 0
|
127 |
+
):
|
128 |
+
"""
|
129 |
+
Cluster embeddings using a Gaussian Mixture Model (GMM) based on a probability threshold.
|
130 |
+
|
131 |
+
Parameters:
|
132 |
+
- embeddings: The input embeddings as a numpy array.
|
133 |
+
- threshold: The probability threshold for assigning an embedding to a cluster.
|
134 |
+
- random_state: Seed for reproducibility.
|
135 |
+
|
136 |
+
Returns:
|
137 |
+
- A tuple containing the cluster labels and the number of clusters determined.
|
138 |
+
"""
|
139 |
+
n_clusters = self.get_optimal_clusters(embeddings)
|
140 |
+
gm = GaussianMixture(n_components=n_clusters, random_state=random_state)
|
141 |
+
gm.fit(embeddings)
|
142 |
+
probs = gm.predict_proba(embeddings)
|
143 |
+
labels = [np.where(prob > threshold)[0] for prob in probs]
|
144 |
+
return labels, n_clusters
|
145 |
+
|
146 |
+
def perform_clustering(
|
147 |
+
self,
|
148 |
+
embeddings: np.ndarray,
|
149 |
+
dim: int,
|
150 |
+
threshold: float,
|
151 |
+
) -> List[np.ndarray]:
|
152 |
+
"""
|
153 |
+
Perform clustering on the embeddings by first reducing their dimensionality globally, then clustering
|
154 |
+
using a Gaussian Mixture Model, and finally performing local clustering within each global cluster.
|
155 |
+
|
156 |
+
Parameters:
|
157 |
+
- embeddings: The input embeddings as a numpy array.
|
158 |
+
- dim: The target dimensionality for UMAP reduction.
|
159 |
+
- threshold: The probability threshold for assigning an embedding to a cluster in GMM.
|
160 |
+
|
161 |
+
Returns:
|
162 |
+
- A list of numpy arrays, where each array contains the cluster IDs for each embedding.
|
163 |
+
"""
|
164 |
+
if len(embeddings) <= dim + 1:
|
165 |
+
# Avoid clustering when there's insufficient data
|
166 |
+
return [np.array([0]) for _ in range(len(embeddings))]
|
167 |
+
|
168 |
+
# Global dimensionality reduction
|
169 |
+
reduced_embeddings_global = self.global_cluster_embeddings(embeddings, dim)
|
170 |
+
# Global clustering
|
171 |
+
global_clusters, n_global_clusters = self.GMM_cluster(
|
172 |
+
reduced_embeddings_global, threshold
|
173 |
+
)
|
174 |
+
|
175 |
+
all_local_clusters = [np.array([]) for _ in range(len(embeddings))]
|
176 |
+
total_clusters = 0
|
177 |
+
|
178 |
+
# Iterate through each global cluster to perform local clustering
|
179 |
+
for i in range(n_global_clusters):
|
180 |
+
# Extract embeddings belonging to the current global cluster
|
181 |
+
global_cluster_embeddings_ = embeddings[
|
182 |
+
np.array([i in gc for gc in global_clusters])
|
183 |
+
]
|
184 |
+
|
185 |
+
if len(global_cluster_embeddings_) == 0:
|
186 |
+
continue
|
187 |
+
if len(global_cluster_embeddings_) <= dim + 1:
|
188 |
+
# Handle small clusters with direct assignment
|
189 |
+
local_clusters = [np.array([0]) for _ in global_cluster_embeddings_]
|
190 |
+
n_local_clusters = 1
|
191 |
+
else:
|
192 |
+
# Local dimensionality reduction and clustering
|
193 |
+
reduced_embeddings_local = self.local_cluster_embeddings(
|
194 |
+
global_cluster_embeddings_, dim
|
195 |
+
)
|
196 |
+
local_clusters, n_local_clusters = self.GMM_cluster(
|
197 |
+
reduced_embeddings_local, threshold
|
198 |
+
)
|
199 |
+
|
200 |
+
# Assign local cluster IDs, adjusting for total clusters already processed
|
201 |
+
for j in range(n_local_clusters):
|
202 |
+
local_cluster_embeddings_ = global_cluster_embeddings_[
|
203 |
+
np.array([j in lc for lc in local_clusters])
|
204 |
+
]
|
205 |
+
indices = np.where(
|
206 |
+
(embeddings == local_cluster_embeddings_[:, None]).all(-1)
|
207 |
+
)[1]
|
208 |
+
for idx in indices:
|
209 |
+
all_local_clusters[idx] = np.append(
|
210 |
+
all_local_clusters[idx], j + total_clusters
|
211 |
+
)
|
212 |
+
|
213 |
+
total_clusters += n_local_clusters
|
214 |
+
|
215 |
+
return all_local_clusters
|
216 |
+
|
217 |
+
def embed(self, texts):
|
218 |
+
"""
|
219 |
+
Generate embeddings for a list of text documents.
|
220 |
+
|
221 |
+
This function assumes the existence of an `embd` object with a method `embed_documents`
|
222 |
+
that takes a list of texts and returns their embeddings.
|
223 |
+
|
224 |
+
Parameters:
|
225 |
+
- texts: List[str], a list of text documents to be embedded.
|
226 |
+
|
227 |
+
Returns:
|
228 |
+
- numpy.ndarray: An array of embeddings for the given text documents.
|
229 |
+
"""
|
230 |
+
text_embeddings = self.embd.embed_documents(texts)
|
231 |
+
text_embeddings_np = np.array(text_embeddings)
|
232 |
+
return text_embeddings_np
|
233 |
+
|
234 |
+
def embed_cluster_texts(self, texts):
|
235 |
+
"""
|
236 |
+
Embeds a list of texts and clusters them, returning a DataFrame with texts, their embeddings, and cluster labels.
|
237 |
+
|
238 |
+
This function combines embedding generation and clustering into a single step. It assumes the existence
|
239 |
+
of a previously defined `perform_clustering` function that performs clustering on the embeddings.
|
240 |
+
|
241 |
+
Parameters:
|
242 |
+
- texts: List[str], a list of text documents to be processed.
|
243 |
+
|
244 |
+
Returns:
|
245 |
+
- pandas.DataFrame: A DataFrame containing the original texts, their embeddings, and the assigned cluster labels.
|
246 |
+
"""
|
247 |
+
text_embeddings_np = self.embed(texts) # Generate embeddings
|
248 |
+
cluster_labels = self.perform_clustering(
|
249 |
+
text_embeddings_np, 10, 0.1
|
250 |
+
) # Perform clustering on the embeddings
|
251 |
+
df = pd.DataFrame() # Initialize a DataFrame to store the results
|
252 |
+
df["text"] = texts # Store original texts
|
253 |
+
df["embd"] = list(
|
254 |
+
text_embeddings_np
|
255 |
+
) # Store embeddings as a list in the DataFrame
|
256 |
+
df["cluster"] = cluster_labels # Store cluster labels
|
257 |
+
return df
|
258 |
+
|
259 |
+
def fmt_txt(self, df: pd.DataFrame) -> str:
|
260 |
+
"""
|
261 |
+
Formats the text documents in a DataFrame into a single string.
|
262 |
+
|
263 |
+
Parameters:
|
264 |
+
- df: DataFrame containing the 'text' column with text documents to format.
|
265 |
+
|
266 |
+
Returns:
|
267 |
+
- A single string where all text documents are joined by a specific delimiter.
|
268 |
+
"""
|
269 |
+
unique_txt = df["text"].tolist()
|
270 |
+
return "--- --- \n --- --- ".join(unique_txt)
|
271 |
+
|
272 |
+
def embed_cluster_summarize_texts(
|
273 |
+
self, texts: List[str], level: int
|
274 |
+
) -> Tuple[pd.DataFrame, pd.DataFrame]:
|
275 |
+
"""
|
276 |
+
Embeds, clusters, and summarizes a list of texts. This function first generates embeddings for the texts,
|
277 |
+
clusters them based on similarity, expands the cluster assignments for easier processing, and then summarizes
|
278 |
+
the content within each cluster.
|
279 |
+
|
280 |
+
Parameters:
|
281 |
+
- texts: A list of text documents to be processed.
|
282 |
+
- level: An integer parameter that could define the depth or detail of processing.
|
283 |
+
|
284 |
+
Returns:
|
285 |
+
- Tuple containing two DataFrames:
|
286 |
+
1. The first DataFrame (`df_clusters`) includes the original texts, their embeddings, and cluster assignments.
|
287 |
+
2. The second DataFrame (`df_summary`) contains summaries for each cluster, the specified level of detail,
|
288 |
+
and the cluster identifiers.
|
289 |
+
"""
|
290 |
+
|
291 |
+
# Embed and cluster the texts, resulting in a DataFrame with 'text', 'embd', and 'cluster' columns
|
292 |
+
df_clusters = self.embed_cluster_texts(texts)
|
293 |
+
|
294 |
+
# Prepare to expand the DataFrame for easier manipulation of clusters
|
295 |
+
expanded_list = []
|
296 |
+
|
297 |
+
# Expand DataFrame entries to document-cluster pairings for straightforward processing
|
298 |
+
for index, row in df_clusters.iterrows():
|
299 |
+
for cluster in row["cluster"]:
|
300 |
+
expanded_list.append(
|
301 |
+
{"text": row["text"], "embd": row["embd"], "cluster": cluster}
|
302 |
+
)
|
303 |
+
|
304 |
+
# Create a new DataFrame from the expanded list
|
305 |
+
expanded_df = pd.DataFrame(expanded_list)
|
306 |
+
|
307 |
+
# Retrieve unique cluster identifiers for processing
|
308 |
+
all_clusters = expanded_df["cluster"].unique()
|
309 |
+
|
310 |
+
print(f"--Generated {len(all_clusters)} clusters--")
|
311 |
+
|
312 |
+
# Summarization
|
313 |
+
template = """Here is content from the course DS598: Deep Learning for Data Science.
|
314 |
+
|
315 |
+
The content may be form webapge about the course, or lecture content, or any other relevant information.
|
316 |
+
If the content is in bullet points (from pdf lectre slides), you can summarize the bullet points.
|
317 |
+
|
318 |
+
Give a detailed summary of the content below.
|
319 |
+
|
320 |
+
Documentation:
|
321 |
+
{context}
|
322 |
+
"""
|
323 |
+
prompt = ChatPromptTemplate.from_template(template)
|
324 |
+
chain = prompt | self.model | StrOutputParser()
|
325 |
+
|
326 |
+
# Format text within each cluster for summarization
|
327 |
+
summaries = []
|
328 |
+
for i in all_clusters:
|
329 |
+
df_cluster = expanded_df[expanded_df["cluster"] == i]
|
330 |
+
formatted_txt = self.fmt_txt(df_cluster)
|
331 |
+
summaries.append(chain.invoke({"context": formatted_txt}))
|
332 |
+
|
333 |
+
# Create a DataFrame to store summaries with their corresponding cluster and level
|
334 |
+
df_summary = pd.DataFrame(
|
335 |
+
{
|
336 |
+
"summaries": summaries,
|
337 |
+
"level": [level] * len(summaries),
|
338 |
+
"cluster": list(all_clusters),
|
339 |
+
}
|
340 |
+
)
|
341 |
+
|
342 |
+
return df_clusters, df_summary
|
343 |
+
|
344 |
+
def recursive_embed_cluster_summarize(
|
345 |
+
self, texts: List[str], level: int = 1, n_levels: int = 3
|
346 |
+
) -> Dict[int, Tuple[pd.DataFrame, pd.DataFrame]]:
|
347 |
+
"""
|
348 |
+
Recursively embeds, clusters, and summarizes texts up to a specified level or until
|
349 |
+
the number of unique clusters becomes 1, storing the results at each level.
|
350 |
+
|
351 |
+
Parameters:
|
352 |
+
- texts: List[str], texts to be processed.
|
353 |
+
- level: int, current recursion level (starts at 1).
|
354 |
+
- n_levels: int, maximum depth of recursion.
|
355 |
+
|
356 |
+
Returns:
|
357 |
+
- Dict[int, Tuple[pd.DataFrame, pd.DataFrame]], a dictionary where keys are the recursion
|
358 |
+
levels and values are tuples containing the clusters DataFrame and summaries DataFrame at that level.
|
359 |
+
"""
|
360 |
+
results = {} # Dictionary to store results at each level
|
361 |
+
|
362 |
+
# Perform embedding, clustering, and summarization for the current level
|
363 |
+
df_clusters, df_summary = self.embed_cluster_summarize_texts(texts, level)
|
364 |
+
|
365 |
+
# Store the results of the current level
|
366 |
+
results[level] = (df_clusters, df_summary)
|
367 |
+
|
368 |
+
# Determine if further recursion is possible and meaningful
|
369 |
+
unique_clusters = df_summary["cluster"].nunique()
|
370 |
+
if level < n_levels and unique_clusters > 1:
|
371 |
+
# Use summaries as the input texts for the next level of recursion
|
372 |
+
new_texts = df_summary["summaries"].tolist()
|
373 |
+
next_level_results = self.recursive_embed_cluster_summarize(
|
374 |
+
new_texts, level + 1, n_levels
|
375 |
+
)
|
376 |
+
|
377 |
+
# Merge the results from the next level into the current results dictionary
|
378 |
+
results.update(next_level_results)
|
379 |
+
|
380 |
+
return results
|
381 |
+
|
382 |
+
def get_vector_db(self):
|
383 |
+
"""
|
384 |
+
Generate a retriever object from a list of documents.
|
385 |
+
|
386 |
+
Parameters:
|
387 |
+
- documents: List of document objects.
|
388 |
+
|
389 |
+
Returns:
|
390 |
+
- A retriever object.
|
391 |
+
"""
|
392 |
+
leaf_texts = self.split_documents(self.documents)
|
393 |
+
results = self.recursive_embed_cluster_summarize(
|
394 |
+
leaf_texts, level=1, n_levels=10
|
395 |
+
)
|
396 |
+
|
397 |
+
all_texts = leaf_texts.copy()
|
398 |
+
# Iterate through the results to extract summaries from each level and add them to all_texts
|
399 |
+
for level in sorted(results.keys()):
|
400 |
+
# Extract summaries from the current level's DataFrame
|
401 |
+
summaries = results[level][1]["summaries"].tolist()
|
402 |
+
# Extend all_texts with the summaries from the current level
|
403 |
+
all_texts.extend(summaries)
|
404 |
+
|
405 |
+
# Now, use all_texts to build the vectorstore
|
406 |
+
vectorstore = FAISS.from_texts(texts=all_texts, embedding=self.embd)
|
407 |
+
return vectorstore
|
408 |
+
|
409 |
+
def create_database(self, documents, embedding_model):
|
410 |
+
self.documents = documents
|
411 |
+
self.embd = embedding_model
|
412 |
+
self.vectorstore = self.get_vector_db()
|
413 |
+
self.vectorstore.save_local(
|
414 |
+
os.path.join(
|
415 |
+
self.config["vectorstore"]["db_path"],
|
416 |
+
"db_"
|
417 |
+
+ self.config["vectorstore"]["db_option"]
|
418 |
+
+ "_"
|
419 |
+
+ self.config["vectorstore"]["model"],
|
420 |
+
)
|
421 |
+
)
|
422 |
+
|
423 |
+
def load_database(self, embedding_model):
|
424 |
+
self.vectorstore = FAISS.load_local(
|
425 |
+
os.path.join(
|
426 |
+
self.config["vectorstore"]["db_path"],
|
427 |
+
"db_"
|
428 |
+
+ self.config["vectorstore"]["db_option"]
|
429 |
+
+ "_"
|
430 |
+
+ self.config["vectorstore"]["model"],
|
431 |
+
),
|
432 |
+
embedding_model,
|
433 |
+
allow_dangerous_deserialization=True,
|
434 |
+
)
|
435 |
+
return self.vectorstore
|
436 |
+
|
437 |
+
def as_retriever(self):
|
438 |
+
return self.vectorstore.as_retriever()
|
code/modules/vectorstore/store_manager.py
ADDED
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from modules.vectorstore.vectorstore import VectorStore
|
2 |
+
from modules.vectorstore.helpers import *
|
3 |
+
from modules.dataloader.webpage_crawler import WebpageCrawler
|
4 |
+
from modules.dataloader.data_loader import DataLoader
|
5 |
+
from modules.dataloader.helpers import *
|
6 |
+
from modules.vectorstore.embedding_model_loader import EmbeddingModelLoader
|
7 |
+
import logging
|
8 |
+
import os
|
9 |
+
import time
|
10 |
+
import asyncio
|
11 |
+
|
12 |
+
|
13 |
+
class VectorStoreManager:
|
14 |
+
def __init__(self, config, logger=None):
|
15 |
+
self.config = config
|
16 |
+
self.document_names = None
|
17 |
+
|
18 |
+
# Set up logging to both console and a file
|
19 |
+
self.logger = logger or self._setup_logging()
|
20 |
+
self.webpage_crawler = WebpageCrawler()
|
21 |
+
self.vector_db = VectorStore(self.config)
|
22 |
+
|
23 |
+
self.logger.info("VectorDB instance instantiated")
|
24 |
+
|
25 |
+
def _setup_logging(self):
|
26 |
+
logger = logging.getLogger(__name__)
|
27 |
+
if not logger.hasHandlers():
|
28 |
+
logger.setLevel(logging.INFO)
|
29 |
+
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
|
30 |
+
|
31 |
+
# Console Handler
|
32 |
+
console_handler = logging.StreamHandler()
|
33 |
+
console_handler.setLevel(logging.INFO)
|
34 |
+
console_handler.setFormatter(formatter)
|
35 |
+
logger.addHandler(console_handler)
|
36 |
+
|
37 |
+
# Ensure log directory exists
|
38 |
+
log_directory = self.config["log_dir"]
|
39 |
+
os.makedirs(log_directory, exist_ok=True)
|
40 |
+
|
41 |
+
# File Handler
|
42 |
+
log_file_path = os.path.join(log_directory, "vector_db.log")
|
43 |
+
file_handler = logging.FileHandler(log_file_path, mode="w")
|
44 |
+
file_handler.setLevel(logging.INFO)
|
45 |
+
file_handler.setFormatter(formatter)
|
46 |
+
logger.addHandler(file_handler)
|
47 |
+
|
48 |
+
return logger
|
49 |
+
|
50 |
+
def load_files(self):
|
51 |
+
|
52 |
+
files = os.listdir(self.config["vectorstore"]["data_path"])
|
53 |
+
files = [
|
54 |
+
os.path.join(self.config["vectorstore"]["data_path"], file)
|
55 |
+
for file in files
|
56 |
+
]
|
57 |
+
urls = get_urls_from_file(self.config["vectorstore"]["url_file_path"])
|
58 |
+
if self.config["vectorstore"]["expand_urls"]:
|
59 |
+
all_urls = []
|
60 |
+
for url in urls:
|
61 |
+
loop = asyncio.get_event_loop()
|
62 |
+
all_urls.extend(
|
63 |
+
loop.run_until_complete(
|
64 |
+
self.webpage_crawler.get_all_pages(
|
65 |
+
url, url
|
66 |
+
) # only get child urls, if you want to get all urls, replace the second argument with the base url
|
67 |
+
)
|
68 |
+
)
|
69 |
+
urls = all_urls
|
70 |
+
return files, urls
|
71 |
+
|
72 |
+
def create_embedding_model(self):
|
73 |
+
|
74 |
+
self.logger.info("Creating embedding function")
|
75 |
+
embedding_model_loader = EmbeddingModelLoader(self.config)
|
76 |
+
embedding_model = embedding_model_loader.load_embedding_model()
|
77 |
+
return embedding_model
|
78 |
+
|
79 |
+
def initialize_database(
|
80 |
+
self,
|
81 |
+
document_chunks: list,
|
82 |
+
document_names: list,
|
83 |
+
documents: list,
|
84 |
+
document_metadata: list,
|
85 |
+
):
|
86 |
+
if self.config["vectorstore"]["db_option"] in ["FAISS", "Chroma", "RAPTOR"]:
|
87 |
+
self.embedding_model = self.create_embedding_model()
|
88 |
+
else:
|
89 |
+
self.embedding_model = None
|
90 |
+
|
91 |
+
self.logger.info("Initializing vector_db")
|
92 |
+
self.logger.info(
|
93 |
+
"\tUsing {} as db_option".format(self.config["vectorstore"]["db_option"])
|
94 |
+
)
|
95 |
+
self.vector_db._create_database(
|
96 |
+
document_chunks,
|
97 |
+
document_names,
|
98 |
+
documents,
|
99 |
+
document_metadata,
|
100 |
+
self.embedding_model,
|
101 |
+
)
|
102 |
+
|
103 |
+
def create_database(self):
|
104 |
+
|
105 |
+
start_time = time.time() # Start time for creating database
|
106 |
+
data_loader = DataLoader(self.config, self.logger)
|
107 |
+
self.logger.info("Loading data")
|
108 |
+
files, urls = self.load_files()
|
109 |
+
files, webpages = self.webpage_crawler.clean_url_list(urls)
|
110 |
+
self.logger.info(f"Number of files: {len(files)}")
|
111 |
+
self.logger.info(f"Number of webpages: {len(webpages)}")
|
112 |
+
if f"{self.config['vectorstore']['url_file_path']}" in files:
|
113 |
+
files.remove(f"{self.config['vectorstores']['url_file_path']}") # cleanup
|
114 |
+
document_chunks, document_names, documents, document_metadata = (
|
115 |
+
data_loader.get_chunks(files, webpages)
|
116 |
+
)
|
117 |
+
num_documents = len(document_chunks)
|
118 |
+
self.logger.info(f"Number of documents in the DB: {num_documents}")
|
119 |
+
metadata_keys = list(document_metadata[0].keys())
|
120 |
+
self.logger.info(f"Metadata keys: {metadata_keys}")
|
121 |
+
self.logger.info("Completed loading data")
|
122 |
+
self.initialize_database(
|
123 |
+
document_chunks, document_names, documents, document_metadata
|
124 |
+
)
|
125 |
+
end_time = time.time() # End time for creating database
|
126 |
+
self.logger.info("Created database")
|
127 |
+
self.logger.info(
|
128 |
+
f"Time taken to create database: {end_time - start_time} seconds"
|
129 |
+
)
|
130 |
+
|
131 |
+
def load_database(self):
|
132 |
+
|
133 |
+
start_time = time.time() # Start time for loading database
|
134 |
+
if self.config["vectorstore"]["db_option"] in ["FAISS", "Chroma", "RAPTOR"]:
|
135 |
+
self.embedding_model = self.create_embedding_model()
|
136 |
+
else:
|
137 |
+
self.embedding_model = None
|
138 |
+
self.loaded_vector_db = self.vector_db._load_database(self.embedding_model)
|
139 |
+
end_time = time.time() # End time for loading database
|
140 |
+
self.logger.info(
|
141 |
+
f"Time taken to load database: {end_time - start_time} seconds"
|
142 |
+
)
|
143 |
+
self.logger.info("Loaded database")
|
144 |
+
return self.loaded_vector_db
|
145 |
+
|
146 |
+
|
147 |
+
if __name__ == "__main__":
|
148 |
+
import yaml
|
149 |
+
|
150 |
+
with open("modules/config/config.yml", "r") as f:
|
151 |
+
config = yaml.safe_load(f)
|
152 |
+
print(config)
|
153 |
+
print(f"Trying to create database with config: {config}")
|
154 |
+
vector_db = VectorStoreManager(config)
|
155 |
+
vector_db.create_database()
|
156 |
+
print("Created database")
|
157 |
+
|
158 |
+
print(f"Trying to load the database")
|
159 |
+
vector_db = VectorStoreManager(config)
|
160 |
+
vector_db.load_database()
|
161 |
+
print("Loaded database")
|
162 |
+
|
163 |
+
print(f"View the logs at {config['log_dir']}/vector_db.log")
|
code/modules/vectorstore/vectorstore.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from modules.vectorstore.faiss import FaissVectorStore
|
2 |
+
from modules.vectorstore.chroma import ChromaVectorStore
|
3 |
+
from modules.vectorstore.colbert import ColbertVectorStore
|
4 |
+
from modules.vectorstore.raptor import RAPTORVectoreStore
|
5 |
+
|
6 |
+
|
7 |
+
class VectorStore:
|
8 |
+
def __init__(self, config):
|
9 |
+
self.config = config
|
10 |
+
self.vectorstore = None
|
11 |
+
self.vectorstore_classes = {
|
12 |
+
"FAISS": FaissVectorStore,
|
13 |
+
"Chroma": ChromaVectorStore,
|
14 |
+
"RAGatouille": ColbertVectorStore,
|
15 |
+
"RAPTOR": RAPTORVectoreStore,
|
16 |
+
}
|
17 |
+
|
18 |
+
def _create_database(
|
19 |
+
self,
|
20 |
+
document_chunks,
|
21 |
+
document_names,
|
22 |
+
documents,
|
23 |
+
document_metadata,
|
24 |
+
embedding_model,
|
25 |
+
):
|
26 |
+
db_option = self.config["vectorstore"]["db_option"]
|
27 |
+
vectorstore_class = self.vectorstore_classes.get(db_option)
|
28 |
+
if not vectorstore_class:
|
29 |
+
raise ValueError(f"Invalid db_option: {db_option}")
|
30 |
+
|
31 |
+
self.vectorstore = vectorstore_class(self.config)
|
32 |
+
|
33 |
+
if db_option == "RAGatouille":
|
34 |
+
self.vectorstore.create_database(
|
35 |
+
documents, document_names, document_metadata
|
36 |
+
)
|
37 |
+
else:
|
38 |
+
self.vectorstore.create_database(document_chunks, embedding_model)
|
39 |
+
|
40 |
+
def _load_database(self, embedding_model):
|
41 |
+
db_option = self.config["vectorstore"]["db_option"]
|
42 |
+
vectorstore_class = self.vectorstore_classes.get(db_option)
|
43 |
+
if not vectorstore_class:
|
44 |
+
raise ValueError(f"Invalid db_option: {db_option}")
|
45 |
+
|
46 |
+
self.vectorstore = vectorstore_class(self.config)
|
47 |
+
|
48 |
+
if db_option == "RAGatouille":
|
49 |
+
return self.vectorstore.load_database()
|
50 |
+
else:
|
51 |
+
return self.vectorstore.load_database(embedding_model)
|
52 |
+
|
53 |
+
def _as_retriever(self):
|
54 |
+
return self.vectorstore.as_retriever()
|
55 |
+
|
56 |
+
def _get_vectorstore(self):
|
57 |
+
return self.vectorstore
|
code/public/acastusphoton-svgrepo-com.svg
ADDED
code/public/adv-screen-recorder-svgrepo-com.svg
ADDED
code/public/alarmy-svgrepo-com.svg
ADDED
public/logo_dark.png β code/public/avatars/ai-tutor.png
RENAMED
File without changes
|
code/public/calendar-samsung-17-svgrepo-com.svg
ADDED
public/logo_light.png β code/public/logo_dark.png
RENAMED
File without changes
|
code/public/logo_light.png
ADDED