Spaces:
Runtime error
Runtime error
KonradSzafer
commited on
Commit
·
d22d549
1
Parent(s):
7a13d67
refactor update
Browse files- data/datasets/hf_repositories_urls_scraped.json +92 -70
- data/hugging_face_docs_dataset.py +2 -1
- data/{indexer.ipynb → index.ipynb} +60 -6
- data/{indexing_benchmark.ipynb → index_benchmark.ipynb} +0 -0
- data/{scrapers → stackoverflow_scrapers}/stack_overflow_scraper.py +0 -0
- data/{stackoverflow_python_dataset.py → stackoverflow_scrapers/stackoverflow_python_dataset.py} +0 -0
- data/{upload_csv_dataset.py → stackoverflow_scrapers/upload_csv_dataset.py} +0 -0
- qa_engine/config.py +1 -1
- qa_engine/qa_engine.py +24 -20
data/datasets/hf_repositories_urls_scraped.json
CHANGED
@@ -1,91 +1,113 @@
|
|
1 |
{
|
2 |
"urls": [
|
3 |
-
"https://github.com/huggingface/tokenizers",
|
4 |
-
"https://github.com/huggingface/datablations",
|
5 |
-
"https://github.com/huggingface/peft",
|
6 |
-
"https://github.com/huggingface/tflite-android-transformers",
|
7 |
-
"https://github.com/huggingface/simulate",
|
8 |
-
"https://github.com/huggingface/transformers",
|
9 |
-
"https://github.com/huggingface/deep-rl-class",
|
10 |
-
"https://github.com/huggingface/awesome-huggingface",
|
11 |
-
"https://github.com/huggingface/datasets-server",
|
12 |
-
"https://github.com/huggingface/setfit",
|
13 |
-
"https://github.com/huggingface/olm-training",
|
14 |
-
"https://github.com/huggingface/huggingface_sb3",
|
15 |
-
"https://github.com/huggingface/optimum-neuron",
|
16 |
-
"https://github.com/huggingface/blog",
|
17 |
-
"https://github.com/huggingface/100-times-faster-nlp",
|
18 |
-
"https://github.com/huggingface/bloom-jax-inference",
|
19 |
-
"https://github.com/huggingface/speechbox",
|
20 |
-
"https://github.com/huggingface/olm-datasets",
|
21 |
-
"https://github.com/huggingface/hub-docs",
|
22 |
-
"https://github.com/huggingface/torchMoji",
|
23 |
-
"https://github.com/huggingface/hffs",
|
24 |
"https://github.com/huggingface/trl",
|
25 |
-
"https://github.com/huggingface/
|
26 |
-
"https://github.com/huggingface/
|
27 |
"https://github.com/huggingface/education-toolkit",
|
28 |
-
"https://github.com/huggingface/
|
29 |
-
"https://github.com/huggingface/
|
30 |
-
"https://github.com/huggingface/
|
31 |
-
"https://github.com/huggingface/
|
32 |
-
"https://github.com/huggingface/
|
33 |
-
"https://github.com/huggingface/fuego",
|
34 |
-
"https://github.com/huggingface/diffusion-models-class",
|
35 |
-
"https://github.com/huggingface/disaggregators",
|
36 |
-
"https://github.com/huggingface/unity-api",
|
37 |
-
"https://github.com/huggingface/workshops",
|
38 |
-
"https://github.com/huggingface/llm-ls",
|
39 |
"https://github.com/huggingface/llm-vscode",
|
40 |
-
"https://github.com/huggingface/
|
41 |
-
"https://github.com/huggingface/
|
42 |
-
"https://github.com/huggingface/
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
"https://github.com/huggingface/paper-style-guide",
|
44 |
-
"https://github.com/huggingface/
|
45 |
-
"https://github.com/huggingface/neuralcoref",
|
46 |
-
"https://github.com/huggingface/hfapi",
|
47 |
-
"https://github.com/huggingface/data-measurements-tool",
|
48 |
-
"https://github.com/huggingface/personas",
|
49 |
-
"https://github.com/huggingface/instruction-tuned-sd",
|
50 |
-
"https://github.com/huggingface/swift-transformers",
|
51 |
-
"https://github.com/huggingface/api-inference-community",
|
52 |
-
"https://github.com/huggingface/diffusers",
|
53 |
"https://github.com/huggingface/safetensors",
|
54 |
-
"https://github.com/huggingface/
|
55 |
-
"https://github.com/huggingface/OBELICS",
|
56 |
-
"https://github.com/huggingface/swift-coreml-diffusers",
|
57 |
"https://github.com/huggingface/naacl_transfer_learning_tutorial",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
"https://github.com/huggingface/nn_pruning",
|
|
|
|
|
|
|
|
|
|
|
59 |
"https://github.com/huggingface/awesome-papers",
|
60 |
-
"https://github.com/huggingface/optimum-intel",
|
61 |
-
"https://github.com/huggingface/autotrain-advanced",
|
62 |
-
"https://github.com/huggingface/pytorch-openai-transformer-lm",
|
63 |
-
"https://github.com/huggingface/node-question-answering",
|
64 |
"https://github.com/huggingface/optimum",
|
65 |
-
"https://github.com/huggingface/knockknock",
|
66 |
-
"https://github.com/huggingface/optimum-habana",
|
67 |
-
"https://github.com/huggingface/transfer-learning-conv-ai",
|
68 |
-
"https://github.com/huggingface/notebooks",
|
69 |
-
"https://github.com/huggingface/hmtl",
|
70 |
-
"https://github.com/huggingface/block_movement_pruning",
|
71 |
-
"https://github.com/huggingface/huggingface_hub",
|
72 |
"https://github.com/huggingface/transformers-bloom-inference",
|
73 |
-
"https://github.com/huggingface/
|
74 |
-
"https://github.com/huggingface/
|
|
|
|
|
75 |
"https://github.com/huggingface/large_language_model_training_playbook",
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
"https://github.com/huggingface/that_is_good_data",
|
77 |
-
"https://github.com/huggingface/
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
"https://github.com/huggingface/datasets-viewer",
|
79 |
-
"https://github.com/huggingface/
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
80 |
"https://github.com/huggingface/evaluate",
|
81 |
-
"https://github.com/huggingface/
|
82 |
-
"https://github.com/huggingface/
|
|
|
|
|
|
|
|
|
83 |
"https://github.com/huggingface/chat-ui",
|
84 |
-
"https://github.com/huggingface/
|
85 |
-
"https://github.com/huggingface/
|
86 |
-
"https://github.com/huggingface/
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
87 |
"https://github.com/huggingface/exporters",
|
88 |
-
"https://github.com/huggingface/
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
"https://github.com/huggingface/hf-endpoints-documentation",
|
90 |
"https://github.com/gradio-app/gradio"
|
91 |
]
|
|
|
1 |
{
|
2 |
"urls": [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
"https://github.com/huggingface/trl",
|
4 |
+
"https://github.com/huggingface/bert-syntax",
|
5 |
+
"https://github.com/huggingface/pytorch_block_sparse",
|
6 |
"https://github.com/huggingface/education-toolkit",
|
7 |
+
"https://github.com/huggingface/diffusion-fast",
|
8 |
+
"https://github.com/huggingface/swift-transformers",
|
9 |
+
"https://github.com/huggingface/llm_training_handbook",
|
10 |
+
"https://github.com/huggingface/awesome-huggingface",
|
11 |
+
"https://github.com/huggingface/m4-logs",
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
"https://github.com/huggingface/llm-vscode",
|
13 |
+
"https://github.com/huggingface/huggingface_sb3",
|
14 |
+
"https://github.com/huggingface/audio-transformers-course",
|
15 |
+
"https://github.com/huggingface/huggingface_hub",
|
16 |
+
"https://github.com/huggingface/swift-chat",
|
17 |
+
"https://github.com/huggingface/swift-coreml-transformers",
|
18 |
+
"https://github.com/huggingface/notebooks",
|
19 |
+
"https://github.com/huggingface/datasets-server",
|
20 |
+
"https://github.com/huggingface/adversarialnlp",
|
21 |
+
"https://github.com/huggingface/alignment-handbook",
|
22 |
+
"https://github.com/huggingface/workshops",
|
23 |
+
"https://github.com/huggingface/torchMoji",
|
24 |
"https://github.com/huggingface/paper-style-guide",
|
25 |
+
"https://github.com/huggingface/optimum-intel",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
"https://github.com/huggingface/safetensors",
|
27 |
+
"https://github.com/huggingface/accelerate",
|
|
|
|
|
28 |
"https://github.com/huggingface/naacl_transfer_learning_tutorial",
|
29 |
+
"https://github.com/huggingface/hfapi",
|
30 |
+
"https://github.com/huggingface/optimum-neuron",
|
31 |
+
"https://github.com/huggingface/simulate",
|
32 |
+
"https://github.com/huggingface/unity-api",
|
33 |
+
"https://github.com/huggingface/instruction-tuned-sd",
|
34 |
+
"https://github.com/huggingface/disaggregators",
|
35 |
+
"https://github.com/huggingface/personas",
|
36 |
+
"https://github.com/huggingface/pytorch-openai-transformer-lm",
|
37 |
+
"https://github.com/huggingface/llm-ls",
|
38 |
"https://github.com/huggingface/nn_pruning",
|
39 |
+
"https://github.com/huggingface/speechbox",
|
40 |
+
"https://github.com/huggingface/community-events",
|
41 |
+
"https://github.com/huggingface/tflite-android-transformers",
|
42 |
+
"https://github.com/huggingface/neuralcoref-viz",
|
43 |
+
"https://github.com/huggingface/amused",
|
44 |
"https://github.com/huggingface/awesome-papers",
|
|
|
|
|
|
|
|
|
45 |
"https://github.com/huggingface/optimum",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
"https://github.com/huggingface/transformers-bloom-inference",
|
47 |
+
"https://github.com/huggingface/open-muse",
|
48 |
+
"https://github.com/huggingface/pytorch-image-models",
|
49 |
+
"https://github.com/huggingface/olm-datasets",
|
50 |
+
"https://github.com/huggingface/datablations",
|
51 |
"https://github.com/huggingface/large_language_model_training_playbook",
|
52 |
+
"https://github.com/huggingface/candle",
|
53 |
+
"https://github.com/huggingface/hf-hub",
|
54 |
+
"https://github.com/huggingface/transformers_bloom_parallel",
|
55 |
+
"https://github.com/huggingface/optimum-benchmark",
|
56 |
+
"https://github.com/huggingface/Mongoku",
|
57 |
+
"https://github.com/huggingface/hf_transfer",
|
58 |
"https://github.com/huggingface/that_is_good_data",
|
59 |
+
"https://github.com/huggingface/100-times-faster-nlp",
|
60 |
+
"https://github.com/huggingface/fuego",
|
61 |
+
"https://github.com/huggingface/optimum-graphcore",
|
62 |
+
"https://github.com/huggingface/peft",
|
63 |
+
"https://github.com/huggingface/tokenizers",
|
64 |
+
"https://github.com/huggingface/llm.nvim",
|
65 |
+
"https://github.com/huggingface/autotrain-advanced",
|
66 |
+
"https://github.com/huggingface/blog",
|
67 |
"https://github.com/huggingface/datasets-viewer",
|
68 |
+
"https://github.com/huggingface/huggingface.js",
|
69 |
+
"https://github.com/huggingface/diffusion-models-class",
|
70 |
+
"https://github.com/huggingface/rlhf-interface",
|
71 |
+
"https://github.com/huggingface/neuralcoref",
|
72 |
+
"https://github.com/huggingface/pytorch-pretrained-BigGAN",
|
73 |
+
"https://github.com/huggingface/distil-whisper",
|
74 |
+
"https://github.com/huggingface/quanto",
|
75 |
+
"https://github.com/huggingface/text-embeddings-inference",
|
76 |
+
"https://github.com/huggingface/course",
|
77 |
"https://github.com/huggingface/evaluate",
|
78 |
+
"https://github.com/huggingface/datasets",
|
79 |
+
"https://github.com/huggingface/optimum-habana",
|
80 |
+
"https://github.com/huggingface/hub-docs",
|
81 |
+
"https://github.com/huggingface/node-question-answering",
|
82 |
+
"https://github.com/huggingface/tune",
|
83 |
+
"https://github.com/huggingface/discord-bots",
|
84 |
"https://github.com/huggingface/chat-ui",
|
85 |
+
"https://github.com/huggingface/setfit",
|
86 |
+
"https://github.com/huggingface/transformers",
|
87 |
+
"https://github.com/huggingface/swift-coreml-diffusers",
|
88 |
+
"https://github.com/huggingface/OBELICS",
|
89 |
+
"https://github.com/huggingface/text-generation-inference",
|
90 |
+
"https://github.com/huggingface/transfer-learning-conv-ai",
|
91 |
+
"https://github.com/huggingface/llm-intellij",
|
92 |
+
"https://github.com/huggingface/api-inference-community",
|
93 |
+
"https://github.com/huggingface/optimum-nvidia",
|
94 |
+
"https://github.com/huggingface/sharp-transformers",
|
95 |
"https://github.com/huggingface/exporters",
|
96 |
+
"https://github.com/huggingface/doc-builder",
|
97 |
+
"https://github.com/huggingface/olm-training",
|
98 |
+
"https://github.com/huggingface/deep-rl-class",
|
99 |
+
"https://github.com/huggingface/zapier",
|
100 |
+
"https://github.com/huggingface/hffs",
|
101 |
+
"https://github.com/huggingface/hmtl",
|
102 |
+
"https://github.com/huggingface/block_movement_pruning",
|
103 |
+
"https://github.com/huggingface/data-measurements-tool",
|
104 |
+
"https://github.com/huggingface/knockknock",
|
105 |
+
"https://github.com/huggingface/bloom-jax-inference",
|
106 |
+
"https://github.com/huggingface/frp",
|
107 |
+
"https://github.com/huggingface/gsplat.js",
|
108 |
+
"https://github.com/huggingface/ml-agents",
|
109 |
+
"https://github.com/huggingface/competitions",
|
110 |
+
"https://github.com/huggingface/diffusers",
|
111 |
"https://github.com/huggingface/hf-endpoints-documentation",
|
112 |
"https://github.com/gradio-app/gradio"
|
113 |
]
|
data/hugging_face_docs_dataset.py
CHANGED
@@ -180,7 +180,8 @@ def markdown_cleaner(data: str):
|
|
180 |
|
181 |
|
182 |
if __name__ == '__main__':
|
183 |
-
repo_urls_file = "./datasets/hf_repositories_urls.json"
|
|
|
184 |
repo_dir = "./datasets/huggingface_repositories/"
|
185 |
docs_dir = "./datasets/huggingface_docs/"
|
186 |
download_repositories(repo_urls_file, repo_dir)
|
|
|
180 |
|
181 |
|
182 |
if __name__ == '__main__':
|
183 |
+
# repo_urls_file = "./datasets/hf_repositories_urls.json"
|
184 |
+
repo_urls_file = "./datasets/hf_repositories_urls_scraped.json"
|
185 |
repo_dir = "./datasets/huggingface_repositories/"
|
186 |
docs_dir = "./datasets/huggingface_docs/"
|
187 |
download_repositories(repo_urls_file, repo_dir)
|
data/{indexer.ipynb → index.ipynb}
RENAMED
@@ -8,6 +8,7 @@
|
|
8 |
"source": [
|
9 |
"import math\n",
|
10 |
"from pathlib import Path\n",
|
|
|
11 |
"from typing import Any\n",
|
12 |
"\n",
|
13 |
"import numpy as np\n",
|
@@ -18,7 +19,14 @@
|
|
18 |
"from langchain.indexes import VectorstoreIndexCreator\n",
|
19 |
"from langchain.text_splitter import CharacterTextSplitter\n",
|
20 |
"from langchain.vectorstores import FAISS\n",
|
21 |
-
"from huggingface_hub import HfApi"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
]
|
23 |
},
|
24 |
{
|
@@ -63,11 +71,13 @@
|
|
63 |
"metadata": {},
|
64 |
"outputs": [],
|
65 |
"source": [
|
66 |
-
"
|
|
|
|
|
67 |
"text_splitter = CharacterTextSplitter(\n",
|
68 |
" separator=\"\",\n",
|
69 |
-
" chunk_size=
|
70 |
-
" chunk_overlap=
|
71 |
" length_function=len,\n",
|
72 |
")\n",
|
73 |
"docs = text_splitter.create_documents(docs, metadata)\n",
|
@@ -127,7 +137,7 @@
|
|
127 |
" return all_embeddings\n",
|
128 |
"\n",
|
129 |
"\n",
|
130 |
-
"# max length fed to the
|
131 |
"# if longer than CHUNK_SIZE in previous steps: then N chunks + averaging of embeddings\n",
|
132 |
"max_length = 512\n",
|
133 |
"embedding_model = AverageInstructEmbeddings( \n",
|
@@ -156,13 +166,21 @@
|
|
156 |
"index = FAISS.from_documents(docs, embedding_model)"
|
157 |
]
|
158 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
159 |
{
|
160 |
"cell_type": "code",
|
161 |
"execution_count": null,
|
162 |
"metadata": {},
|
163 |
"outputs": [],
|
164 |
"source": [
|
165 |
-
"
|
|
|
166 |
"index_name = index_name.replace('/', '_')"
|
167 |
]
|
168 |
},
|
@@ -198,6 +216,7 @@
|
|
198 |
" print(f\"Document {i} of {len(docs)}\")\n",
|
199 |
" print(\"Page Content:\")\n",
|
200 |
" print(f\"\\n{'-'*100}\\n\")\n",
|
|
|
201 |
" print(doc.page_content, '\\n')\n",
|
202 |
" print(doc.metadata)"
|
203 |
]
|
@@ -221,6 +240,41 @@
|
|
221 |
" repo_type='dataset',\n",
|
222 |
")"
|
223 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
224 |
}
|
225 |
],
|
226 |
"metadata": {
|
|
|
8 |
"source": [
|
9 |
"import math\n",
|
10 |
"from pathlib import Path\n",
|
11 |
+
"from datetime import datetime\n",
|
12 |
"from typing import Any\n",
|
13 |
"\n",
|
14 |
"import numpy as np\n",
|
|
|
19 |
"from langchain.indexes import VectorstoreIndexCreator\n",
|
20 |
"from langchain.text_splitter import CharacterTextSplitter\n",
|
21 |
"from langchain.vectorstores import FAISS\n",
|
22 |
+
"from huggingface_hub import HfApi, snapshot_download"
|
23 |
+
]
|
24 |
+
},
|
25 |
+
{
|
26 |
+
"cell_type": "markdown",
|
27 |
+
"metadata": {},
|
28 |
+
"source": [
|
29 |
+
"## Index building"
|
30 |
]
|
31 |
},
|
32 |
{
|
|
|
71 |
"metadata": {},
|
72 |
"outputs": [],
|
73 |
"source": [
|
74 |
+
"# if split_chunk_size > 512 model is processing first 512 characters of the chunk\n",
|
75 |
+
"split_chunk_size = 800\n",
|
76 |
+
"chunk_overlap = 200\n",
|
77 |
"text_splitter = CharacterTextSplitter(\n",
|
78 |
" separator=\"\",\n",
|
79 |
+
" chunk_size=split_chunk_size,\n",
|
80 |
+
" chunk_overlap=chunk_overlap,\n",
|
81 |
" length_function=len,\n",
|
82 |
")\n",
|
83 |
"docs = text_splitter.create_documents(docs, metadata)\n",
|
|
|
137 |
" return all_embeddings\n",
|
138 |
"\n",
|
139 |
"\n",
|
140 |
+
"# max length fed to the model\n",
|
141 |
"# if longer than CHUNK_SIZE in previous steps: then N chunks + averaging of embeddings\n",
|
142 |
"max_length = 512\n",
|
143 |
"embedding_model = AverageInstructEmbeddings( \n",
|
|
|
166 |
"index = FAISS.from_documents(docs, embedding_model)"
|
167 |
]
|
168 |
},
|
169 |
+
{
|
170 |
+
"cell_type": "markdown",
|
171 |
+
"metadata": {},
|
172 |
+
"source": [
|
173 |
+
"## Index uploading"
|
174 |
+
]
|
175 |
+
},
|
176 |
{
|
177 |
"cell_type": "code",
|
178 |
"execution_count": null,
|
179 |
"metadata": {},
|
180 |
"outputs": [],
|
181 |
"source": [
|
182 |
+
"todays_date = datetime.now().strftime('%d_%b_%Y')\n",
|
183 |
+
"index_name = f'index-{model_name}-{split_chunk_size}-{chunk_overlap}-m{max_length}-{todays_date}'\n",
|
184 |
"index_name = index_name.replace('/', '_')"
|
185 |
]
|
186 |
},
|
|
|
216 |
" print(f\"Document {i} of {len(docs)}\")\n",
|
217 |
" print(\"Page Content:\")\n",
|
218 |
" print(f\"\\n{'-'*100}\\n\")\n",
|
219 |
+
" print(f'length of a chunk: {len(doc.page_content)}')\n",
|
220 |
" print(doc.page_content, '\\n')\n",
|
221 |
" print(doc.metadata)"
|
222 |
]
|
|
|
240 |
" repo_type='dataset',\n",
|
241 |
")"
|
242 |
]
|
243 |
+
},
|
244 |
+
{
|
245 |
+
"cell_type": "markdown",
|
246 |
+
"metadata": {},
|
247 |
+
"source": [
|
248 |
+
"## Index inference"
|
249 |
+
]
|
250 |
+
},
|
251 |
+
{
|
252 |
+
"cell_type": "code",
|
253 |
+
"execution_count": null,
|
254 |
+
"metadata": {},
|
255 |
+
"outputs": [],
|
256 |
+
"source": [
|
257 |
+
"index_repo_id = f'KonradSzafer/index-hkunlp_instructor-large-512-m512-11_Jan_2024'\n",
|
258 |
+
"\n",
|
259 |
+
"snapshot_download(\n",
|
260 |
+
" repo_id=index_repo_id,\n",
|
261 |
+
" allow_patterns=['*.faiss', '*.pkl'], \n",
|
262 |
+
" repo_type='dataset',\n",
|
263 |
+
" local_dir='../indexes/run/'\n",
|
264 |
+
")"
|
265 |
+
]
|
266 |
+
},
|
267 |
+
{
|
268 |
+
"cell_type": "code",
|
269 |
+
"execution_count": null,
|
270 |
+
"metadata": {},
|
271 |
+
"outputs": [],
|
272 |
+
"source": [
|
273 |
+
"index = FAISS.load_local('../indexes/run/', embedding_model)\n",
|
274 |
+
"docs = index.similarity_search(query='how to create a pipeline object?', k=5)\n",
|
275 |
+
"docs[0].metadata\n",
|
276 |
+
"docs[0].page_content"
|
277 |
+
]
|
278 |
}
|
279 |
],
|
280 |
"metadata": {
|
data/{indexing_benchmark.ipynb → index_benchmark.ipynb}
RENAMED
File without changes
|
data/{scrapers → stackoverflow_scrapers}/stack_overflow_scraper.py
RENAMED
File without changes
|
data/{stackoverflow_python_dataset.py → stackoverflow_scrapers/stackoverflow_python_dataset.py}
RENAMED
File without changes
|
data/{upload_csv_dataset.py → stackoverflow_scrapers/upload_csv_dataset.py}
RENAMED
File without changes
|
qa_engine/config.py
CHANGED
@@ -36,7 +36,7 @@ class Config:
|
|
36 |
|
37 |
# Discord bot config - optional
|
38 |
discord_token: str = get_env('DISCORD_TOKEN', '-', warn=False)
|
39 |
-
discord_channel_ids: list[int] = get_env('DISCORD_CHANNEL_IDS', field(default_factory=list), warn=
|
40 |
num_last_messages: int = int(get_env('NUM_LAST_MESSAGES', 2, warn=False))
|
41 |
use_names_in_context: bool = eval(get_env('USE_NAMES_IN_CONTEXT', 'False', warn=False))
|
42 |
enable_commands: bool = eval(get_env('ENABLE_COMMANDS', 'True', warn=False))
|
|
|
36 |
|
37 |
# Discord bot config - optional
|
38 |
discord_token: str = get_env('DISCORD_TOKEN', '-', warn=False)
|
39 |
+
discord_channel_ids: list[int] = get_env('DISCORD_CHANNEL_IDS', field(default_factory=list), warn=False)
|
40 |
num_last_messages: int = int(get_env('NUM_LAST_MESSAGES', 2, warn=False))
|
41 |
use_names_in_context: bool = eval(get_env('USE_NAMES_IN_CONTEXT', 'False', warn=False))
|
42 |
enable_commands: bool = eval(get_env('ENABLE_COMMANDS', 'True', warn=False))
|
qa_engine/qa_engine.py
CHANGED
@@ -181,30 +181,11 @@ class QAEngine():
|
|
181 |
self.first_stage_docs = first_stage_docs
|
182 |
self.debug = debug
|
183 |
|
184 |
-
if 'local_models/' in llm_model_id:
|
185 |
-
logger.info('using local binary model')
|
186 |
-
self.llm_model = LocalBinaryModel(
|
187 |
-
model_id=llm_model_id
|
188 |
-
)
|
189 |
-
elif 'api_models/' in llm_model_id:
|
190 |
-
logger.info('using api served model')
|
191 |
-
self.llm_model = APIServedModel(
|
192 |
-
model_url=llm_model_id.replace('api_models/', ''),
|
193 |
-
debug=self.debug
|
194 |
-
)
|
195 |
-
elif llm_model_id == 'mock':
|
196 |
-
logger.info('using mock model')
|
197 |
-
self.llm_model = MockLocalBinaryModel()
|
198 |
-
else:
|
199 |
-
logger.info('using transformers pipeline model')
|
200 |
-
self.llm_model = TransformersPipelineModel(
|
201 |
-
model_id=llm_model_id
|
202 |
-
)
|
203 |
-
|
204 |
prompt = PromptTemplate(
|
205 |
template=prompt_template,
|
206 |
input_variables=['question', 'context']
|
207 |
)
|
|
|
208 |
self.llm_chain = LLMChain(prompt=prompt, llm=self.llm_model)
|
209 |
|
210 |
if self.use_docs_for_context:
|
@@ -228,6 +209,29 @@ class QAEngine():
|
|
228 |
self.reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-12-v2')
|
229 |
|
230 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
231 |
@staticmethod
|
232 |
def _preprocess_question(question: str) -> str:
|
233 |
if question[-1] != '?':
|
|
|
181 |
self.first_stage_docs = first_stage_docs
|
182 |
self.debug = debug
|
183 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
184 |
prompt = PromptTemplate(
|
185 |
template=prompt_template,
|
186 |
input_variables=['question', 'context']
|
187 |
)
|
188 |
+
self.llm_model = QAEngine._get_model(llm_model_id)
|
189 |
self.llm_chain = LLMChain(prompt=prompt, llm=self.llm_model)
|
190 |
|
191 |
if self.use_docs_for_context:
|
|
|
209 |
self.reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-12-v2')
|
210 |
|
211 |
|
212 |
+
@staticmethod
|
213 |
+
def _get_model(llm_model_id: str):
|
214 |
+
if 'local_models/' in llm_model_id:
|
215 |
+
logger.info('using local binary model')
|
216 |
+
return LocalBinaryModel(
|
217 |
+
model_id=llm_model_id
|
218 |
+
)
|
219 |
+
elif 'api_models/' in llm_model_id:
|
220 |
+
logger.info('using api served model')
|
221 |
+
return APIServedModel(
|
222 |
+
model_url=llm_model_id.replace('api_models/', ''),
|
223 |
+
debug=self.debug
|
224 |
+
)
|
225 |
+
elif llm_model_id == 'mock':
|
226 |
+
logger.info('using mock model')
|
227 |
+
return MockLocalBinaryModel()
|
228 |
+
else:
|
229 |
+
logger.info('using transformers pipeline model')
|
230 |
+
return TransformersPipelineModel(
|
231 |
+
model_id=llm_model_id
|
232 |
+
)
|
233 |
+
|
234 |
+
|
235 |
@staticmethod
|
236 |
def _preprocess_question(question: str) -> str:
|
237 |
if question[-1] != '?':
|