phi commited on
Commit
3709b60
1 Parent(s): 6ded56f
Files changed (1) hide show
  1. app.py +44 -54
app.py CHANGED
@@ -57,68 +57,29 @@ TODO:
57
  need to upload the model as hugginface/models/seal_13b_a
58
  # https://huggingface.co/docs/hub/spaces-overview#managing-secrets
59
  set
60
- MODEL_REPO_ID=hugginface/models/seal_13b_a
61
 
 
62
  # if persistent, then export the following
 
63
  HF_HOME=/data/.huggingface
64
- TRANSFORMERS_CACHE=/data/.huggingface
65
  MODEL_PATH=/data/.huggingface/seal-13b-chat-a
66
  HF_MODEL_NAME=DAMO-NLP-SG/seal-13b-chat-a
67
  # if not persistent
68
  MODEL_PATH=./seal-13b-chat-a
69
  HF_MODEL_NAME=DAMO-NLP-SG/seal-13b-chat-a
70
 
71
-
72
-
73
- # download will auto detect and get the most updated one
74
- if DOWNLOAD_SNAPSHOT:
75
- print(f'Download from HF_MODEL_NAME={HF_MODEL_NAME} -> {MODEL_PATH}')
76
- snapshot_download(HF_MODEL_NAME, local_dir=MODEL_PATH)
77
- elif not DEBUG:
78
- assert os.path.exists(MODEL_PATH), f'{MODEL_PATH} not found and no snapshot download'
79
-
80
  """
81
 
82
 
83
-
84
-
85
  # ==============================
86
  print(f'DEBUG mode: {DEBUG}')
87
 
88
- if DTYPE == "bfloat16" and not DEBUG:
89
- try:
90
- compute_capability = torch.cuda.get_device_capability()
91
- if compute_capability[0] < 8:
92
- gpu_name = torch.cuda.get_device_name()
93
- print(
94
- "Bfloat16 is only supported on GPUs with compute capability "
95
- f"of at least 8.0. Your {gpu_name} GPU has compute capability "
96
- f"{compute_capability[0]}.{compute_capability[1]}. --> Move to FLOAT16")
97
- DTYPE = "float16"
98
- except Exception as e:
99
- print(f'Unable to obtain compute_capability: {e}')
100
 
101
 
102
- # @@ constants ================
103
- if not DEBUG:
104
-
105
- # vllm import
106
- from vllm import LLM, SamplingParams
107
- from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
108
- from vllm.engine.arg_utils import EngineArgs
109
- from vllm.engine.llm_engine import LLMEngine
110
- from vllm.outputs import RequestOutput
111
- from vllm.sampling_params import SamplingParams
112
- from vllm.utils import Counter
113
- from vllm.sequence import (Sequence, SequenceData, SequenceGroup,
114
- SequenceGroupMetadata, SequenceOutputs,
115
- SequenceStatus)
116
- # ! reconfigure vllm to faster llama
117
- from vllm.model_executor.model_loader import _MODEL_REGISTRY
118
- from vllm.model_executor.models import LlamaForCausalLM
119
 
 
120
 
121
- _MODEL_REGISTRY['FasterLlamaForCausalLM'] = LlamaForCausalLM
122
 
123
 
124
  def _detect_lang(text):
@@ -390,7 +351,6 @@ def llama_load_weights(
390
  intermediate_size + shard_size * tensor_model_parallel_rank,
391
  intermediate_size + shard_size * (tensor_model_parallel_rank + 1)
392
  )
393
- # print(f'{name} {param.size()} | {g_offsets} | {u_offsets}')
394
  _loaded_weight = torch.cat(
395
  [
396
  loaded_weight[g_offsets[0]:g_offsets[1]],
@@ -420,7 +380,33 @@ def llama_load_weights(
420
 
421
  # Reassign LlamaForCausalLM.load_weights with llama_load_weights
422
  if not DEBUG:
423
- LlamaForCausalLM.load_weights = llama_load_weights
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
424
 
425
  # ! ==================================================================
426
 
@@ -501,11 +487,11 @@ class ChatBot(gr.Chatbot):
501
  return x
502
 
503
 
504
- # gr.ChatInterface
505
  from gradio.components import Button
506
  from gradio.events import Dependency, EventListenerMethod
507
 
508
-
 
509
  def _setup_stop_events(
510
  self, event_triggers: list[EventListenerMethod], event_to_cancel: Dependency
511
  ) -> None:
@@ -571,13 +557,12 @@ def _setup_stop_events(
571
  queue=False,
572
  )
573
 
574
-
575
-
576
  gr.ChatInterface._setup_stop_events = _setup_stop_events
577
 
578
  def chat_response(message, history, temperature: float, max_tokens: int, system_prompt: str = '') -> str:
579
  global llm
580
  assert llm is not None
 
581
  temperature = float(temperature)
582
  max_tokens = int(max_tokens)
583
  if system_prompt.strip() != '':
@@ -594,6 +579,7 @@ def chat_response(message, history, temperature: float, max_tokens: int, system_
594
 
595
 
596
  def vllm_abort(self: Any):
 
597
  scheduler = self.llm_engine.scheduler
598
  for state_queue in [scheduler.waiting, scheduler.running, scheduler.swapped]:
599
  for seq_group in state_queue:
@@ -607,6 +593,7 @@ def vllm_abort(self: Any):
607
 
608
  # def _vllm_run_engine(self: LLM, use_tqdm: bool = False) -> Dict[str, RequestOutput]:
609
  def _vllm_run_engine(self: Any, use_tqdm: bool = False) -> Dict[str, Any]:
 
610
  # Initialize tqdm.
611
  if use_tqdm:
612
  num_requests = self.llm_engine.get_num_unfinished_requests()
@@ -654,6 +641,7 @@ def vllm_generate_stream(
654
  A list of `RequestOutput` objects containing the generated
655
  completions in the same order as the input prompts.
656
  """
 
