nxphi47 commited on
Commit
a832036
·
1 Parent(s): f5d291f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +423 -54
app.py CHANGED
@@ -93,15 +93,24 @@ ENABLE_AGREE_POPUP = bool(int(os.environ.get("ENABLE_AGREE_POPUP", "0")))
93
  MAX_TOKENS = int(os.environ.get("MAX_TOKENS", "2048"))
94
  TEMPERATURE = float(os.environ.get("TEMPERATURE", "0.1"))
95
  FREQUENCE_PENALTY = float(os.environ.get("FREQUENCE_PENALTY", "0.4"))
 
96
  gpu_memory_utilization = float(os.environ.get("gpu_memory_utilization", "0.9"))
97
 
98
  # whether to enable quantization, currently not in use
99
  QUANTIZATION = str(os.environ.get("QUANTIZATION", ""))
100
 
 
 
 
 
 
 
 
 
 
101
  DATA_SET_REPO_PATH = str(os.environ.get("DATA_SET_REPO_PATH", ""))
102
  DATA_SET_REPO = None
103
 
104
-
105
  """
106
  Internal instructions of how to configure the DEMO
107
 
@@ -196,6 +205,32 @@ MODEL_TITLE = """
196
  </div>
197
  """
198
  # <a href=''><img src='https://img.shields.io/badge/Paper-PDF-red'></a>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
199
  MODEL_DESC = """
200
  <div style='display:flex; gap: 0.25rem; '>
201
  <a href='https://github.com/SeaLLMs/SeaLLMs'><img src='https://img.shields.io/badge/Github-Code-success'></a>
@@ -207,20 +242,13 @@ This is <a href="https://huggingface.co/SeaLLMs/SeaLLM-Chat-13b" target="_blank"
207
  Explore <a href="https://huggingface.co/SeaLLMs/SeaLLM-Chat-13b" target="_blank">our article</a> for more details.
208
  </span>
209
  <br>
210
- <span >
211
- NOTE: The chatbot may produce inaccurate and harmful information about people, places, or facts.
212
- <span style="color: red">By using our service, you are required to agree to our <a href="https://huggingface.co/SeaLLMs/SeaLLM-Chat-13b/blob/main/LICENSE" target="_blank" style="color: red">SeaLLM Terms Of Use</a>, which include:</span><br>
213
- <ul>
214
- <li >
215
- You must not use our service to generate any harmful, unethical or illegal content that violates locally applicable and international laws or regulations,
216
- including but not limited to hate speech, violence, pornography and deception.</li>
217
- <li >
218
  The service collects user dialogue data for testing and performance improvement, and reserves the right to distribute it under
219
- <a href="https://creativecommons.org/licenses/by/4.0/">Creative Commons Attribution (CC-BY)</a> or similar license. So do not enter any personal information!
220
- </li>
221
- </ul>
222
  </span>
223
-
224
  """.strip()
225
 
226
 
@@ -709,6 +737,7 @@ def llama_chat_multiturn_sys_input_seq_constructor(
709
  sys_prompt=SYSTEM_PROMPT_1,
710
  bos_token=BOS_TOKEN,
711
  eos_token=EOS_TOKEN,
 
712
  ):
713
  """
714
  ```
@@ -718,18 +747,19 @@ def llama_chat_multiturn_sys_input_seq_constructor(
718
  ```
