JiakaiDu commited on
Commit
c27d36b
1 Parent(s): 078c925

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +873 -0
app.py ADDED
@@ -0,0 +1,873 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ["GIT_CLONE_PROTECTION_ACTIVE"] = "false"
3
+ from pathlib import Path
4
+ import requests
5
+ import shutil
6
+ import io
7
+ from pathlib import Path
8
+ import openvino as ov
9
+ import torch
10
+ from transformers import (
11
+ TextIteratorStreamer,
12
+ StoppingCriteria,
13
+ StoppingCriteriaList,
14
+ )
15
+ from llm_config import (
16
+ SUPPORTED_EMBEDDING_MODELS,
17
+ SUPPORTED_RERANK_MODELS,
18
+ SUPPORTED_LLM_MODELS,
19
+ )
20
+ from huggingface_hub import login
21
+
22
+
23
+ config_shared_path = Path("../../utils/llm_config.py")
24
+ config_dst_path = Path("llm_config.py")
25
+ text_example_en_path = Path("text_example_en.pdf")
26
+ text_example_cn_path = Path("text_example_cn.pdf")
27
+ text_example_en = "https://github.com/openvinotoolkit/openvino_notebooks/files/15039728/Platform.Brief_Intel.vPro.with.Intel.Core.Ultra_Final.pdf"
28
+ text_example_cn = "https://github.com/openvinotoolkit/openvino_notebooks/files/15039713/Platform.Brief_Intel.vPro.with.Intel.Core.Ultra_Final_CH.pdf"
29
+
30
+ if not config_dst_path.exists():
31
+ if config_shared_path.exists():
32
+ try:
33
+ os.symlink(config_shared_path, config_dst_path)
34
+ except Exception:
35
+ shutil.copy(config_shared_path, config_dst_path)
36
+ else:
37
+ r = requests.get(url="https://raw.githubusercontent.com/openvinotoolkit/openvino_notebooks/latest/utils/llm_config.py")
38
+ with open("llm_config.py", "w", encoding="utf-8") as f:
39
+ f.write(r.text)
40
+ elif not os.path.islink(config_dst_path):
41
+ print("LLM config will be updated")
42
+ if config_shared_path.exists():
43
+ shutil.copy(config_shared_path, config_dst_path)
44
+ else:
45
+ r = requests.get(url="https://raw.githubusercontent.com/openvinotoolkit/openvino_notebooks/latest/utils/llm_config.py")
46
+ with open("llm_config.py", "w", encoding="utf-8") as f:
47
+ f.write(r.text)
48
+
49
+
50
+ if not text_example_en_path.exists():
51
+ r = requests.get(url=text_example_en)
52
+ content = io.BytesIO(r.content)
53
+ with open("text_example_en.pdf", "wb") as f:
54
+ f.write(content.read())
55
+
56
+ if not text_example_cn_path.exists():
57
+ r = requests.get(url=text_example_cn)
58
+ content = io.BytesIO(r.content)
59
+ with open("text_example_cn.pdf", "wb") as f:
60
+ f.write(content.read())
61
+
62
+ model_language = "English"
63
+ llm_model_id= "llama-3-8b-instruct"
64
+ llm_model_configuration = SUPPORTED_LLM_MODELS[model_language][llm_model_id]
65
+ print(f"Selected LLM model {llm_model_id}")
66
+ prepare_int4_model = True # Prepare INT4 model
67
+ prepare_int8_model = False # Do not prepare INT8 model
68
+ prepare_fp16_model = False # Do not prepare FP16 model
69
+ enable_awq = False
70
+ # Get the token from the environment variable
71
+ hf_token = os.getenv("HUGGINGFACE_TOKEN")
72
+
73
+ if hf_token is None:
74
+ raise ValueError(
75
+ "HUGGINGFACE_TOKEN environment variable not set. "
76
+ "Please set it in your environment variables or repository secrets."
77
+ )
78
+
79
+ # Log in to Hugging Face Hub
80
+ login(token=hf_token)
81
+ pt_model_id = llm_model_configuration["model_id"]
82
+ # pt_model_name = llm_model_id.value.split("-")[0]
83
+ fp16_model_dir = Path(llm_model_id) / "FP16"
84
+ int8_model_dir = Path(llm_model_id) / "INT8_compressed_weights"
85
+ int4_model_dir = Path(llm_model_id) / "INT4_compressed_weights"
86
+
87
+
88
+ def convert_to_fp16():
89
+ if (fp16_model_dir / "openvino_model.xml").exists():
90
+ return
91
+ remote_code = llm_model_configuration.get("remote_code", False)
92
+ export_command_base = "optimum-cli export openvino --model {} --task text-generation-with-past --weight-format fp16".format(pt_model_id)
93
+ if remote_code:
94
+ export_command_base += " --trust-remote-code"
95
+ export_command = export_command_base + " " + str(fp16_model_dir)
96
+
97
+
98
+
99
+ def convert_to_int8():
100
+ if (int8_model_dir / "openvino_model.xml").exists():
101
+ return
102
+ int8_model_dir.mkdir(parents=True, exist_ok=True)
103
+ remote_code = llm_model_configuration.get("remote_code", False)
104
+ export_command_base = "optimum-cli export openvino --model {} --task text-generation-with-past --weight-format int8".format(pt_model_id)
105
+ if remote_code:
106
+ export_command_base += " --trust-remote-code"
107
+ export_command = export_command_base + " " + str(int8_model_dir)
108
+
109
+
110
+
111
+ def convert_to_int4():
112
+ compression_configs = {
113
+ "zephyr-7b-beta": {
114
+ "sym": True,
115
+ "group_size": 64,
116
+ "ratio": 0.6,
117
+ },
118
+ "mistral-7b": {
119
+ "sym": True,
120
+ "group_size": 64,
121
+ "ratio": 0.6,
122
+ },
123
+ "minicpm-2b-dpo": {
124
+ "sym": True,
125
+ "group_size": 64,
126
+ "ratio": 0.6,
127
+ },
128
+ "gemma-2b-it": {
129
+ "sym": True,
130
+ "group_size": 64,
131
+ "ratio": 0.6,
132
+ },
133
+ "notus-7b-v1": {
134
+ "sym": True,
135
+ "group_size": 64,
136
+ "ratio": 0.6,
137
+ },
138
+ "neural-chat-7b-v3-1": {
139
+ "sym": True,
140
+ "group_size": 64,
141
+ "ratio": 0.6,
142
+ },
143
+ "llama-2-chat-7b": {
144
+ "sym": True,
145
+ "group_size": 128,
146
+ "ratio": 0.8,
147
+ },
148
+ "llama-3-8b-instruct": {
149
+ "sym": True,
150
+ "group_size": 128,
151
+ "ratio": 0.8,
152
+ },
153
+ "gemma-7b-it": {
154
+ "sym": True,
155
+ "group_size": 128,
156
+ "ratio": 0.8,
157
+ },
158
+ "chatglm2-6b": {
159
+ "sym": True,
160
+ "group_size": 128,
161
+ "ratio": 0.72,
162
+ },
163
+ "qwen-7b-chat": {"sym": True, "group_size": 128, "ratio": 0.6},
164
+ "red-pajama-3b-chat": {
165
+ "sym": False,
166
+ "group_size": 128,
167
+ "ratio": 0.5,
168
+ },
169
+ "default": {
170
+ "sym": False,
171
+ "group_size": 128,
172
+ "ratio": 0.8,
173
+ },
174
+ }
175
+
176
+ model_compression_params = compression_configs.get(llm_model_id, compression_configs["default"])
177
+ if (int4_model_dir / "openvino_model.xml").exists():
178
+ return
179
+ remote_code = llm_model_configuration.get("remote_code", False)
180
+ export_command_base = "optimum-cli export openvino --model {} --task text-generation-with-past --weight-format int4".format(pt_model_id)
181
+ int4_compression_args = " --group-size {} --ratio {}".format(model_compression_params["group_size"], model_compression_params["ratio"])
182
+ if model_compression_params["sym"]:
183
+ int4_compression_args += " --sym"
184
+ if enable_awq.value:
185
+ int4_compression_args += " --awq --dataset wikitext2 --num-samples 128"
186
+ export_command_base += int4_compression_args
187
+ if remote_code:
188
+ export_command_base += " --trust-remote-code"
189
+ export_command = export_command_base + " " + str(int4_model_dir)
190
+
191
+
192
+
193
+ if prepare_fp16_model:
194
+ convert_to_fp16()
195
+ if prepare_int8_model:
196
+ convert_to_int8()
197
+ if prepare_int4_model:
198
+ convert_to_int4()
199
+ fp16_weights = fp16_model_dir / "openvino_model.bin"
200
+ int8_weights = int8_model_dir / "openvino_model.bin"
201
+ int4_weights = int4_model_dir / "openvino_model.bin"
202
+
203
+ if fp16_weights.exists():
204
+ print(f"Size of FP16 model is {fp16_weights.stat().st_size / 1024 / 1024:.2f} MB")
205
+ for precision, compressed_weights in zip([8, 4], [int8_weights, int4_weights]):
206
+ if compressed_weights.exists():
207
+ print(f"Size of model with INT{precision} compressed weights is {compressed_weights.stat().st_size / 1024 / 1024:.2f} MB")
208
+ if compressed_weights.exists() and fp16_weights.exists():
209
+ print(f"Compression rate for INT{precision} model: {fp16_weights.stat().st_size / compressed_weights.stat().st_size:.3f}")
210
+ embedding_model_id = 'bge-small-en-v1.5' #'bge-small-en-v1.5', 'bge-large-en-v1.5', 'bge-m3'), value='bge-small-en-v1.5'
211
+ embedding_model_configuration = SUPPORTED_EMBEDDING_MODELS[model_language][embedding_model_id]
212
+ print(f"Selected {embedding_model_id} model")
213
+ export_command_base = "optimum-cli export openvino --model {} --task feature-extraction".format(embedding_model_configuration["model_id"])
214
+ export_command = export_command_base + " " + str(embedding_model_id)
215
+ rerank_model_id = "bge-reranker-v2-m3" #'bge-reranker-v2-m3', 'bge-reranker-large', 'bge-reranker-base')
216
+ rerank_model_configuration = SUPPORTED_RERANK_MODELS[rerank_model_id]
217
+ print(f"Selected {rerank_model_id} model")
218
+ export_command_base = "optimum-cli export openvino --model {} --task text-classification".format(rerank_model_configuration["model_id"])
219
+ export_command = export_command_base + " " + str(rerank_model_id)
220
+ embedding_device = "CPU"
221
+ USING_NPU = embedding_device == "NPU"
222
+
223
+ npu_embedding_dir = embedding_model_id + "-npu"
224
+ npu_embedding_path = Path(npu_embedding_dir) / "openvino_model.xml"
225
+ if USING_NPU and not Path(npu_embedding_dir).exists():
226
+ r = requests.get(
227
+ url="https://raw.githubusercontent.com/openvinotoolkit/openvino_notebooks/latest/utils/notebook_utils.py",
228
+ )
229
+ with open("notebook_utils.py", "w") as f:
230
+ f.write(r.text)
231
+ import notebook_utils as utils
232
+
233
+ shutil.copytree(embedding_model_id, npu_embedding_dir)
234
+ utils.optimize_bge_embedding(Path(embedding_model_id) / "openvino_model.xml", npu_embedding_path)
235
+ rerank_device = "CPU"
236
+ llm_device = "CPU"
237
+ from langchain_community.embeddings import OpenVINOBgeEmbeddings
238
+
239
+ embedding_model_name = npu_embedding_dir if USING_NPU else embedding_model_id
240
+ batch_size = 1 if USING_NPU else 4
241
+ embedding_model_kwargs = {"device": embedding_device, "compile": False}
242
+ encode_kwargs = {
243
+ "mean_pooling": embedding_model_configuration["mean_pooling"],
244
+ "normalize_embeddings": embedding_model_configuration["normalize_embeddings"],
245
+ "batch_size": batch_size,
246
+ }
247
+
248
+ embedding = OpenVINOBgeEmbeddings(
249
+ model_name_or_path=embedding_model_name,
250
+ model_kwargs=embedding_model_kwargs,
251
+ encode_kwargs=encode_kwargs,
252
+ )
253
+ if USING_NPU:
254
+ embedding.ov_model.reshape(1, 512)
255
+ embedding.ov_model.compile()
256
+
257
+ text = "This is a test document."
258
+ embedding_result = embedding.embed_query(text)
259
+ embedding_result[:3]
260
+ from langchain_community.document_compressors.openvino_rerank import OpenVINOReranker
261
+
262
+ rerank_model_name = rerank_model_id
263
+ rerank_model_kwargs = {"device": rerank_device}
264
+ rerank_top_n = 2
265
+
266
+ reranker = OpenVINOReranker(
267
+ model_name_or_path=rerank_model_name,
268
+ model_kwargs=rerank_model_kwargs,
269
+ top_n=rerank_top_n,
270
+ )
271
+ model_to_run = "INT4"
272
+ from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
273
+
274
+ if model_to_run == "INT4":
275
+ model_dir = int4_model_dir
276
+ elif model_to_run == "INT8":
277
+ model_dir = int8_model_dir
278
+ else:
279
+ model_dir = fp16_model_dir
280
+ print(f"Loading model from {model_dir}")
281
+
282
+ ov_config = {"PERFORMANCE_HINT": "LATENCY", "NUM_STREAMS": "1", "CACHE_DIR": ""}
283
+
284
+ if "GPU" in llm_device and "qwen2-7b-instruct" in llm_model_id:
285
+ ov_config["GPU_ENABLE_SDPA_OPTIMIZATION"] = "NO"
286
+
287
+ # On a GPU device a model is executed in FP16 precision. For red-pajama-3b-chat model there known accuracy
288
+ # issues caused by this, which we avoid by setting precision hint to "f32".
289
+ if llm_model_id == "red-pajama-3b-chat" and "GPU" in core.available_devices and llm_device in ["GPU", "AUTO"]:
290
+ ov_config["INFERENCE_PRECISION_HINT"] = "f32"
291
+
292
+ llm = HuggingFacePipeline.from_model_id(
293
+ model_id=str(model_dir),
294
+ task="text-generation",
295
+ backend="openvino",
296
+ model_kwargs={
297
+ "device": llm_device,
298
+ "ov_config": ov_config,
299
+ "trust_remote_code": True,
300
+ },
301
+ pipeline_kwargs={"max_new_tokens": 2},
302
+ )
303
+
304
+ llm.invoke("2 + 2 =")
305
+ import re
306
+ from typing import List
307
+ from langchain.text_splitter import (
308
+ CharacterTextSplitter,
309
+ RecursiveCharacterTextSplitter,
310
+ MarkdownTextSplitter,
311
+ )
312
+ from langchain.document_loaders import (
313
+ CSVLoader,
314
+ EverNoteLoader,
315
+ PyPDFLoader,
316
+ TextLoader,
317
+ UnstructuredEPubLoader,
318
+ UnstructuredHTMLLoader,
319
+ UnstructuredMarkdownLoader,
320
+ UnstructuredODTLoader,
321
+ UnstructuredPowerPointLoader,
322
+ UnstructuredWordDocumentLoader,
323
+ )
324
+
325
+
326
+ class ChineseTextSplitter(CharacterTextSplitter):
327
+ def __init__(self, pdf: bool = False, **kwargs):
328
+ super().__init__(**kwargs)
329
+ self.pdf = pdf
330
+
331
+ def split_text(self, text: str) -> List[str]:
332
+ if self.pdf:
333
+ text = re.sub(r"\n{3,}", "\n", text)
334
+ text = text.replace("\n\n", "")
335
+ sent_sep_pattern = re.compile('([﹒﹔﹖﹗.。!?]["’”」』]{0,2}|(?=["‘“「『]{1,2}|$))')
336
+ sent_list = []
337
+ for ele in sent_sep_pattern.split(text):
338
+ if sent_sep_pattern.match(ele) and sent_list:
339
+ sent_list[-1] += ele
340
+ elif ele:
341
+ sent_list.append(ele)
342
+ return sent_list
343
+
344
+
345
+ TEXT_SPLITERS = {
346
+ "Character": CharacterTextSplitter,
347
+ "RecursiveCharacter": RecursiveCharacterTextSplitter,
348
+ "Markdown": MarkdownTextSplitter,
349
+ "Chinese": ChineseTextSplitter,
350
+ }
351
+
352
+
353
+ LOADERS = {
354
+ ".csv": (CSVLoader, {}),
355
+ ".doc": (UnstructuredWordDocumentLoader, {}),
356
+ ".docx": (UnstructuredWordDocumentLoader, {}),
357
+ ".enex": (EverNoteLoader, {}),
358
+ ".epub": (UnstructuredEPubLoader, {}),
359
+ ".html": (UnstructuredHTMLLoader, {}),
360
+ ".md": (UnstructuredMarkdownLoader, {}),
361
+ ".odt": (UnstructuredODTLoader, {}),
362
+ ".pdf": (PyPDFLoader, {}),
363
+ ".ppt": (UnstructuredPowerPointLoader, {}),
364
+ ".pptx": (UnstructuredPowerPointLoader, {}),
365
+ ".txt": (TextLoader, {"encoding": "utf8"}),
366
+ }
367
+
368
+ chinese_examples = [
369
+ ["英特尔®酷睿™ Ultra处理器可以降低多少功耗?"],
370
+ ["相比英特尔之前的移动处理器产品,英特尔®酷睿™ Ultra处理器的AI推理性能提升了多少?"],
371
+ ["英特尔博锐® Enterprise系统提供哪些功能?"],
372
+ ]
373
+
374
+ english_examples = [
375
+ ["How much power consumption can Intel® Core™ Ultra Processors help save?"],
376
+ ["Compared to Intel’s previous mobile processor, what is the advantage of Intel® Core™ Ultra Processors for Artificial Intelligence?"],
377
+ ["What can Intel vPro® Enterprise systems offer?"],
378
+ ]
379
+
380
+ if model_language == "English":
381
+ # text_example_path = "text_example_en.pdf"
382
+ text_example_path = ['Supervisors-Guide-Accurate-Timekeeping_AH edits.docx','Salary-vs-Hourly-Guide_AH edits.docx','Employee-Guide-Accurate-Timekeeping_AH edits.docx','Eller Overtime Guidelines.docx','Eller FLSA information 9.2024_AH edits.docx','Accurate Timekeeping Supervisors 12.2.20_AH edits.docx']
383
+ else:
384
+ text_example_path = "text_example_cn.pdf"
385
+
386
+ examples = chinese_examples if (model_language == "Chinese") else english_examples
387
+ from langchain.prompts import PromptTemplate
388
+ from langchain_community.vectorstores import FAISS
389
+ from langchain.chains.retrieval import create_retrieval_chain
390
+ from langchain.chains.combine_documents import create_stuff_documents_chain
391
+ from langchain.docstore.document import Document
392
+ from langchain.retrievers import ContextualCompressionRetriever
393
+ from threading import Thread
394
+ import gradio as gr
395
+
396
+ stop_tokens = llm_model_configuration.get("stop_tokens")
397
+ rag_prompt_template = llm_model_configuration["rag_prompt_template"]
398
+
399
+
400
+ class StopOnTokens(StoppingCriteria):
401
+ def __init__(self, token_ids):
402
+ self.token_ids = token_ids
403
+
404
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
405
+ for stop_id in self.token_ids:
406
+ if input_ids[0][-1] == stop_id:
407
+ return True
408
+ return False
409
+
410
+
411
+ if stop_tokens is not None:
412
+ if isinstance(stop_tokens[0], str):
413
+ stop_tokens = llm.pipeline.tokenizer.convert_tokens_to_ids(stop_tokens)
414
+
415
+ stop_tokens = [StopOnTokens(stop_tokens)]
416
+
417
+
418
+ def load_single_document(file_path: str) -> List[Document]:
419
+ """
420
+ helper for loading a single document
421
+
422
+ Params:
423
+ file_path: document path
424
+ Returns:
425
+ documents loaded
426
+
427
+ """
428
+ ext = "." + file_path.rsplit(".", 1)[-1]
429
+ if ext in LOADERS:
430
+ loader_class, loader_args = LOADERS[ext]
431
+ loader = loader_class(file_path, **loader_args)
432
+ return loader.load()
433
+
434
+ raise ValueError(f"File does not exist '{ext}'")
435
+
436
+
437
+ def default_partial_text_processor(partial_text: str, new_text: str):
438
+ """
439
+ helper for updating partially generated answer, used by default
440
+
441
+ Params:
442
+ partial_text: text buffer for storing previosly generated text
443
+ new_text: text update for the current step
444
+ Returns:
445
+ updated text string
446
+
447
+ """
448
+ partial_text += new_text
449
+ return partial_text
450
+
451
+
452
+ text_processor = llm_model_configuration.get("partial_text_processor", default_partial_text_processor)
453
+
454
+
455
+ def create_vectordb(
456
+ docs, spliter_name, chunk_size, chunk_overlap, vector_search_top_k, vector_rerank_top_n, run_rerank, search_method, score_threshold, progress=gr.Progress()
457
+ ):
458
+ """
459
+ Initialize a vector database
460
+
461
+ Params:
462
+ doc: orignal documents provided by user
463
+ spliter_name: spliter method
464
+ chunk_size: size of a single sentence chunk
465
+ chunk_overlap: overlap size between 2 chunks
466
+ vector_search_top_k: Vector search top k
467
+ vector_rerank_top_n: Search rerank top n
468
+ run_rerank: whether run reranker
469
+ search_method: top k search method
470
+ score_threshold: score threshold when selecting 'similarity_score_threshold' method
471
+
472
+ """
473
+ global db
474
+ global retriever
475
+ global combine_docs_chain
476
+ global rag_chain
477
+
478
+ if vector_rerank_top_n > vector_search_top_k:
479
+ gr.Warning("Search top k must >= Rerank top n")
480
+
481
+ documents = []
482
+ for doc in docs:
483
+ if type(doc) is not str:
484
+ doc = doc.name
485
+ documents.extend(load_single_document(doc))
486
+
487
+ text_splitter = TEXT_SPLITERS[spliter_name](chunk_size=chunk_size, chunk_overlap=chunk_overlap)
488
+
489
+ texts = text_splitter.split_documents(documents)
490
+ db = FAISS.from_documents(texts, embedding)
491
+ if search_method == "similarity_score_threshold":
492
+ search_kwargs = {"k": vector_search_top_k, "score_threshold": score_threshold}
493
+ else:
494
+ search_kwargs = {"k": vector_search_top_k}
495
+ retriever = db.as_retriever(search_kwargs=search_kwargs, search_type=search_method)
496
+ if run_rerank:
497
+ reranker.top_n = vector_rerank_top_n
498
+ retriever = ContextualCompressionRetriever(base_compressor=reranker, base_retriever=retriever)
499
+ prompt = PromptTemplate.from_template(rag_prompt_template)
500
+ combine_docs_chain = create_stuff_documents_chain(llm, prompt)
501
+
502
+ rag_chain = create_retrieval_chain(retriever, combine_docs_chain)
503
+
504
+ return "Vector database is Ready"
505
+
506
+
507
+ def update_retriever(vector_search_top_k, vector_rerank_top_n, run_rerank, search_method, score_threshold):
508
+ """
509
+ Update retriever
510
+
511
+ Params:
512
+ vector_search_top_k: Vector search top k
513
+ vector_rerank_top_n: Search rerank top n
514
+ run_rerank: whether run reranker
515
+ search_method: top k search method
516
+ score_threshold: score threshold when selecting 'similarity_score_threshold' method
517
+
518
+ """
519
+ global db
520
+ global retriever
521
+ global combine_docs_chain
522
+ global rag_chain
523
+
524
+ if vector_rerank_top_n > vector_search_top_k:
525
+ gr.Warning("Search top k must >= Rerank top n")
526
+
527
+ if search_method == "similarity_score_threshold":
528
+ search_kwargs = {"k": vector_search_top_k, "score_threshold": score_threshold}
529
+ else:
530
+ search_kwargs = {"k": vector_search_top_k}
531
+ retriever = db.as_retriever(search_kwargs=search_kwargs, search_type=search_method)
532
+ if run_rerank:
533
+ retriever = ContextualCompressionRetriever(base_compressor=reranker, base_retriever=retriever)
534
+ reranker.top_n = vector_rerank_top_n
535
+ rag_chain = create_retrieval_chain(retriever, combine_docs_chain)
536
+
537
+ return "Vector database is Ready"
538
+
539
+
540
+ def user(message, history):
541
+ """
542
+ callback function for updating user messages in interface on submit button click
543
+
544
+ Params:
545
+ message: current message
546
+ history: conversation history
547
+ Returns:
548
+ None
549
+ """
550
+ # Append the user's message to the conversation history
551
+ return "", history + [[message, ""]]
552
+
553
+
554
+ def bot(history, temperature, top_p, top_k, repetition_penalty, hide_full_prompt, do_rag):
555
+ """
556
+ callback function for running chatbot on submit button click
557
+
558
+ Params:
559
+ history: conversation history
560
+ temperature: parameter for control the level of creativity in AI-generated text.
561
+ By adjusting the `temperature`, you can influence the AI model's probability distribution, making the text more focused or diverse.
562
+ top_p: parameter for control the range of tokens considered by the AI model based on their cumulative probability.
563
+ top_k: parameter for control the range of tokens considered by the AI model based on their cumulative probability, selecting number of tokens with highest probability.
564
+ repetition_penalty: parameter for penalizing tokens based on how frequently they occur in the text.
565
+ hide_full_prompt: whether to show searching results in promopt.
566
+ do_rag: whether do RAG when generating texts.
567
+
568
+ """
569
+ streamer = TextIteratorStreamer(
570
+ llm.pipeline.tokenizer,
571
+ timeout=60.0,
572
+ skip_prompt=hide_full_prompt,
573
+ skip_special_tokens=True,
574
+ )
575
+ llm.pipeline._forward_params = dict(
576
+ max_new_tokens=512,
577
+ temperature=temperature,
578
+ do_sample=temperature > 0.0,
579
+ top_p=top_p,
580
+ top_k=top_k,
581
+ repetition_penalty=repetition_penalty,
582
+ streamer=streamer,
583
+ )
584
+ if stop_tokens is not None:
585
+ llm.pipeline._forward_params["stopping_criteria"] = StoppingCriteriaList(stop_tokens)
586
+
587
+ if do_rag:
588
+ t1 = Thread(target=rag_chain.invoke, args=({"input": history[-1][0]},))
589
+ else:
590
+ input_text = rag_prompt_template.format(input=history[-1][0], context="")
591
+ t1 = Thread(target=llm.invoke, args=(input_text,))
592
+ t1.start()
593
+
594
+ # Initialize an empty string to store the generated text
595
+ partial_text = ""
596
+ for new_text in streamer:
597
+ partial_text = text_processor(partial_text, new_text)
598
+ history[-1][1] = partial_text
599
+ yield history
600
+
601
+
602
+ def request_cancel():
603
+ llm.pipeline.model.request.cancel()
604
+
605
+
606
+ def clear_files():
607
+ return "Vector Store is Not ready"
608
+
609
+
610
+ # initialize the vector store with example document
611
+ create_vectordb(
612
+ text_example_path, #changed
613
+ "RecursiveCharacter",
614
+ chunk_size=400,
615
+ chunk_overlap=50,
616
+ vector_search_top_k=10,
617
+ vector_rerank_top_n=2,
618
+ run_rerank=True,
619
+ search_method="similarity_score_threshold",
620
+ score_threshold=0.5,
621
+ )
622
+ with gr.Blocks(
623
+ theme=gr.themes.Soft(),
624
+ css=".disclaimer {font-variant-caps: all-small-caps;}",
625
+ ) as demo:
626
+ gr.Markdown("""<h1><center>QA over Document</center></h1>""")
627
+ gr.Markdown(f"""<center>Powered by OpenVINO and {llm_model_id} </center>""")
628
+ with gr.Row():
629
+ with gr.Column(scale=1):
630
+ docs = gr.File(
631
+ label="Step 1: Load text files",
632
+ value=text_example_path, #changed
633
+ file_count="multiple",
634
+ file_types=[
635
+ ".csv",
636
+ ".doc",
637
+ ".docx",
638
+ ".enex",
639
+ ".epub",
640
+ ".html",
641
+ ".md",
642
+ ".odt",
643
+ ".pdf",
644
+ ".ppt",
645
+ ".pptx",
646
+ ".txt",
647
+ ],
648
+ )
649
+ load_docs = gr.Button("Step 2: Build Vector Store", variant="primary")
650
+ db_argument = gr.Accordion("Vector Store Configuration", open=False)
651
+ with db_argument:
652
+ spliter = gr.Dropdown(
653
+ ["Character", "RecursiveCharacter", "Markdown", "Chinese"],
654
+ value="RecursiveCharacter",
655
+ label="Text Spliter",
656
+ info="Method used to splite the documents",
657
+ multiselect=False,
658
+ )
659
+
660
+ chunk_size = gr.Slider(
661
+ label="Chunk size",
662
+ value=400,
663
+ minimum=50,
664
+ maximum=2000,
665
+ step=50,
666
+ interactive=True,
667
+ info="Size of sentence chunk",
668
+ )
669
+
670
+ chunk_overlap = gr.Slider(
671
+ label="Chunk overlap",
672
+ value=50,
673
+ minimum=0,
674
+ maximum=400,
675
+ step=10,
676
+ interactive=True,
677
+ info=("Overlap between 2 chunks"),
678
+ )
679
+
680
+ langchain_status = gr.Textbox(
681
+ label="Vector Store Status",
682
+ value="Vector Store is Ready",
683
+ interactive=False,
684
+ )
685
+ do_rag = gr.Checkbox(
686
+ value=True,
687
+ label="RAG is ON",
688
+ interactive=True,
689
+ info="Whether to do RAG for generation",
690
+ )
691
+ with gr.Accordion("Generation Configuration", open=False):
692
+ with gr.Row():
693
+ with gr.Column():
694
+ with gr.Row():
695
+ temperature = gr.Slider(
696
+ label="Temperature",
697
+ value=0.1,
698
+ minimum=0.0,
699
+ maximum=1.0,
700
+ step=0.1,
701
+ interactive=True,
702
+ info="Higher values produce more diverse outputs",
703
+ )
704
+ with gr.Column():
705
+ with gr.Row():
706
+ top_p = gr.Slider(
707
+ label="Top-p (nucleus sampling)",
708
+ value=1.0,
709
+ minimum=0.0,
710
+ maximum=1,
711
+ step=0.01,
712
+ interactive=True,
713
+ info=(
714
+ "Sample from the smallest possible set of tokens whose cumulative probability "
715
+ "exceeds top_p. Set to 1 to disable and sample from all tokens."
716
+ ),
717
+ )
718
+ with gr.Column():
719
+ with gr.Row():
720
+ top_k = gr.Slider(
721
+ label="Top-k",
722
+ value=50,
723
+ minimum=0.0,
724
+ maximum=200,
725
+ step=1,
726
+ interactive=True,
727
+ info="Sample from a shortlist of top-k tokens — 0 to disable and sample from all tokens.",
728
+ )
729
+ with gr.Column():
730
+ with gr.Row():
731
+ repetition_penalty = gr.Slider(
732
+ label="Repetition Penalty",
733
+ value=1.1,
734
+ minimum=1.0,
735
+ maximum=2.0,
736
+ step=0.1,
737
+ interactive=True,
738
+ info="Penalize repetition — 1.0 to disable.",
739
+ )
740
+ with gr.Column(scale=4):
741
+ chatbot = gr.Chatbot(
742
+ height=800,
743
+ label="Step 3: Input Query",
744
+ )
745
+ with gr.Row():
746
+ with gr.Column():
747
+ with gr.Row():
748
+ msg = gr.Textbox(
749
+ label="QA Message Box",
750
+ placeholder="Chat Message Box",
751
+ show_label=False,
752
+ container=False,
753
+ )
754
+ with gr.Column():
755
+ with gr.Row():
756
+ submit = gr.Button("Submit", variant="primary")
757
+ stop = gr.Button("Stop")
758
+ clear = gr.Button("Clear")
759
+ gr.Examples(examples, inputs=msg, label="Click on any example and press the 'Submit' button")
760
+ retriever_argument = gr.Accordion("Retriever Configuration", open=True)
761
+ with retriever_argument:
762
+ with gr.Row():
763
+ with gr.Row():
764
+ do_rerank = gr.Checkbox(
765
+ value=True,
766
+ label="Rerank searching result",
767
+ interactive=True,
768
+ )
769
+ hide_context = gr.Checkbox(
770
+ value=True,
771
+ label="Hide searching result in prompt",
772
+ interactive=True,
773
+ )
774
+ with gr.Row():
775
+ search_method = gr.Dropdown(
776
+ ["similarity_score_threshold", "similarity", "mmr"],
777
+ value="similarity_score_threshold",
778
+ label="Searching Method",
779
+ info="Method used to search vector store",
780
+ multiselect=False,
781
+ interactive=True,
782
+ )
783
+ with gr.Row():
784
+ score_threshold = gr.Slider(
785
+ 0.01,
786
+ 0.99,
787
+ value=0.5,
788
+ step=0.01,
789
+ label="Similarity Threshold",
790
+ info="Only working for 'similarity score threshold' method",
791
+ interactive=True,
792
+ )
793
+ with gr.Row():
794
+ vector_rerank_top_n = gr.Slider(
795
+ 1,
796
+ 10,
797
+ value=2,
798
+ step=1,
799
+ label="Rerank top n",
800
+ info="Number of rerank results",
801
+ interactive=True,
802
+ )
803
+ with gr.Row():
804
+ vector_search_top_k = gr.Slider(
805
+ 1,
806
+ 50,
807
+ value=10,
808
+ step=1,
809
+ label="Search top k",
810
+ info="Search top k must >= Rerank top n",
811
+ interactive=True,
812
+ )
813
+ docs.clear(clear_files, outputs=[langchain_status], queue=False)
814
+ load_docs.click(
815
+ create_vectordb,
816
+ inputs=[docs, spliter, chunk_size, chunk_overlap, vector_search_top_k, vector_rerank_top_n, do_rerank, search_method, score_threshold],
817
+ outputs=[langchain_status],
818
+ queue=False,
819
+ )
820
+ submit_event = msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
821
+ bot,
822
+ [chatbot, temperature, top_p, top_k, repetition_penalty, hide_context, do_rag],
823
+ chatbot,
824
+ queue=True,
825
+ )
826
+ submit_click_event = submit.click(user, [msg, chatbot], [msg, chatbot], queue=False).then(
827
+ bot,
828
+ [chatbot, temperature, top_p, top_k, repetition_penalty, hide_context, do_rag],
829
+ chatbot,
830
+ queue=True,
831
+ )
832
+ stop.click(
833
+ fn=request_cancel,
834
+ inputs=None,
835
+ outputs=None,
836
+ cancels=[submit_event, submit_click_event],
837
+ queue=False,
838
+ )
839
+ clear.click(lambda: None, None, chatbot, queue=False)
840
+ vector_search_top_k.release(
841
+ update_retriever,
842
+ [vector_search_top_k, vector_rerank_top_n, do_rerank, search_method, score_threshold],
843
+ outputs=[langchain_status],
844
+ )
845
+ vector_rerank_top_n.release(
846
+ update_retriever,
847
+ inputs=[vector_search_top_k, vector_rerank_top_n, do_rerank, search_method, score_threshold],
848
+ outputs=[langchain_status],
849
+ )
850
+ do_rerank.change(
851
+ update_retriever,
852
+ inputs=[vector_search_top_k, vector_rerank_top_n, do_rerank, search_method, score_threshold],
853
+ outputs=[langchain_status],
854
+ )
855
+ search_method.change(
856
+ update_retriever,
857
+ inputs=[vector_search_top_k, vector_rerank_top_n, do_rerank, search_method, score_threshold],
858
+ outputs=[langchain_status],
859
+ )
860
+ score_threshold.change(
861
+ update_retriever,
862
+ inputs=[vector_search_top_k, vector_rerank_top_n, do_rerank, search_method, score_threshold],
863
+ outputs=[langchain_status],
864
+ )
865
+
866
+
867
+ demo.queue()
868
+ # if you are launching remotely, specify server_name and server_port
869
+ # demo.launch(server_port=8082)
870
+ # if you have any issue to launch on your platform, you can pass share=True to launch method:
871
+ demo.launch(share=True)
872
+ # it creates a publicly shareable link for the interface. Read more in the docs: https://gradio.app/docs/
873
+ # demo.launch()