JiakaiDu commited on
Commit
cd1c110
1 Parent(s): aa28d6b

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. README.md +2 -9
  2. Test_RAG.py +878 -0
README.md CHANGED
@@ -1,13 +1,6 @@
1
  ---
2
- title: RAG Test
3
- emoji: 📉
4
- colorFrom: gray
5
- colorTo: gray
6
  sdk: gradio
7
  sdk_version: 4.44.0
8
- app_file: app.py
9
- pinned: false
10
- license: llama3
11
  ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: RAG_Test
3
+ app_file: Test_RAG.py
 
 
4
  sdk: gradio
5
  sdk_version: 4.44.0
 
 
 
6
  ---
 
 
Test_RAG.py ADDED
@@ -0,0 +1,878 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ["GIT_CLONE_PROTECTION_ACTIVE"] = "false"
3
+ from pathlib import Path
4
+ import requests
5
+ import shutil
6
+ import io
7
+ from pathlib import Path
8
+ import openvino as ov
9
+ import torch
10
+ import ipywidgets as widgets
11
+ from transformers import (
12
+ TextIteratorStreamer,
13
+ StoppingCriteria,
14
+ StoppingCriteriaList,
15
+ )
16
+ from llm_config import (
17
+ SUPPORTED_EMBEDDING_MODELS,
18
+ SUPPORTED_RERANK_MODELS,
19
+ SUPPORTED_LLM_MODELS,
20
+ )
21
+ from huggingface_hub import login
22
+
23
+
24
+ config_shared_path = Path("../../utils/llm_config.py")
25
+ config_dst_path = Path("llm_config.py")
26
+ text_example_en_path = Path("text_example_en.pdf")
27
+ text_example_cn_path = Path("text_example_cn.pdf")
28
+ text_example_en = "https://github.com/openvinotoolkit/openvino_notebooks/files/15039728/Platform.Brief_Intel.vPro.with.Intel.Core.Ultra_Final.pdf"
29
+ text_example_cn = "https://github.com/openvinotoolkit/openvino_notebooks/files/15039713/Platform.Brief_Intel.vPro.with.Intel.Core.Ultra_Final_CH.pdf"
30
+
31
+ if not config_dst_path.exists():
32
+ if config_shared_path.exists():
33
+ try:
34
+ os.symlink(config_shared_path, config_dst_path)
35
+ except Exception:
36
+ shutil.copy(config_shared_path, config_dst_path)
37
+ else:
38
+ r = requests.get(url="https://raw.githubusercontent.com/openvinotoolkit/openvino_notebooks/latest/utils/llm_config.py")
39
+ with open("llm_config.py", "w", encoding="utf-8") as f:
40
+ f.write(r.text)
41
+ elif not os.path.islink(config_dst_path):
42
+ print("LLM config will be updated")
43
+ if config_shared_path.exists():
44
+ shutil.copy(config_shared_path, config_dst_path)
45
+ else:
46
+ r = requests.get(url="https://raw.githubusercontent.com/openvinotoolkit/openvino_notebooks/latest/utils/llm_config.py")
47
+ with open("llm_config.py", "w", encoding="utf-8") as f:
48
+ f.write(r.text)
49
+
50
+
51
+ if not text_example_en_path.exists():
52
+ r = requests.get(url=text_example_en)
53
+ content = io.BytesIO(r.content)
54
+ with open("text_example_en.pdf", "wb") as f:
55
+ f.write(content.read())
56
+
57
+ if not text_example_cn_path.exists():
58
+ r = requests.get(url=text_example_cn)
59
+ content = io.BytesIO(r.content)
60
+ with open("text_example_cn.pdf", "wb") as f:
61
+ f.write(content.read())
62
+
63
+ model_language = "English"
64
+ llm_model_id= "llama-3-8b-instruct"
65
+ llm_model_configuration = SUPPORTED_LLM_MODELS[model_language][llm_model_id]
66
+ print(f"Selected LLM model {llm_model_id}")
67
+ prepare_int4_model = True # Prepare INT4 model
68
+ prepare_int8_model = False # Do not prepare INT8 model
69
+ prepare_fp16_model = False # Do not prepare FP16 model
70
+ enable_awq = False
71
+ # Get the token from the environment variable
72
+ hf_token = os.getenv("HUGGINGFACE_TOKEN")
73
+
74
+ if hf_token is None:
75
+ raise ValueError(
76
+ "HUGGINGFACE_TOKEN environment variable not set. "
77
+ "Please set it in your environment variables or repository secrets."
78
+ )
79
+
80
+ # Log in to Hugging Face Hub
81
+ login(token=hf_token)
82
+ pt_model_id = llm_model_configuration["model_id"]
83
+ # pt_model_name = llm_model_id.value.split("-")[0]
84
+ fp16_model_dir = Path(llm_model_id) / "FP16"
85
+ int8_model_dir = Path(llm_model_id) / "INT8_compressed_weights"
86
+ int4_model_dir = Path(llm_model_id) / "INT4_compressed_weights"
87
+
88
+
89
+ def convert_to_fp16():
90
+ if (fp16_model_dir / "openvino_model.xml").exists():
91
+ return
92
+ remote_code = llm_model_configuration.get("remote_code", False)
93
+ export_command_base = "optimum-cli export openvino --model {} --task text-generation-with-past --weight-format fp16".format(pt_model_id)
94
+ if remote_code:
95
+ export_command_base += " --trust-remote-code"
96
+ export_command = export_command_base + " " + str(fp16_model_dir)
97
+ display(Markdown("**Export command:**"))
98
+ display(Markdown(f"`{export_command}`"))
99
+ ! $export_command
100
+
101
+
102
+ def convert_to_int8():
103
+ if (int8_model_dir / "openvino_model.xml").exists():
104
+ return
105
+ int8_model_dir.mkdir(parents=True, exist_ok=True)
106
+ remote_code = llm_model_configuration.get("remote_code", False)
107
+ export_command_base = "optimum-cli export openvino --model {} --task text-generation-with-past --weight-format int8".format(pt_model_id)
108
+ if remote_code:
109
+ export_command_base += " --trust-remote-code"
110
+ export_command = export_command_base + " " + str(int8_model_dir)
111
+ display(Markdown("**Export command:**"))
112
+ display(Markdown(f"`{export_command}`"))
113
+ ! $export_command
114
+
115
+
116
+ def convert_to_int4():
117
+ compression_configs = {
118
+ "zephyr-7b-beta": {
119
+ "sym": True,
120
+ "group_size": 64,
121
+ "ratio": 0.6,
122
+ },
123
+ "mistral-7b": {
124
+ "sym": True,
125
+ "group_size": 64,
126
+ "ratio": 0.6,
127
+ },
128
+ "minicpm-2b-dpo": {
129
+ "sym": True,
130
+ "group_size": 64,
131
+ "ratio": 0.6,
132
+ },
133
+ "gemma-2b-it": {
134
+ "sym": True,
135
+ "group_size": 64,
136
+ "ratio": 0.6,
137
+ },
138
+ "notus-7b-v1": {
139
+ "sym": True,
140
+ "group_size": 64,
141
+ "ratio": 0.6,
142
+ },
143
+ "neural-chat-7b-v3-1": {
144
+ "sym": True,
145
+ "group_size": 64,
146
+ "ratio": 0.6,
147
+ },
148
+ "llama-2-chat-7b": {
149
+ "sym": True,
150
+ "group_size": 128,
151
+ "ratio": 0.8,
152
+ },
153
+ "llama-3-8b-instruct": {
154
+ "sym": True,
155
+ "group_size": 128,
156
+ "ratio": 0.8,
157
+ },
158
+ "gemma-7b-it": {
159
+ "sym": True,
160
+ "group_size": 128,
161
+ "ratio": 0.8,
162
+ },
163
+ "chatglm2-6b": {
164
+ "sym": True,
165
+ "group_size": 128,
166
+ "ratio": 0.72,
167
+ },
168
+ "qwen-7b-chat": {"sym": True, "group_size": 128, "ratio": 0.6},
169
+ "red-pajama-3b-chat": {
170
+ "sym": False,
171
+ "group_size": 128,
172
+ "ratio": 0.5,
173
+ },
174
+ "default": {
175
+ "sym": False,
176
+ "group_size": 128,
177
+ "ratio": 0.8,
178
+ },
179
+ }
180
+
181
+ model_compression_params = compression_configs.get(llm_model_id, compression_configs["default"])
182
+ if (int4_model_dir / "openvino_model.xml").exists():
183
+ return
184
+ remote_code = llm_model_configuration.get("remote_code", False)
185
+ export_command_base = "optimum-cli export openvino --model {} --task text-generation-with-past --weight-format int4".format(pt_model_id)
186
+ int4_compression_args = " --group-size {} --ratio {}".format(model_compression_params["group_size"], model_compression_params["ratio"])
187
+ if model_compression_params["sym"]:
188
+ int4_compression_args += " --sym"
189
+ if enable_awq.value:
190
+ int4_compression_args += " --awq --dataset wikitext2 --num-samples 128"
191
+ export_command_base += int4_compression_args
192
+ if remote_code:
193
+ export_command_base += " --trust-remote-code"
194
+ export_command = export_command_base + " " + str(int4_model_dir)
195
+
196
+
197
+
198
+ if prepare_fp16_model:
199
+ convert_to_fp16()
200
+ if prepare_int8_model:
201
+ convert_to_int8()
202
+ if prepare_int4_model:
203
+ convert_to_int4()
204
+ fp16_weights = fp16_model_dir / "openvino_model.bin"
205
+ int8_weights = int8_model_dir / "openvino_model.bin"
206
+ int4_weights = int4_model_dir / "openvino_model.bin"
207
+
208
+ if fp16_weights.exists():
209
+ print(f"Size of FP16 model is {fp16_weights.stat().st_size / 1024 / 1024:.2f} MB")
210
+ for precision, compressed_weights in zip([8, 4], [int8_weights, int4_weights]):
211
+ if compressed_weights.exists():
212
+ print(f"Size of model with INT{precision} compressed weights is {compressed_weights.stat().st_size / 1024 / 1024:.2f} MB")
213
+ if compressed_weights.exists() and fp16_weights.exists():
214
+ print(f"Compression rate for INT{precision} model: {fp16_weights.stat().st_size / compressed_weights.stat().st_size:.3f}")
215
+ embedding_model_id = 'bge-small-en-v1.5' #'bge-small-en-v1.5', 'bge-large-en-v1.5', 'bge-m3'), value='bge-small-en-v1.5'
216
+ embedding_model_configuration = SUPPORTED_EMBEDDING_MODELS[model_language][embedding_model_id]
217
+ print(f"Selected {embedding_model_id} model")
218
+ export_command_base = "optimum-cli export openvino --model {} --task feature-extraction".format(embedding_model_configuration["model_id"])
219
+ export_command = export_command_base + " " + str(embedding_model_id)
220
+ rerank_model_id = "bge-reranker-v2-m3" #'bge-reranker-v2-m3', 'bge-reranker-large', 'bge-reranker-base')
221
+ rerank_model_configuration = SUPPORTED_RERANK_MODELS[rerank_model_id]
222
+ print(f"Selected {rerank_model_id} model")
223
+ export_command_base = "optimum-cli export openvino --model {} --task text-classification".format(rerank_model_configuration["model_id"])
224
+ export_command = export_command_base + " " + str(rerank_model_id)
225
+ embedding_device = "CPU"
226
+ USING_NPU = embedding_device == "NPU"
227
+
228
+ npu_embedding_dir = embedding_model_id + "-npu"
229
+ npu_embedding_path = Path(npu_embedding_dir) / "openvino_model.xml"
230
+ if USING_NPU and not Path(npu_embedding_dir).exists():
231
+ r = requests.get(
232
+ url="https://raw.githubusercontent.com/openvinotoolkit/openvino_notebooks/latest/utils/notebook_utils.py",
233
+ )
234
+ with open("notebook_utils.py", "w") as f:
235
+ f.write(r.text)
236
+ import notebook_utils as utils
237
+
238
+ shutil.copytree(embedding_model_id, npu_embedding_dir)
239
+ utils.optimize_bge_embedding(Path(embedding_model_id) / "openvino_model.xml", npu_embedding_path)
240
+ rerank_device = "CPU"
241
+ llm_device = "CPU"
242
+ from langchain_community.embeddings import OpenVINOBgeEmbeddings
243
+
244
+ embedding_model_name = npu_embedding_dir if USING_NPU else embedding_model_id
245
+ batch_size = 1 if USING_NPU else 4
246
+ embedding_model_kwargs = {"device": embedding_device, "compile": False}
247
+ encode_kwargs = {
248
+ "mean_pooling": embedding_model_configuration["mean_pooling"],
249
+ "normalize_embeddings": embedding_model_configuration["normalize_embeddings"],
250
+ "batch_size": batch_size,
251
+ }
252
+
253
+ embedding = OpenVINOBgeEmbeddings(
254
+ model_name_or_path=embedding_model_name,
255
+ model_kwargs=embedding_model_kwargs,
256
+ encode_kwargs=encode_kwargs,
257
+ )
258
+ if USING_NPU:
259
+ embedding.ov_model.reshape(1, 512)
260
+ embedding.ov_model.compile()
261
+
262
+ text = "This is a test document."
263
+ embedding_result = embedding.embed_query(text)
264
+ embedding_result[:3]
265
+ from langchain_community.document_compressors.openvino_rerank import OpenVINOReranker
266
+
267
+ rerank_model_name = rerank_model_id
268
+ rerank_model_kwargs = {"device": rerank_device}
269
+ rerank_top_n = 2
270
+
271
+ reranker = OpenVINOReranker(
272
+ model_name_or_path=rerank_model_name,
273
+ model_kwargs=rerank_model_kwargs,
274
+ top_n=rerank_top_n,
275
+ )
276
+ model_to_run = "INT4"
277
+ from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
278
+
279
+ if model_to_run == "INT4":
280
+ model_dir = int4_model_dir
281
+ elif model_to_run == "INT8":
282
+ model_dir = int8_model_dir
283
+ else:
284
+ model_dir = fp16_model_dir
285
+ print(f"Loading model from {model_dir}")
286
+
287
+ ov_config = {"PERFORMANCE_HINT": "LATENCY", "NUM_STREAMS": "1", "CACHE_DIR": ""}
288
+
289
+ if "GPU" in llm_device and "qwen2-7b-instruct" in llm_model_id:
290
+ ov_config["GPU_ENABLE_SDPA_OPTIMIZATION"] = "NO"
291
+
292
+ # On a GPU device a model is executed in FP16 precision. For red-pajama-3b-chat model there known accuracy
293
+ # issues caused by this, which we avoid by setting precision hint to "f32".
294
+ if llm_model_id == "red-pajama-3b-chat" and "GPU" in core.available_devices and llm_device in ["GPU", "AUTO"]:
295
+ ov_config["INFERENCE_PRECISION_HINT"] = "f32"
296
+
297
+ llm = HuggingFacePipeline.from_model_id(
298
+ model_id=str(model_dir),
299
+ task="text-generation",
300
+ backend="openvino",
301
+ model_kwargs={
302
+ "device": llm_device,
303
+ "ov_config": ov_config,
304
+ "trust_remote_code": True,
305
+ },
306
+ pipeline_kwargs={"max_new_tokens": 2},
307
+ )
308
+
309
+ llm.invoke("2 + 2 =")
310
+ import re
311
+ from typing import List
312
+ from langchain.text_splitter import (
313
+ CharacterTextSplitter,
314
+ RecursiveCharacterTextSplitter,
315
+ MarkdownTextSplitter,
316
+ )
317
+ from langchain.document_loaders import (
318
+ CSVLoader,
319
+ EverNoteLoader,
320
+ PyPDFLoader,
321
+ TextLoader,
322
+ UnstructuredEPubLoader,
323
+ UnstructuredHTMLLoader,
324
+ UnstructuredMarkdownLoader,
325
+ UnstructuredODTLoader,
326
+ UnstructuredPowerPointLoader,
327
+ UnstructuredWordDocumentLoader,
328
+ )
329
+
330
+
331
+ class ChineseTextSplitter(CharacterTextSplitter):
332
+ def __init__(self, pdf: bool = False, **kwargs):
333
+ super().__init__(**kwargs)
334
+ self.pdf = pdf
335
+
336
+ def split_text(self, text: str) -> List[str]:
337
+ if self.pdf:
338
+ text = re.sub(r"\n{3,}", "\n", text)
339
+ text = text.replace("\n\n", "")
340
+ sent_sep_pattern = re.compile('([﹒﹔﹖﹗.。!?]["’”」』]{0,2}|(?=["‘“「『]{1,2}|$))')
341
+ sent_list = []
342
+ for ele in sent_sep_pattern.split(text):
343
+ if sent_sep_pattern.match(ele) and sent_list:
344
+ sent_list[-1] += ele
345
+ elif ele:
346
+ sent_list.append(ele)
347
+ return sent_list
348
+
349
+
350
+ TEXT_SPLITERS = {
351
+ "Character": CharacterTextSplitter,
352
+ "RecursiveCharacter": RecursiveCharacterTextSplitter,
353
+ "Markdown": MarkdownTextSplitter,
354
+ "Chinese": ChineseTextSplitter,
355
+ }
356
+
357
+
358
+ LOADERS = {
359
+ ".csv": (CSVLoader, {}),
360
+ ".doc": (UnstructuredWordDocumentLoader, {}),
361
+ ".docx": (UnstructuredWordDocumentLoader, {}),
362
+ ".enex": (EverNoteLoader, {}),
363
+ ".epub": (UnstructuredEPubLoader, {}),
364
+ ".html": (UnstructuredHTMLLoader, {}),
365
+ ".md": (UnstructuredMarkdownLoader, {}),
366
+ ".odt": (UnstructuredODTLoader, {}),
367
+ ".pdf": (PyPDFLoader, {}),
368
+ ".ppt": (UnstructuredPowerPointLoader, {}),
369
+ ".pptx": (UnstructuredPowerPointLoader, {}),
370
+ ".txt": (TextLoader, {"encoding": "utf8"}),
371
+ }
372
+
373
+ chinese_examples = [
374
+ ["英特尔®酷睿™ Ultra处理器可以降低多少功耗?"],
375
+ ["相比英特尔之前的移动处理器产品,英特尔®酷睿™ Ultra处理器的AI推理性能提升了多少?"],
376
+ ["英特尔博锐® Enterprise系统提供哪些功能?"],
377
+ ]
378
+
379
+ english_examples = [
380
+ ["How much power consumption can Intel® Core™ Ultra Processors help save?"],
381
+ ["Compared to Intel’s previous mobile processor, what is the advantage of Intel® Core™ Ultra Processors for Artificial Intelligence?"],
382
+ ["What can Intel vPro® Enterprise systems offer?"],
383
+ ]
384
+
385
+ if model_language == "English":
386
+ # text_example_path = "text_example_en.pdf"
387
+ text_example_path = ['Supervisors-Guide-Accurate-Timekeeping_AH edits.docx','Salary-vs-Hourly-Guide_AH edits.docx','Employee-Guide-Accurate-Timekeeping_AH edits.docx','Eller Overtime Guidelines.docx','Eller FLSA information 9.2024_AH edits.docx','Accurate Timekeeping Supervisors 12.2.20_AH edits.docx']
388
+ else:
389
+ text_example_path = "text_example_cn.pdf"
390
+
391
+ examples = chinese_examples if (model_language == "Chinese") else english_examples
392
+ from langchain.prompts import PromptTemplate
393
+ from langchain_community.vectorstores import FAISS
394
+ from langchain.chains.retrieval import create_retrieval_chain
395
+ from langchain.chains.combine_documents import create_stuff_documents_chain
396
+ from langchain.docstore.document import Document
397
+ from langchain.retrievers import ContextualCompressionRetriever
398
+ from threading import Thread
399
+ import gradio as gr
400
+
401
+ stop_tokens = llm_model_configuration.get("stop_tokens")
402
+ rag_prompt_template = llm_model_configuration["rag_prompt_template"]
403
+
404
+
405
+ class StopOnTokens(StoppingCriteria):
406
+ def __init__(self, token_ids):
407
+ self.token_ids = token_ids
408
+
409
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
410
+ for stop_id in self.token_ids:
411
+ if input_ids[0][-1] == stop_id:
412
+ return True
413
+ return False
414
+
415
+
416
+ if stop_tokens is not None:
417
+ if isinstance(stop_tokens[0], str):
418
+ stop_tokens = llm.pipeline.tokenizer.convert_tokens_to_ids(stop_tokens)
419
+
420
+ stop_tokens = [StopOnTokens(stop_tokens)]
421
+
422
+
423
+ def load_single_document(file_path: str) -> List[Document]:
424
+ """
425
+ helper for loading a single document
426
+
427
+ Params:
428
+ file_path: document path
429
+ Returns:
430
+ documents loaded
431
+
432
+ """
433
+ ext = "." + file_path.rsplit(".", 1)[-1]
434
+ if ext in LOADERS:
435
+ loader_class, loader_args = LOADERS[ext]
436
+ loader = loader_class(file_path, **loader_args)
437
+ return loader.load()
438
+
439
+ raise ValueError(f"File does not exist '{ext}'")
440
+
441
+
442
+ def default_partial_text_processor(partial_text: str, new_text: str):
443
+ """
444
+ helper for updating partially generated answer, used by default
445
+
446
+ Params:
447
+ partial_text: text buffer for storing previosly generated text
448
+ new_text: text update for the current step
449
+ Returns:
450
+ updated text string
451
+
452
+ """
453
+ partial_text += new_text
454
+ return partial_text
455
+
456
+
457
+ text_processor = llm_model_configuration.get("partial_text_processor", default_partial_text_processor)
458
+
459
+
460
+ def create_vectordb(
461
+ docs, spliter_name, chunk_size, chunk_overlap, vector_search_top_k, vector_rerank_top_n, run_rerank, search_method, score_threshold, progress=gr.Progress()
462
+ ):
463
+ """
464
+ Initialize a vector database
465
+
466
+ Params:
467
+ doc: orignal documents provided by user
468
+ spliter_name: spliter method
469
+ chunk_size: size of a single sentence chunk
470
+ chunk_overlap: overlap size between 2 chunks
471
+ vector_search_top_k: Vector search top k
472
+ vector_rerank_top_n: Search rerank top n
473
+ run_rerank: whether run reranker
474
+ search_method: top k search method
475
+ score_threshold: score threshold when selecting 'similarity_score_threshold' method
476
+
477
+ """
478
+ global db
479
+ global retriever
480
+ global combine_docs_chain
481
+ global rag_chain
482
+
483
+ if vector_rerank_top_n > vector_search_top_k:
484
+ gr.Warning("Search top k must >= Rerank top n")
485
+
486
+ documents = []
487
+ for doc in docs:
488
+ if type(doc) is not str:
489
+ doc = doc.name
490
+ documents.extend(load_single_document(doc))
491
+
492
+ text_splitter = TEXT_SPLITERS[spliter_name](chunk_size=chunk_size, chunk_overlap=chunk_overlap)
493
+
494
+ texts = text_splitter.split_documents(documents)
495
+ db = FAISS.from_documents(texts, embedding)
496
+ if search_method == "similarity_score_threshold":
497
+ search_kwargs = {"k": vector_search_top_k, "score_threshold": score_threshold}
498
+ else:
499
+ search_kwargs = {"k": vector_search_top_k}
500
+ retriever = db.as_retriever(search_kwargs=search_kwargs, search_type=search_method)
501
+ if run_rerank:
502
+ reranker.top_n = vector_rerank_top_n
503
+ retriever = ContextualCompressionRetriever(base_compressor=reranker, base_retriever=retriever)
504
+ prompt = PromptTemplate.from_template(rag_prompt_template)
505
+ combine_docs_chain = create_stuff_documents_chain(llm, prompt)
506
+
507
+ rag_chain = create_retrieval_chain(retriever, combine_docs_chain)
508
+
509
+ return "Vector database is Ready"
510
+
511
+
512
+ def update_retriever(vector_search_top_k, vector_rerank_top_n, run_rerank, search_method, score_threshold):
513
+ """
514
+ Update retriever
515
+
516
+ Params:
517
+ vector_search_top_k: Vector search top k
518
+ vector_rerank_top_n: Search rerank top n
519
+ run_rerank: whether run reranker
520
+ search_method: top k search method
521
+ score_threshold: score threshold when selecting 'similarity_score_threshold' method
522
+
523
+ """
524
+ global db
525
+ global retriever
526
+ global combine_docs_chain
527
+ global rag_chain
528
+
529
+ if vector_rerank_top_n > vector_search_top_k:
530
+ gr.Warning("Search top k must >= Rerank top n")
531
+
532
+ if search_method == "similarity_score_threshold":
533
+ search_kwargs = {"k": vector_search_top_k, "score_threshold": score_threshold}
534
+ else:
535
+ search_kwargs = {"k": vector_search_top_k}
536
+ retriever = db.as_retriever(search_kwargs=search_kwargs, search_type=search_method)
537
+ if run_rerank:
538
+ retriever = ContextualCompressionRetriever(base_compressor=reranker, base_retriever=retriever)
539
+ reranker.top_n = vector_rerank_top_n
540
+ rag_chain = create_retrieval_chain(retriever, combine_docs_chain)
541
+
542
+ return "Vector database is Ready"
543
+
544
+
545
+ def user(message, history):
546
+ """
547
+ callback function for updating user messages in interface on submit button click
548
+
549
+ Params:
550
+ message: current message
551
+ history: conversation history
552
+ Returns:
553
+ None
554
+ """
555
+ # Append the user's message to the conversation history
556
+ return "", history + [[message, ""]]
557
+
558
+
559
+ def bot(history, temperature, top_p, top_k, repetition_penalty, hide_full_prompt, do_rag):
560
+ """
561
+ callback function for running chatbot on submit button click
562
+
563
+ Params:
564
+ history: conversation history
565
+ temperature: parameter for control the level of creativity in AI-generated text.
566
+ By adjusting the `temperature`, you can influence the AI model's probability distribution, making the text more focused or diverse.
567
+ top_p: parameter for control the range of tokens considered by the AI model based on their cumulative probability.
568
+ top_k: parameter for control the range of tokens considered by the AI model based on their cumulative probability, selecting number of tokens with highest probability.
569
+ repetition_penalty: parameter for penalizing tokens based on how frequently they occur in the text.
570
+ hide_full_prompt: whether to show searching results in promopt.
571
+ do_rag: whether do RAG when generating texts.
572
+
573
+ """
574
+ streamer = TextIteratorStreamer(
575
+ llm.pipeline.tokenizer,
576
+ timeout=60.0,
577
+ skip_prompt=hide_full_prompt,
578
+ skip_special_tokens=True,
579
+ )
580
+ llm.pipeline._forward_params = dict(
581
+ max_new_tokens=512,
582
+ temperature=temperature,
583
+ do_sample=temperature > 0.0,
584
+ top_p=top_p,
585
+ top_k=top_k,
586
+ repetition_penalty=repetition_penalty,
587
+ streamer=streamer,
588
+ )
589
+ if stop_tokens is not None:
590
+ llm.pipeline._forward_params["stopping_criteria"] = StoppingCriteriaList(stop_tokens)
591
+
592
+ if do_rag:
593
+ t1 = Thread(target=rag_chain.invoke, args=({"input": history[-1][0]},))
594
+ else:
595
+ input_text = rag_prompt_template.format(input=history[-1][0], context="")
596
+ t1 = Thread(target=llm.invoke, args=(input_text,))
597
+ t1.start()
598
+
599
+ # Initialize an empty string to store the generated text
600
+ partial_text = ""
601
+ for new_text in streamer:
602
+ partial_text = text_processor(partial_text, new_text)
603
+ history[-1][1] = partial_text
604
+ yield history
605
+
606
+
607
+ def request_cancel():
608
+ llm.pipeline.model.request.cancel()
609
+
610
+
611
+ def clear_files():
612
+ return "Vector Store is Not ready"
613
+
614
+
615
+ # initialize the vector store with example document
616
+ create_vectordb(
617
+ text_example_path, #changed
618
+ "RecursiveCharacter",
619
+ chunk_size=400,
620
+ chunk_overlap=50,
621
+ vector_search_top_k=10,
622
+ vector_rerank_top_n=2,
623
+ run_rerank=True,
624
+ search_method="similarity_score_threshold",
625
+ score_threshold=0.5,
626
+ )
627
+ with gr.Blocks(
628
+ theme=gr.themes.Soft(),
629
+ css=".disclaimer {font-variant-caps: all-small-caps;}",
630
+ ) as demo:
631
+ gr.Markdown("""<h1><center>QA over Document</center></h1>""")
632
+ gr.Markdown(f"""<center>Powered by OpenVINO and {llm_model_id} </center>""")
633
+ with gr.Row():
634
+ with gr.Column(scale=1):
635
+ docs = gr.File(
636
+ label="Step 1: Load text files",
637
+ value=text_example_path, #changed
638
+ file_count="multiple",
639
+ file_types=[
640
+ ".csv",
641
+ ".doc",
642
+ ".docx",
643
+ ".enex",
644
+ ".epub",
645
+ ".html",
646
+ ".md",
647
+ ".odt",
648
+ ".pdf",
649
+ ".ppt",
650
+ ".pptx",
651
+ ".txt",
652
+ ],
653
+ )
654
+ load_docs = gr.Button("Step 2: Build Vector Store", variant="primary")
655
+ db_argument = gr.Accordion("Vector Store Configuration", open=False)
656
+ with db_argument:
657
+ spliter = gr.Dropdown(
658
+ ["Character", "RecursiveCharacter", "Markdown", "Chinese"],
659
+ value="RecursiveCharacter",
660
+ label="Text Spliter",
661
+ info="Method used to splite the documents",
662
+ multiselect=False,
663
+ )
664
+
665
+ chunk_size = gr.Slider(
666
+ label="Chunk size",
667
+ value=400,
668
+ minimum=50,
669
+ maximum=2000,
670
+ step=50,
671
+ interactive=True,
672
+ info="Size of sentence chunk",
673
+ )
674
+
675
+ chunk_overlap = gr.Slider(
676
+ label="Chunk overlap",
677
+ value=50,
678
+ minimum=0,
679
+ maximum=400,
680
+ step=10,
681
+ interactive=True,
682
+ info=("Overlap between 2 chunks"),
683
+ )
684
+
685
+ langchain_status = gr.Textbox(
686
+ label="Vector Store Status",
687
+ value="Vector Store is Ready",
688
+ interactive=False,
689
+ )
690
+ do_rag = gr.Checkbox(
691
+ value=True,
692
+ label="RAG is ON",
693
+ interactive=True,
694
+ info="Whether to do RAG for generation",
695
+ )
696
+ with gr.Accordion("Generation Configuration", open=False):
697
+ with gr.Row():
698
+ with gr.Column():
699
+ with gr.Row():
700
+ temperature = gr.Slider(
701
+ label="Temperature",
702
+ value=0.1,
703
+ minimum=0.0,
704
+ maximum=1.0,
705
+ step=0.1,
706
+ interactive=True,
707
+ info="Higher values produce more diverse outputs",
708
+ )
709
+ with gr.Column():
710
+ with gr.Row():
711
+ top_p = gr.Slider(
712
+ label="Top-p (nucleus sampling)",
713
+ value=1.0,
714
+ minimum=0.0,
715
+ maximum=1,
716
+ step=0.01,
717
+ interactive=True,
718
+ info=(
719
+ "Sample from the smallest possible set of tokens whose cumulative probability "
720
+ "exceeds top_p. Set to 1 to disable and sample from all tokens."
721
+ ),
722
+ )
723
+ with gr.Column():
724
+ with gr.Row():
725
+ top_k = gr.Slider(
726
+ label="Top-k",
727
+ value=50,
728
+ minimum=0.0,
729
+ maximum=200,
730
+ step=1,
731
+ interactive=True,
732
+ info="Sample from a shortlist of top-k tokens — 0 to disable and sample from all tokens.",
733
+ )
734
+ with gr.Column():
735
+ with gr.Row():
736
+ repetition_penalty = gr.Slider(
737
+ label="Repetition Penalty",
738
+ value=1.1,
739
+ minimum=1.0,
740
+ maximum=2.0,
741
+ step=0.1,
742
+ interactive=True,
743
+ info="Penalize repetition — 1.0 to disable.",
744
+ )
745
+ with gr.Column(scale=4):
746
+ chatbot = gr.Chatbot(
747
+ height=800,
748
+ label="Step 3: Input Query",
749
+ )
750
+ with gr.Row():
751
+ with gr.Column():
752
+ with gr.Row():
753
+ msg = gr.Textbox(
754
+ label="QA Message Box",
755
+ placeholder="Chat Message Box",
756
+ show_label=False,
757
+ container=False,
758
+ )
759
+ with gr.Column():
760
+ with gr.Row():
761
+ submit = gr.Button("Submit", variant="primary")
762
+ stop = gr.Button("Stop")
763
+ clear = gr.Button("Clear")
764
+ gr.Examples(examples, inputs=msg, label="Click on any example and press the 'Submit' button")
765
+ retriever_argument = gr.Accordion("Retriever Configuration", open=True)
766
+ with retriever_argument:
767
+ with gr.Row():
768
+ with gr.Row():
769
+ do_rerank = gr.Checkbox(
770
+ value=True,
771
+ label="Rerank searching result",
772
+ interactive=True,
773
+ )
774
+ hide_context = gr.Checkbox(
775
+ value=True,
776
+ label="Hide searching result in prompt",
777
+ interactive=True,
778
+ )
779
+ with gr.Row():
780
+ search_method = gr.Dropdown(
781
+ ["similarity_score_threshold", "similarity", "mmr"],
782
+ value="similarity_score_threshold",
783
+ label="Searching Method",
784
+ info="Method used to search vector store",
785
+ multiselect=False,
786
+ interactive=True,
787
+ )
788
+ with gr.Row():
789
+ score_threshold = gr.Slider(
790
+ 0.01,
791
+ 0.99,
792
+ value=0.5,
793
+ step=0.01,
794
+ label="Similarity Threshold",
795
+ info="Only working for 'similarity score threshold' method",
796
+ interactive=True,
797
+ )
798
+ with gr.Row():
799
+ vector_rerank_top_n = gr.Slider(
800
+ 1,
801
+ 10,
802
+ value=2,
803
+ step=1,
804
+ label="Rerank top n",
805
+ info="Number of rerank results",
806
+ interactive=True,
807
+ )
808
+ with gr.Row():
809
+ vector_search_top_k = gr.Slider(
810
+ 1,
811
+ 50,
812
+ value=10,
813
+ step=1,
814
+ label="Search top k",
815
+ info="Search top k must >= Rerank top n",
816
+ interactive=True,
817
+ )
818
+ docs.clear(clear_files, outputs=[langchain_status], queue=False)
819
+ load_docs.click(
820
+ create_vectordb,
821
+ inputs=[docs, spliter, chunk_size, chunk_overlap, vector_search_top_k, vector_rerank_top_n, do_rerank, search_method, score_threshold],
822
+ outputs=[langchain_status],
823
+ queue=False,
824
+ )
825
+ submit_event = msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
826
+ bot,
827
+ [chatbot, temperature, top_p, top_k, repetition_penalty, hide_context, do_rag],
828
+ chatbot,
829
+ queue=True,
830
+ )
831
+ submit_click_event = submit.click(user, [msg, chatbot], [msg, chatbot], queue=False).then(
832
+ bot,
833
+ [chatbot, temperature, top_p, top_k, repetition_penalty, hide_context, do_rag],
834
+ chatbot,
835
+ queue=True,
836
+ )
837
+ stop.click(
838
+ fn=request_cancel,
839
+ inputs=None,
840
+ outputs=None,
841
+ cancels=[submit_event, submit_click_event],
842
+ queue=False,
843
+ )
844
+ clear.click(lambda: None, None, chatbot, queue=False)
845
+ vector_search_top_k.release(
846
+ update_retriever,
847
+ [vector_search_top_k, vector_rerank_top_n, do_rerank, search_method, score_threshold],
848
+ outputs=[langchain_status],
849
+ )
850
+ vector_rerank_top_n.release(
851
+ update_retriever,
852
+ inputs=[vector_search_top_k, vector_rerank_top_n, do_rerank, search_method, score_threshold],
853
+ outputs=[langchain_status],
854
+ )
855
+ do_rerank.change(
856
+ update_retriever,
857
+ inputs=[vector_search_top_k, vector_rerank_top_n, do_rerank, search_method, score_threshold],
858
+ outputs=[langchain_status],
859
+ )
860
+ search_method.change(
861
+ update_retriever,
862
+ inputs=[vector_search_top_k, vector_rerank_top_n, do_rerank, search_method, score_threshold],
863
+ outputs=[langchain_status],
864
+ )
865
+ score_threshold.change(
866
+ update_retriever,
867
+ inputs=[vector_search_top_k, vector_rerank_top_n, do_rerank, search_method, score_threshold],
868
+ outputs=[langchain_status],
869
+ )
870
+
871
+
872
+ demo.queue()
873
+ # if you are launching remotely, specify server_name and server_port
874
+ # demo.launch(server_port=8082)
875
+ # if you have any issue to launch on your platform, you can pass share=True to launch method:
876
+ demo.launch(share=True)
877
+ # it creates a publicly shareable link for the interface. Read more in the docs: https://gradio.app/docs/
878
+ # demo.launch()