719
  """
720
  text = ''
 
721
  for i, (prompt, res) in enumerate(history):
722
  if i == 0:
723
- text += f"{bos_token}{B_INST} {B_SYS} {sys_prompt} {E_SYS} {prompt} {E_INST}"
724
  else:
725
- text += f"{bos_token}{B_INST} {prompt} {E_INST}"
726
 
727
  if res is not None:
728
  text += f" {res} {eos_token} "
729
  if len(history) == 0 or text.strip() == '':
730
- text = f"{bos_token}{B_INST} {B_SYS} {sys_prompt} {E_SYS} {message} {E_INST}"
731
  else:
732
- text += f"{bos_token}{B_INST} {message} {E_INST}"
733
  return text
734
 
735
 
@@ -944,6 +974,10 @@ gr.ChatInterface._setup_events = _setup_events
944
 
945
 
946
  def vllm_abort(self: Any):
 
 
 
 
947
  from vllm.sequence import SequenceStatus
948
  scheduler = self.llm_engine.scheduler
949
  for state_queue in [scheduler.waiting, scheduler.running, scheduler.swapped]:
@@ -1093,6 +1127,7 @@ def chat_response_stream_multiturn(
1093
  temperature: float,
1094
  max_tokens: int,
1095
  frequency_penalty: float,
 
1096
  current_time: Optional[float] = None,
1097
  system_prompt: Optional[str] = SYSTEM_PROMPT_1
1098
  ) -> str:
@@ -1144,6 +1179,7 @@ def chat_response_stream_multiturn(
1144
  temperature=temperature,
1145
  max_tokens=max_tokens,
1146
  frequency_penalty=frequency_penalty,
 
1147
  stop=['<s>', '</s>', '<<SYS>>', '<</SYS>>', '[INST]', '[/INST]']
1148
  )
1149
  cur_out = None
@@ -1163,6 +1199,9 @@ def chat_response_stream_multiturn(
1163
  assert len(gen) == 1, f'{gen}'
1164
  item = next(iter(gen.values()))
1165
  cur_out = item.outputs[0].text
 
 
 
1166
 
1167
  # TODO: use current_time to register conversations, accoriding history and cur_out
1168
  history_str = format_conversation(history + [[message, cur_out]])
@@ -1236,7 +1275,7 @@ def maybe_upload_to_dataset():
1236
  )
1237
  except Exception as e:
1238
  print(f'Failed to save to repo: {DATA_SET_REPO_PATH}|{str(e)}')
1239
-
1240
 
1241
  def print_log_file():
1242
  global LOG_FILE, LOG_PATH
@@ -1262,6 +1301,7 @@ def debug_chat_response_echo(
1262
  temperature: float = 0.0,
1263
  max_tokens: int = 4096,
1264
  frequency_penalty: float = 0.4,
 
1265
  current_time: Optional[float] = None,
1266
  system_prompt: str = SYSTEM_PROMPT_1,
1267
  ) -> str:
@@ -1316,6 +1356,256 @@ async () => {
1316
  }
1317
  """
1318
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1319
  def launch():
1320
  global demo, llm, DEBUG, LOG_FILE
1321
  model_desc = MODEL_DESC
@@ -1329,6 +1619,7 @@ def launch():
1329
  max_tokens = MAX_TOKENS
1330
  temperature = TEMPERATURE
1331
  frequence_penalty = FREQUENCE_PENALTY
 
1332
  ckpt_info = "None"
1333
 