657
  if prompts is None and prompt_token_ids is None:
658
  raise ValueError("Either prompts or prompt_token_ids must be "
659
  "provided.")
@@ -750,6 +738,7 @@ def chat_response_stream_multiturn(
750
  frequency_penalty: float,
751
  system_prompt: Optional[str] = SYSTEM_PROMPT_1
752
  ) -> str:
 
753
  """Build multi turn
754
  <bos>[INST] B_SYS SytemPrompt E_SYS Prompt [/INST] Answer <eos>
755
  <bos>[INST] Prompt [/INST] Answer <eos>
@@ -837,7 +826,7 @@ This is a DAMO SeaL-13B chatbot assistant built by DAMO Academy, Alibaba Group.
837
 
838
 
839
  cite_markdown = """
840
- ### Citation
841
  If you find our project useful, hope you can star our repo and cite our paper as follows:
842
  ```
843
  @article{damonlpsg2023seallm,
@@ -849,9 +838,8 @@ If you find our project useful, hope you can star our repo and cite our paper as
849
  """
850
 
851
  warning_markdown = """
852
- ### Warning:
853
  <span style="color: red">The chatbot may produce inaccurate and harmful information about people, places, or facts.</span>
854
-
855
  <span style="color: red">We strongly advise against misuse of the chatbot to knowingly generate harmful or unethical content, \
856
  or content that violates locally applicable and international laws or regulations, including hate speech, violence, pornography, deception, etc!</span>
857
  """
@@ -893,11 +881,12 @@ def launch():
893
  ckpt_info = "None"
894
 
895
  print(
896
- f'Launch config: {model_path=} / {model_title=} / {tensor_parallel=} / {dtype=} / {max_tokens} | {BLOCK_ZH=} '
897
  f'\n| STREAM_YIELD_MULTIPLE={STREAM_YIELD_MULTIPLE} '
898
  f'\n| frequence_penalty={frequence_penalty} '
899
  f'\n| temperature={temperature} '
900
  f'\n| hf_model_name={hf_model_name} '
 
901
  f'\n| DOWNLOAD_SNAPSHOT={DOWNLOAD_SNAPSHOT} '
902
  f'\nsys={SYSTEM_PROMPT_1}'
903
  f'\ndesc={model_desc}'
@@ -910,6 +899,8 @@ def launch():
910
  else:
911
  # ! load the model
912
  import vllm
 
 
913
  print(F'VLLM: {vllm.__version__}')
914
 
915
  if DOWNLOAD_SNAPSHOT:
@@ -962,7 +953,6 @@ def launch():
962
 
963
  def main():
964
 
965
- # launch(parser.parse_args())
966
  launch()
967
 
968
 
 
57
  need to upload the model as hugginface/models/seal_13b_a
58
  # https://huggingface.co/docs/hub/spaces-overview#managing-secrets
59
  set
60
+ HF_TOKEN=???
61
 
62
+ TRANSFORMERS_CACHE=/data/.huggingface
63
  # if persistent, then export the following
64
+
65
  HF_HOME=/data/.huggingface
 
66
  MODEL_PATH=/data/.huggingface/seal-13b-chat-a
67
  HF_MODEL_NAME=DAMO-NLP-SG/seal-13b-chat-a
68
  # if not persistent
69
  MODEL_PATH=./seal-13b-chat-a
70
  HF_MODEL_NAME=DAMO-NLP-SG/seal-13b-chat-a
71
 
 
 
 
 
 
 
 
 
 
72
  """
73
 
74
 
 
 
75
  # ==============================
76
  print(f'DEBUG mode: {DEBUG}')
77
 
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
+ # @@ constants ================
82
 
 
83
 
84
 
85
  def _detect_lang(text):
 
351
  intermediate_size + shard_size * tensor_model_parallel_rank,
352
  intermediate_size + shard_size * (tensor_model_parallel_rank + 1)
353
  )
 
354
  _loaded_weight = torch.cat(
355
  [
356
  loaded_weight[g_offsets[0]:g_offsets[1]],
 
380
 
381
  # Reassign LlamaForCausalLM.load_weights with llama_load_weights
382
  if not DEBUG:
383
+
384
+ # vllm import
385
+ # from vllm import LLM, SamplingParams
386
+ # ! reconfigure vllm to faster llama
387
+ try:
388
+ import vllm
389
+ from vllm.model_executor.model_loader import _MODEL_REGISTRY
390
+ from vllm.model_executor.models import LlamaForCausalLM
391
+
392
+ _MODEL_REGISTRY['FasterLlamaForCausalLM'] = LlamaForCausalLM
393
+ LlamaForCausalLM.load_weights = llama_load_weights
394
+
395
+ if DTYPE == "bfloat16":
396
+ try:
397
+ compute_capability = torch.cuda.get_device_capability()
398
+ if compute_capability[0] < 8:
399
+ gpu_name = torch.cuda.get_device_name()
400
+ print(
401
+ "Bfloat16 is only supported on GPUs with compute capability "
402
+ f"of at least 8.0. Your {gpu_name} GPU has compute capability "
403
+ f"{compute_capability[0]}.{compute_capability[1]}. --> Move to FLOAT16")
404
+ DTYPE = "float16"
405
+ except Exception as e:
406
+ print(f'Unable to obtain compute_capability: {e}')
407
+ except Exception as e:
408
+ print(f'Failing import and reconfigure VLLM: {str(e)}')
409
+
410
 
411
  # ! ==================================================================
412
 
 
487
  return x
488
 
489
 
 
490
  from gradio.components import Button
491
  from gradio.events import Dependency, EventListenerMethod
492
 
493
+ # replace events so that submit button is disabled during generation, if stop_btn not found
494
+ # this prevent weird behavior
495
  def _setup_stop_events(
496
  self, event_triggers: list[EventListenerMethod], event_to_cancel: Dependency
497
  ) -> None:
 
557
  queue=False,
558
  )
559
 
 
 
560
  gr.ChatInterface._setup_stop_events = _setup_stop_events
561
 
562
  def chat_response(message, history, temperature: float, max_tokens: int, system_prompt: str = '') -> str:
563
  global llm
564
  assert llm is not None
565
+ from vllm import LLM, SamplingParams
566
  temperature = float(temperature)
567
  max_tokens = int(max_tokens)
568
  if system_prompt.strip() != '':
 
579
 
580
 
581
  def vllm_abort(self: Any):
582
+ from vllm.sequence import SequenceStatus
583
  scheduler = self.llm_engine.scheduler
584
  for state_queue in [scheduler.waiting, scheduler.running, scheduler.swapped]:
585
  for seq_group in state_queue:
 
593
 
594
  # def _vllm_run_engine(self: LLM, use_tqdm: bool = False) -> Dict[str, RequestOutput]:
595
  def _vllm_run_engine(self: Any, use_tqdm: bool = False) -> Dict[str, Any]:
596
+ from vllm.outputs import RequestOutput
597
  # Initialize tqdm.
598
  if use_tqdm:
599
  num_requests = self.llm_engine.get_num_unfinished_requests()
 
641
  A list of `RequestOutput` objects containing the generated
642
  completions in the same order as the input prompts.
643
  """
644
+ from vllm import LLM, SamplingParams
645
  if prompts is None and prompt_token_ids is None:
646
  raise ValueError("Either prompts or prompt_token_ids must be "
647
  "provided.")
 
738
  frequency_penalty: float,
739
  system_prompt: Optional[str] = SYSTEM_PROMPT_1
740
  ) -> str:
741
+ from vllm import LLM, SamplingParams
742
  """Build multi turn
743
  <bos>[INST] B_SYS SytemPrompt E_SYS Prompt [/INST] Answer <eos>
744
  <bos>[INST] Prompt [/INST] Answer <eos>
 
826
 
827
 
828
  cite_markdown = """
829
+ ## Citation
830
  If you find our project useful, hope you can star our repo and cite our paper as follows:
831
  ```