1334
  print(
@@ -1344,6 +1635,7 @@ def launch():
1344
  f'\n| DISPLAY_MODEL_PATH={DISPLAY_MODEL_PATH} '
1345
  f'\n| LANG_BLOCK_HISTORY={LANG_BLOCK_HISTORY} '
1346
  f'\n| frequence_penalty={frequence_penalty} '
 
1347
  f'\n| temperature={temperature} '
1348
  f'\n| hf_model_name={hf_model_name} '
1349
  f'\n| model_path={model_path} '
@@ -1409,44 +1701,120 @@ def launch():
1409
  if SAVE_LOGS:
1410
  LOG_FILE = open(LOG_PATH, 'a', encoding='utf-8')
1411
 
1412
- demo = gr.ChatInterface(
1413
- response_fn,
1414
- chatbot=ChatBot(
1415
- label=MODEL_NAME,
1416
- bubble_full_width=False,
1417
- latex_delimiters=[
1418
- { "left": "$", "right": "$", "display": False},
1419
- { "left": "$$", "right": "$$", "display": True},
 
 
 
 
 
1420
  ],
1421
- show_copy_button=True,
1422
- ),
1423
- textbox=gr.Textbox(placeholder='Type message', lines=8, max_lines=128, min_width=200),
1424
- submit_btn=gr.Button(value='Submit', variant="primary", scale=0),
1425
- # ! consider preventing the stop button
1426
- stop_btn=None,
1427
- title=f"{model_title}",
1428
- description=f"{model_desc}",
1429
- additional_inputs=[
1430
- gr.Number(value=temperature, label='Temperature (higher -> more random)'),
1431
- gr.Number(value=max_tokens, label='Max generated tokens (increase if want more generation)'),
1432
- gr.Number(value=frequence_penalty, label='Frequency penalty (> 0 encourage new tokens)'),
1433
- gr.Number(value=0, label='current_time', visible=False),
1434
- # ! Remove the system prompt textbox to avoid jailbreaking
1435
- # gr.Textbox(value=sys_prompt, label='System prompt', lines=8)
1436
- ],
1437
- )
1438
- demo.title = MODEL_NAME
1439
- with demo:
1440
- gr.Markdown(cite_markdown)
1441
- if DISPLAY_MODEL_PATH:
1442
- gr.Markdown(path_markdown.format(model_path=model_path))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1443
 
1444
- if ENABLE_AGREE_POPUP:
1445
- demo.load(None, None, None, _js=AGREE_POP_SCRIPTS)
1446
-
1447
 
1448
- demo.queue()
1449
- demo.launch(server_port=PORT)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1450
 
1451
 
1452
  def main():
@@ -1455,4 +1823,5 @@ def main():
1455
 
1456
 
1457
  if __name__ == "__main__":
1458
- main()
 
 
93
  MAX_TOKENS = int(os.environ.get("MAX_TOKENS", "2048"))
94
  TEMPERATURE = float(os.environ.get("TEMPERATURE", "0.1"))
95
  FREQUENCE_PENALTY = float(os.environ.get("FREQUENCE_PENALTY", "0.4"))
96
+ PRESENCE_PENALTY = float(os.environ.get("PRESENCE_PENALTY", "0.0"))
97
  gpu_memory_utilization = float(os.environ.get("gpu_memory_utilization", "0.9"))
98
 
99
  # whether to enable quantization, currently not in use
100
  QUANTIZATION = str(os.environ.get("QUANTIZATION", ""))
101
 
102
+
103
+ # Batch inference file upload
104
+ ENABLE_BATCH_INFER = bool(int(os.environ.get("ENABLE_BATCH_INFER", "1")))
105
+ BATCH_INFER_MAX_ITEMS = int(os.environ.get("BATCH_INFER_MAX_ITEMS", "200"))
106
+ BATCH_INFER_MAX_FILE_SIZE = int(os.environ.get("BATCH_INFER_MAX_FILE_SIZE", "500"))
107
+ BATCH_INFER_MAX_PROMPT_TOKENS = int(os.environ.get("BATCH_INFER_MAX_PROMPT_TOKENS", "4000"))
108
+ BATCH_INFER_SAVE_TMP_FILE = os.environ.get("BATCH_INFER_SAVE_TMP_FILE", "./tmp/pred.json")
109
+
110
+ #
111
  DATA_SET_REPO_PATH = str(os.environ.get("DATA_SET_REPO_PATH", ""))
112
  DATA_SET_REPO = None
113
 
 
114
  """
115
  Internal instructions of how to configure the DEMO
116
 
 
205
  </div>
206
  """
207
  # <a href=''><img src='https://img.shields.io/badge/Paper-PDF-red'></a>
208
+ # MODEL_DESC = """
209
+ # <div style='display:flex; gap: 0.25rem; '>
210
+ # <a href='https://github.com/SeaLLMs/SeaLLMs'><img src='https://img.shields.io/badge/Github-Code-success'></a>
211
+ # <a href='https://huggingface.co/spaces/SeaLLMs/SeaLLM-Chat-13b'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue'></a>
212
+ # <a href='https://huggingface.co/SeaLLMs/SeaLLM-Chat-13b'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Model-blue'></a>
213
+ # </div>
214
+ # <span style="font-size: larger">
215
+ # This is <a href="https://huggingface.co/SeaLLMs/SeaLLM-Chat-13b" target="_blank">SeaLLM-13B-Chat</a> - a chatbot assistant optimized for Southeast Asian Languages. It produces helpful responses in English 🇬🇧, Vietnamese 🇻🇳, Indonesian 🇮🇩 and Thai 🇹🇭.
216
+ # Explore <a href="https://huggingface.co/SeaLLMs/SeaLLM-Chat-13b" target="_blank">our article</a> for more details.
217
+ # </span>
218
+ # <br>
219
+ # <span >
220
+ # NOTE: The chatbot may produce inaccurate and harmful information about people, places, or facts.
221
+ # <span style="color: red">By using our service, you are required to agree to our <a href="https://huggingface.co/SeaLLMs/SeaLLM-Chat-13b/blob/main/LICENSE" target="_blank" style="color: red">SeaLLM Terms Of Use</a>, which include:</span><br>
222
+ # <ul>
223
+ # <li >
224
+ # You must not use our service to generate any harmful, unethical or illegal content that violates locally applicable and international laws or regulations,
225
+ # including but not limited to hate speech, violence, pornography and deception.</li>
226
+ # <li >
227
+ # The service collects user dialogue data for testing and performance improvement, and reserves the right to distribute it under
228
+ # <a href="https://creativecommons.org/licenses/by/4.0/">Creative Commons Attribution (CC-BY)</a> or similar license. So do not enter any personal information!
229
+ # </li>
230
+ # </ul>
231
+ # </span>
232
+ # """.strip()
233
+
234
  MODEL_DESC = """
235
  <div style='display:flex; gap: 0.25rem; '>
236
  <a href='https://github.com/SeaLLMs/SeaLLMs'><img src='https://img.shields.io/badge/Github-Code-success'></a>
 
242
  Explore <a href="https://huggingface.co/SeaLLMs/SeaLLM-Chat-13b" target="_blank">our article</a> for more details.
243
  </span>
244
  <br>
245
+ <span>
246
+ <span style="color: red">NOTE:</span> The chatbot may produce inaccurate and harmful information.
247
+ By using our service, you are required to <span style="color: red">agree to our <a href="https://huggingface.co/SeaLLMs/SeaLLM-Chat-13b/blob/main/LICENSE" target="_blank" style="color: red">Terms Of Use</a>,</span> which includes
248
+ not to use our service to generate any harmful, inappropriate or unethical or illegal content that violates locally applicable and international laws and regulations.
 
 
 
 
249
  The service collects user dialogue data for testing and performance improvement, and reserves the right to distribute it under
250
+ <a href="https://creativecommons.org/licenses/by/4.0/">(CC-BY)</a> or similar license. So do not enter any personal information!
 
 
251
  </span>
 
252
  """.strip()
253
 
254
 
 
737
  sys_prompt=SYSTEM_PROMPT_1,
738
  bos_token=BOS_TOKEN,
739
  eos_token=EOS_TOKEN,
740
+ include_end_instruct=True,
741
  ):
742
  """
743
  ```
 
747
  ```
748
  """
749
  text = ''
750
+ end_instr = f" {E_INST}" if include_end_instruct else ""
751
  for i, (prompt, res) in enumerate(history):
752
  if i == 0:
753
+ text += f"{bos_token}{B_INST} {B_SYS} {sys_prompt} {E_SYS} {prompt}{end_instr}"
754
  else:
755
+ text += f"{bos_token}{B_INST} {prompt}{end_instr}"
756
 
757
  if res is not None:
758
  text += f" {res} {eos_token} "
759
  if len(history) == 0 or text.strip() == '':
760
+ text = f"{bos_token}{B_INST} {B_SYS} {sys_prompt} {E_SYS} {message}{end_instr}"
761
  else:
762
+ text += f"{bos_token}{B_INST} {message}{end_instr}"
763
  return text
764
 
765
 
 
974
 
975
 
976
  def vllm_abort(self: Any):
977
+ sh = self.llm_engine.scheduler
978
+ for g in (sh.waiting + sh.running + sh.swapped):
979
+ sh.abort_seq_group(g.request_id)
980
+
981
  from vllm.sequence import SequenceStatus
982
  scheduler = self.llm_engine.scheduler
983
  for state_queue in [scheduler.waiting, scheduler.running, scheduler.swapped]:
 
1127
  temperature: float,
1128
  max_tokens: int,
1129
  frequency_penalty: float,
1130
+ presence_penalty: float,
1131
  current_time: Optional[float] = None,
1132
  system_prompt: Optional[str] = SYSTEM_PROMPT_1
1133
  ) -> str:
 
1179
  temperature=temperature,
1180
  max_tokens=max_tokens,
1181
  frequency_penalty=frequency_penalty,
1182
+ presence_penalty=presence_penalty,
1183
  stop=['<s>', '</s>', '<<SYS>>', '<</SYS>>', '[INST]', '[/INST]']
1184
  )
1185
  cur_out = None
 
1199
  assert len(gen) == 1, f'{gen}'
1200
  item = next(iter(gen.values()))
1201
  cur_out = item.outputs[0].text
1202
+
1203
+ if j >= max_tokens - 2:
1204
+ gr.Warning(f'The response hits limit of {max_tokens} tokens. Consider increase the max tokens parameter in the Additional Inputs.')
1205
 
1206
  # TODO: use current_time to register conversations, accoriding history and cur_out
1207
  history_str = format_conversation(history + [[message, cur_out]])
 
1275
  )
1276
  except Exception as e:
1277
  print(f'Failed to save to repo: {DATA_SET_REPO_PATH}|{str(e)}')
1278
+
1279
 
1280
  def print_log_file():
1281
  global LOG_FILE, LOG_PATH
 
1301
  temperature: float = 0.0,
1302
  max_tokens: int = 4096,
1303
  frequency_penalty: float = 0.4,
1304
+ presence_penalty: float = 0.0,
1305
  current_time: Optional[float] = None,
1306
  system_prompt: str = SYSTEM_PROMPT_1,
1307
  ) -> str:
 
1356
  }