832
  @article{damonlpsg2023seallm,
 
838
  """
839
 
840
  warning_markdown = """
841
+ ## Warning:
842
  <span style="color: red">The chatbot may produce inaccurate and harmful information about people, places, or facts.</span>
 
843
  <span style="color: red">We strongly advise against misuse of the chatbot to knowingly generate harmful or unethical content, \
844
  or content that violates locally applicable and international laws or regulations, including hate speech, violence, pornography, deception, etc!</span>
845
  """
 
881
  ckpt_info = "None"
882
 
883
  print(
884
+ f'Launch config: {model_title=} / {tensor_parallel=} / {dtype=} / {max_tokens} | {BLOCK_ZH=} '
885
  f'\n| STREAM_YIELD_MULTIPLE={STREAM_YIELD_MULTIPLE} '
886
  f'\n| frequence_penalty={frequence_penalty} '
887
  f'\n| temperature={temperature} '
888
  f'\n| hf_model_name={hf_model_name} '
889
+ f'\n| model_path={model_path} '
890
  f'\n| DOWNLOAD_SNAPSHOT={DOWNLOAD_SNAPSHOT} '
891
  f'\nsys={SYSTEM_PROMPT_1}'
892
  f'\ndesc={model_desc}'
 
899
  else:
900
  # ! load the model
901
  import vllm
902
+ from vllm import LLM, SamplingParams
903
+
904
  print(F'VLLM: {vllm.__version__}')
905
 
906
  if DOWNLOAD_SNAPSHOT:
 
953
 
954
  def main():
955
 
 
956
  launch()
957
 
958