1357
  """
1358
 
1359
+ def debug_file_function(
1360
+ files: Union[str, List[str]],
1361
+ prompt_mode: str,
1362
+ temperature: float,
1363
+ max_tokens: int,
1364
+ frequency_penalty: float,
1365
+ presence_penalty: float,
1366
+ stop_strings: str = "[STOP],<s>,</s>",
1367
+ current_time: Optional[float] = None,
1368
+ ):
1369
+ files = files if isinstance(files, list) else [files]
1370
+ print(files)
1371
+ filenames = [f.name for f in files]
1372
+ all_items = []
1373
+ for fname in filenames:
1374
+ print(f'Reading {fname}')
1375
+ with open(fname, 'r', encoding='utf-8') as f:
1376
+ items = json.load(f)
1377
+ assert isinstance(items, list), f'invalid items from {fname} not list'
1378
+ all_items.extend(items)
1379
+ print(all_items)
1380
+ print(f'{prompt_mode} / {temperature} / {max_tokens}, {frequency_penalty}, {presence_penalty}')
1381
+ save_path = "./test.json"
1382
+ with open(save_path, 'w', encoding='utf-8') as f:
1383
+ json.dump(all_items, f, indent=4, ensure_ascii=False)
1384
+
1385
+ for x in all_items:
1386
+ x['response'] = "Return response"
1387
+
1388
+ print_items = all_items[:1]
1389
+ # print_json = json.dumps(print_items, indent=4, ensure_ascii=False)
1390
+ return save_path, print_items
1391
+
1392
+
1393
+ def validate_file_item(filename, index, item: Dict[str, str]):
1394
+ # BATCH_INFER_MAX_PROMPT_TOKENS
1395
+ message = item['prompt'].strip()
1396
+
1397
+ if len(message) == 0:
1398
+ raise gr.Error(f'Prompt {index} empty')
1399
+
1400
+ message_safety = safety_check(message, history=None)
1401
+ if message_safety is not None:
1402
+ raise gr.Error(f'Prompt {index} unsafe or supported: {message_safety}')
1403
+
1404
+ tokenizer = llm.get_tokenizer() if llm is not None else None
1405
+ if tokenizer is None or len(tokenizer.encode(message, add_special_tokens=False)) >= BATCH_INFER_MAX_PROMPT_TOKENS:
1406
+ raise gr.Error(f"Prompt {index} too long, should be less than {BATCH_INFER_MAX_PROMPT_TOKENS} tokens")
1407
+
1408
+
1409
+ def read_validate_json_files(files: Union[str, List[str]]):
1410
+ files = files if isinstance(files, list) else [files]
1411
+ filenames = [f.name for f in files]
1412
+ all_items = []
1413
+ for fname in filenames:
1414
+ # check each files
1415
+ print(f'Reading {fname}')
1416
+ with open(fname, 'r', encoding='utf-8') as f:
1417
+ items = json.load(f)
1418
+ assert isinstance(items, list), f'Data {fname} not list'
1419
+ assert all(isinstance(x, dict) for x in items), f'item in input file not list'
1420
+ assert all("prompt" in x for x in items), f'key prompt should be in dict item of input file'
1421
+
1422
+ for i, x in enumerate(items):
1423
+ validate_file_item(fname, i, x)
1424
+
1425
+ all_items.extend(items)
1426
+ if len(all_items) > BATCH_INFER_MAX_ITEMS:
1427
+ raise gr.Error(f"Num samples {len(all_items)} > {BATCH_INFER_MAX_ITEMS} allowed.")
1428
+
1429
+ return all_items
1430
+
1431
+
1432
+ def remove_gradio_cache():
1433
+ import shutil
1434
+ for root, dirs, files in os.walk('/tmp/gradio/'):
1435
+ for f in files:
1436
+ os.unlink(os.path.join(root, f))
1437
+ for d in dirs:
1438
+ shutil.rmtree(os.path.join(root, d))
1439
+
1440
+
1441
+ def maybe_upload_batch_set(pred_json_path):
1442
+ global LOG_FILE, DATA_SET_REPO_PATH, SAVE_LOGS
1443
+
1444
+ if SAVE_LOGS and DATA_SET_REPO_PATH is not "":
1445
+ try:
1446
+ from huggingface_hub import upload_file
1447
+ path_in_repo = "misc/" + os.path.basename(pred_json_path).replace(".json", f'.{time.time()}.json')
1448
+ print(f'upload {pred_json_path} to {DATA_SET_REPO_PATH}//{path_in_repo}')
1449
+ upload_file(
1450
+ path_or_fileobj=pred_json_path,
1451
+ path_in_repo=path_in_repo,
1452
+ repo_id=DATA_SET_REPO_PATH,
1453
+ token=HF_TOKEN,
1454
+ repo_type="dataset",
1455
+ create_pr=True
1456
+ )
1457
+ except Exception as e:
1458
+ print(f'Failed to save to repo: {DATA_SET_REPO_PATH}|{str(e)}')
1459
+
1460
+
1461
+ def batch_inference(
1462
+ files: Union[str, List[str]],
1463
+ prompt_mode: str,
1464
+ temperature: float,
1465
+ max_tokens: int,
1466
+ frequency_penalty: float,
1467
+ presence_penalty: float,
1468
+ stop_strings: str = "[STOP],<s>,</s>",
1469
+ current_time: Optional[float] = None,
1470
+ system_prompt: Optional[str] = SYSTEM_PROMPT_1
1471
+ ):
1472
+ """
1473
+ Must handle
1474
+
1475
+ """
1476
+ global LOG_FILE, LOG_PATH, DEBUG, llm, RES_PRINTED
1477
+ if DEBUG:
1478
+ return debug_file_function(
1479
+ files, prompt_mode, temperature, max_tokens,
1480
+ presence_penalty, stop_strings, current_time)
1481
+
1482
+ from vllm import LLM, SamplingParams
1483
+ assert llm is not None
1484
+ # assert system_prompt.strip() != '', f'system prompt is empty'
1485
+
1486
+ stop_strings = [x.strip() for x in stop_strings.strip().split(",")]
1487
+ tokenizer = llm.get_tokenizer()
1488
+ # force removing all
1489
+ # NOTE: need to make sure all cached items are removed!!!!!!!!!
1490
+ vllm_abort(llm)
1491
+
1492
+ temperature = float(temperature)
1493
+ frequency_penalty = float(frequency_penalty)
1494
+ max_tokens = int(max_tokens)
1495
+
1496
+ all_items = read_validate_json_files(files)
1497
+
1498
+ # remove all items in /tmp/gradio/
1499
+ remove_gradio_cache()
1500
+
1501
+
1502
+ if prompt_mode == 'chat':
1503
+ prompt_format_fn = llama_chat_multiturn_sys_input_seq_constructor
1504
+ elif prompt_mode == 'few-shot':
1505
+ from functools import partial
1506
+ prompt_format_fn = partial(
1507
+ llama_chat_multiturn_sys_input_seq_constructor, include_end_instruct=False
1508
+ )
1509
+ else:
1510
+ raise gr.Error(f'Wrong mode {prompt_mode}')
1511
+
1512
+ full_prompts = [
1513
+ prompt_format_fn(
1514
+ x['prompt'], [], sys_prompt=system_prompt
1515
+ )
1516
+ for i, x in enumerate(all_items)
1517
+ ]
1518
+ print(f'{full_prompts[0]}\n')
1519
+
1520
+ if any(len(tokenizer.encode(x, add_special_tokens=False)) >= 4090 for x in full_prompts):
1521
+ raise gr.Error(f"Some prompt is too long!")
1522
+
1523
+ stop_seq = list(set(['<s>', '</s>', '<<SYS>>', '<</SYS>>', '[INST]', '[/INST]'] + stop_strings))
1524
+ sampling_params = SamplingParams(
1525
+ temperature=temperature,
1526
+ max_tokens=max_tokens,
1527
+ frequency_penalty=frequency_penalty,
1528
+ presence_penalty=presence_penalty,
1529
+ stop=stop_seq
1530
+ )
1531
+
1532
+ generated = llm.generate(full_prompts, sampling_params, use_tqdm=False)
1533
+ responses = [g.outputs[0].text for g in generated]
1534
+ if len(responses) != len(all_items):
1535
+ raise gr.Error(f'inconsistent lengths {len(responses)} != {len(all_items)}')
1536
+
1537
+ for res, item in zip(responses, all_items):
1538
+ item['response'] = res
1539
+
1540
+ # save_path = "/mnt/workspace/workgroup/phi/test.json"
1541
+ save_path = BATCH_INFER_SAVE_TMP_FILE
1542
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
1543
+ with open(save_path, 'w', encoding='utf-8') as f:
1544
+ json.dump(all_items, f, indent=4, ensure_ascii=False)
1545
+
1546
+ # You need to upload save_path as a new timestamp file.
1547
+ maybe_upload_batch_set(save_path)
1548
+
1549
+ print_items = all_items[:2]
1550
+ # print_json = json.dumps(print_items, indent=4, ensure_ascii=False)
1551
+ return save_path, print_items
1552
+
1553
+
1554
+ # BATCH_INFER_MAX_ITEMS
1555
+ FILE_UPLOAD_DESC = f"""File upload json format, with JSON object as list of dict with < {BATCH_INFER_MAX_ITEMS} items"""
1556
+ FILE_UPLOAD_DESCRIPTION = FILE_UPLOAD_DESC + """
1557
+ ```
1558
+ [ {\"id\": 0, \"prompt\": \"Hello world\"} , {\"id\": 1, \"prompt\": \"Hi there?\"}]
1559
+ ```
1560
+ """
1561
+
1562
+
1563
+ # https://huggingface.co/spaces/yuntian-deng/ChatGPT4Turbo/blob/main/app.py
1564
+ @document()
1565
+ class CusTabbedInterface(gr.Blocks):
1566
+ def __init__(
1567
+ self,
1568
+ interface_list: list[gr.Interface],
1569
+ tab_names: Optional[list[str]] = None,
1570
+ title: Optional[str] = None,
1571
+ description: Optional[str] = None,
1572
+ theme: Optional[gr.Theme] = None,
1573
+ analytics_enabled: Optional[bool] = None,
1574
+ css: Optional[str] = None,
1575
+ ):
1576
+ """
1577
+ Parameters:
1578
+ interface_list: a list of interfaces to be rendered in tabs.
1579
+ tab_names: a list of tab names. If None, the tab names will be "Tab 1", "Tab 2", etc.
1580
+ title: a title for the interface; if provided, appears above the input and output components in large font. Also used as the tab title when opened in a browser window.
1581
+ analytics_enabled: whether to allow basic telemetry. If None, will use GRADIO_ANALYTICS_ENABLED environment variable or default to True.
1582
+ css: custom css or path to custom css file to apply to entire Blocks
1583
+ Returns:
1584
+ a Gradio Tabbed Interface for the given interfaces
1585
+ """
1586
+ super().__init__(
1587
+ title=title or "Gradio",
1588
+ theme=theme,
1589
+ analytics_enabled=analytics_enabled,
1590
+ mode="tabbed_interface",
1591
+ css=css,
1592
+ )
1593
+ self.description = description
1594
+ if tab_names is None:
1595
+ tab_names = [f"Tab {i}" for i in range(len(interface_list))]
1596
+ with self:
1597
+ if title:
1598
+ gr.Markdown(
1599
+ f"<h1 style='text-align: center; margin-bottom: 1rem'>{title}</h1>"
1600
+ )
1601
+ if description:
1602
+ gr.Markdown(description)
1603
+ with gr.Tabs():
1604
+ for interface, tab_name in zip(interface_list, tab_names):
1605
+ with gr.Tab(label=tab_name):
1606
+ interface.render()
1607
+
1608
+
1609
  def launch():
1610
  global demo, llm, DEBUG, LOG_FILE
1611
  model_desc = MODEL_DESC
 
1619
  max_tokens = MAX_TOKENS
1620
  temperature = TEMPERATURE
1621
  frequence_penalty = FREQUENCE_PENALTY
1622
+ presence_penalty = PRESENCE_PENALTY
1623
  ckpt_info = "None"
1624
 
1625
  print(
 
1635
  f'\n| DISPLAY_MODEL_PATH={DISPLAY_MODEL_PATH} '
1636
  f'\n| LANG_BLOCK_HISTORY={LANG_BLOCK_HISTORY} '
1637
  f'\n| frequence_penalty={frequence_penalty} '
1638
+ f'\n| presence_penalty={presence_penalty} '
1639
  f'\n| temperature={temperature} '
1640
  f'\n| hf_model_name={hf_model_name} '
1641
  f'\n| model_path={model_path} '
 
1701
  if SAVE_LOGS:
1702
  LOG_FILE = open(LOG_PATH, 'a', encoding='utf-8')
1703
 
1704
+ if ENABLE_BATCH_INFER:
1705
+
1706
+ demo_file = gr.Interface(
1707
+ batch_inference,
1708
+ inputs=[
1709
+ gr.File(file_count='single', file_types=['json']),
1710
+ gr.Radio(["chat", "few-shot"], value='chat', label="Chat or Few-shot mode", info="Chat's output more user-friendly, Few-shot's output more consistent with few-shot patterns."),
1711
+ gr.Number(value=temperature, label='Temperature (higher -> more random)'),
1712
+ gr.Number(value=max_tokens, label='Max generated tokens (increase if want more generation)'),
1713
+ gr.Number(value=frequence_penalty, label='Frequency penalty (> 0 encourage new tokens over repeated tokens)'),
1714
+ gr.Number(value=presence_penalty, label='Presence penalty (> 0 encourage new tokens, < 0 encourage existing tokens)'),
1715
+ gr.Textbox(value="[STOP],[END],<s>,</s>", label='Comma-separated STOP string to stop generation only in few-shot mode', lines=1),
1716
+ gr.Number(value=0, label='current_time', visible=False),
1717
  ],
1718
+ outputs=[
1719
+ # "file",
1720
+ gr.File(label="Generated file"),
1721
+ # gr.Textbox(),
1722
+ # "json"
1723
+ gr.JSON(label='Example outputs (max 2 samples)')
1724
+ ],
1725
+ # examples=[[[os.path.join(os.path.dirname(__file__),"files/titanic.csv"),
1726
+ # os.path.join(os.path.dirname(__file__),"files/titanic.csv"),
1727
+ # os.path.join(os.path.dirname(__file__),"files/titanic.csv")]]],
1728
+ # cache_examples=True
1729
+ description=FILE_UPLOAD_DESCRIPTION
1730
+ )
1731
+
1732
+
1733
+ demo_chat = gr.ChatInterface(
1734
+ response_fn,
1735
+ chatbot=ChatBot(
1736
+ label=MODEL_NAME,
1737
+ bubble_full_width=False,
1738
+ latex_delimiters=[
1739
+ { "left": "$", "right": "$", "display": False},
1740
+ { "left": "$$", "right": "$$", "display": True},
1741
+ ],
1742
+ show_copy_button=True,
1743
+ ),
1744
+ textbox=gr.Textbox(placeholder='Type message', lines=8, max_lines=128, min_width=200),
1745
+ submit_btn=gr.Button(value='Submit', variant="primary", scale=0),
1746
+ # ! consider preventing the stop button
1747
+ # stop_btn=None,
1748
+ # title=f"{model_title}",
1749
+ # description=f"{model_desc}",
1750
+ additional_inputs=[
1751
+ gr.Number(value=temperature, label='Temperature (higher -> more random)'),
1752
+ gr.Number(value=max_tokens, label='Max generated tokens (increase if want more generation)'),
1753
+ gr.Number(value=frequence_penalty, label='Frequency penalty (> 0 encourage new tokens over repeated tokens)'),
1754
+ gr.Number(value=presence_penalty, label='Presence penalty (> 0 encourage new tokens, < 0 encourage existing tokens)'),
1755
+ gr.Number(value=0, label='current_time', visible=False),
1756
+ # ! Remove the system prompt textbox to avoid jailbreaking
1757
+ # gr.Textbox(value=sys_prompt, label='System prompt', lines=8)
1758
+ ],
1759
+ )
1760
+ demo = CusTabbedInterface(
1761
+ interface_list=[demo_chat, demo_file],
1762
+ tab_names=["Chat Interface", "Batch Inference"],
1763
+ title=f"{model_title}",
1764
+ description=f"{model_desc}",
1765
+ )
1766
+ demo.title = MODEL_NAME
1767
+ with demo:
1768
+ gr.Markdown(cite_markdown)
1769
+ if DISPLAY_MODEL_PATH:
1770
+ gr.Markdown(path_markdown.format(model_path=model_path))
1771
+
1772
+ if ENABLE_AGREE_POPUP:
1773
+ demo.load(None, None, None, _js=AGREE_POP_SCRIPTS)
1774
 
 
 
 
1775
 
1776
+ demo.queue()
1777
+ demo.launch(server_port=PORT)
1778
+ else:
1779
+ demo = gr.ChatInterface(
1780
+ response_fn,
1781
+ chatbot=ChatBot(
1782
+ label=MODEL_NAME,
1783
+ bubble_full_width=False,
1784
+ latex_delimiters=[
1785
+ { "left": "$", "right": "$", "display": False},
1786
+ { "left": "$$", "right": "$$", "display": True},
1787
+ ],
1788
+ show_copy_button=True,
1789
+ ),
1790
+ textbox=gr.Textbox(placeholder='Type message', lines=8, max_lines=128, min_width=200),
1791
+ submit_btn=gr.Button(value='Submit', variant="primary", scale=0),
1792
+ # ! consider preventing the stop button
1793
+ # stop_btn=None,
1794
+ title=f"{model_title}",
1795
+ description=f"{model_desc}",
1796
+ additional_inputs=[
1797
+ gr.Number(value=temperature, label='Temperature (higher -> more random)'),
1798
+ gr.Number(value=max_tokens, label='Max generated tokens (increase if want more generation)'),
1799
+ gr.Number(value=frequence_penalty, label='Frequency penalty (> 0 encourage new tokens over repeated tokens)'),
1800
+ gr.Number(value=presence_penalty, label='Presence penalty (> 0 encourage new tokens, < 0 encourage existing tokens)'),
1801
+ gr.Number(value=0, label='current_time', visible=False),
1802
+ # ! Remove the system prompt textbox to avoid jailbreaking
1803
+ # gr.Textbox(value=sys_prompt, label='System prompt', lines=8)
1804
+ ],
1805
+ )
1806
+ demo.title = MODEL_NAME
1807
+ with demo:
1808
+ gr.Markdown(cite_markdown)
1809
+ if DISPLAY_MODEL_PATH:
1810
+ gr.Markdown(path_markdown.format(model_path=model_path))
1811
+
1812
+ if ENABLE_AGREE_POPUP:
1813
+ demo.load(None, None, None, _js=AGREE_POP_SCRIPTS)
1814
+
1815
+
1816
+ demo.queue()
1817
+ demo.launch(server_port=PORT)
1818
 
1819
 
1820
  def main():
 
1823
 
1824
 
1825
  if __name__ == "__main__":
1826
+ main()
1827
+