Spaces:
Running
on
Zero
Running
on
Zero
phi
commited on
Commit
•
3709b60
1
Parent(s):
6ded56f
update
Browse files
app.py
CHANGED
@@ -57,68 +57,29 @@ TODO:
|
|
57 |
need to upload the model as hugginface/models/seal_13b_a
|
58 |
# https://huggingface.co/docs/hub/spaces-overview#managing-secrets
|
59 |
set
|
60 |
-
|
61 |
|
|
|
62 |
# if persistent, then export the following
|
|
|
63 |
HF_HOME=/data/.huggingface
|
64 |
-
TRANSFORMERS_CACHE=/data/.huggingface
|
65 |
MODEL_PATH=/data/.huggingface/seal-13b-chat-a
|
66 |
HF_MODEL_NAME=DAMO-NLP-SG/seal-13b-chat-a
|
67 |
# if not persistent
|
68 |
MODEL_PATH=./seal-13b-chat-a
|
69 |
HF_MODEL_NAME=DAMO-NLP-SG/seal-13b-chat-a
|
70 |
|
71 |
-
|
72 |
-
|
73 |
-
# download will auto detect and get the most updated one
|
74 |
-
if DOWNLOAD_SNAPSHOT:
|
75 |
-
print(f'Download from HF_MODEL_NAME={HF_MODEL_NAME} -> {MODEL_PATH}')
|
76 |
-
snapshot_download(HF_MODEL_NAME, local_dir=MODEL_PATH)
|
77 |
-
elif not DEBUG:
|
78 |
-
assert os.path.exists(MODEL_PATH), f'{MODEL_PATH} not found and no snapshot download'
|
79 |
-
|
80 |
"""
|
81 |
|
82 |
|
83 |
-
|
84 |
-
|
85 |
# ==============================
|
86 |
print(f'DEBUG mode: {DEBUG}')
|
87 |
|
88 |
-
if DTYPE == "bfloat16" and not DEBUG:
|
89 |
-
try:
|
90 |
-
compute_capability = torch.cuda.get_device_capability()
|
91 |
-
if compute_capability[0] < 8:
|
92 |
-
gpu_name = torch.cuda.get_device_name()
|
93 |
-
print(
|
94 |
-
"Bfloat16 is only supported on GPUs with compute capability "
|
95 |
-
f"of at least 8.0. Your {gpu_name} GPU has compute capability "
|
96 |
-
f"{compute_capability[0]}.{compute_capability[1]}. --> Move to FLOAT16")
|
97 |
-
DTYPE = "float16"
|
98 |
-
except Exception as e:
|
99 |
-
print(f'Unable to obtain compute_capability: {e}')
|
100 |
|
101 |
|
102 |
-
# @@ constants ================
|
103 |
-
if not DEBUG:
|
104 |
-
|
105 |
-
# vllm import
|
106 |
-
from vllm import LLM, SamplingParams
|
107 |
-
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
|
108 |
-
from vllm.engine.arg_utils import EngineArgs
|
109 |
-
from vllm.engine.llm_engine import LLMEngine
|
110 |
-
from vllm.outputs import RequestOutput
|
111 |
-
from vllm.sampling_params import SamplingParams
|
112 |
-
from vllm.utils import Counter
|
113 |
-
from vllm.sequence import (Sequence, SequenceData, SequenceGroup,
|
114 |
-
SequenceGroupMetadata, SequenceOutputs,
|
115 |
-
SequenceStatus)
|
116 |
-
# ! reconfigure vllm to faster llama
|
117 |
-
from vllm.model_executor.model_loader import _MODEL_REGISTRY
|
118 |
-
from vllm.model_executor.models import LlamaForCausalLM
|
119 |
|
|
|
120 |
|
121 |
-
_MODEL_REGISTRY['FasterLlamaForCausalLM'] = LlamaForCausalLM
|
122 |
|
123 |
|
124 |
def _detect_lang(text):
|
@@ -390,7 +351,6 @@ def llama_load_weights(
|
|
390 |
intermediate_size + shard_size * tensor_model_parallel_rank,
|
391 |
intermediate_size + shard_size * (tensor_model_parallel_rank + 1)
|
392 |
)
|
393 |
-
# print(f'{name} {param.size()} | {g_offsets} | {u_offsets}')
|
394 |
_loaded_weight = torch.cat(
|
395 |
[
|
396 |
loaded_weight[g_offsets[0]:g_offsets[1]],
|
@@ -420,7 +380,33 @@ def llama_load_weights(
|
|
420 |
|
421 |
# Reassign LlamaForCausalLM.load_weights with llama_load_weights
|
422 |
if not DEBUG:
|
423 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
424 |
|
425 |
# ! ==================================================================
|
426 |
|
@@ -501,11 +487,11 @@ class ChatBot(gr.Chatbot):
|
|
501 |
return x
|
502 |
|
503 |
|
504 |
-
# gr.ChatInterface
|
505 |
from gradio.components import Button
|
506 |
from gradio.events import Dependency, EventListenerMethod
|
507 |
|
508 |
-
|
|
|
509 |
def _setup_stop_events(
|
510 |
self, event_triggers: list[EventListenerMethod], event_to_cancel: Dependency
|
511 |
) -> None:
|
@@ -571,13 +557,12 @@ def _setup_stop_events(
|
|
571 |
queue=False,
|
572 |
)
|
573 |
|
574 |
-
|
575 |
-
|
576 |
gr.ChatInterface._setup_stop_events = _setup_stop_events
|
577 |
|
578 |
def chat_response(message, history, temperature: float, max_tokens: int, system_prompt: str = '') -> str:
|
579 |
global llm
|
580 |
assert llm is not None
|
|
|
581 |
temperature = float(temperature)
|
582 |
max_tokens = int(max_tokens)
|
583 |
if system_prompt.strip() != '':
|
@@ -594,6 +579,7 @@ def chat_response(message, history, temperature: float, max_tokens: int, system_
|
|
594 |
|
595 |
|
596 |
def vllm_abort(self: Any):
|
|
|
597 |
scheduler = self.llm_engine.scheduler
|
598 |
for state_queue in [scheduler.waiting, scheduler.running, scheduler.swapped]:
|
599 |
for seq_group in state_queue:
|
@@ -607,6 +593,7 @@ def vllm_abort(self: Any):
|
|
607 |
|
608 |
# def _vllm_run_engine(self: LLM, use_tqdm: bool = False) -> Dict[str, RequestOutput]:
|
609 |
def _vllm_run_engine(self: Any, use_tqdm: bool = False) -> Dict[str, Any]:
|
|
|
610 |
# Initialize tqdm.
|
611 |
if use_tqdm:
|
612 |
num_requests = self.llm_engine.get_num_unfinished_requests()
|
@@ -654,6 +641,7 @@ def vllm_generate_stream(
|
|
654 |
A list of `RequestOutput` objects containing the generated
|
655 |
completions in the same order as the input prompts.
|
656 |
"""
|
|
|
657 |
if prompts is None and prompt_token_ids is None:
|
658 |
raise ValueError("Either prompts or prompt_token_ids must be "
|
659 |
"provided.")
|
@@ -750,6 +738,7 @@ def chat_response_stream_multiturn(
|
|
750 |
frequency_penalty: float,
|
751 |
system_prompt: Optional[str] = SYSTEM_PROMPT_1
|
752 |
) -> str:
|
|
|
753 |
"""Build multi turn
|
754 |
<bos>[INST] B_SYS SytemPrompt E_SYS Prompt [/INST] Answer <eos>
|
755 |
<bos>[INST] Prompt [/INST] Answer <eos>
|
@@ -837,7 +826,7 @@ This is a DAMO SeaL-13B chatbot assistant built by DAMO Academy, Alibaba Group.
|
|
837 |
|
838 |
|
839 |
cite_markdown = """
|
840 |
-
|
841 |
If you find our project useful, hope you can star our repo and cite our paper as follows:
|
842 |
```
|
843 |
@article{damonlpsg2023seallm,
|
@@ -849,9 +838,8 @@ If you find our project useful, hope you can star our repo and cite our paper as
|
|
849 |
"""
|
850 |
|
851 |
warning_markdown = """
|
852 |
-
|
853 |
<span style="color: red">The chatbot may produce inaccurate and harmful information about people, places, or facts.</span>
|
854 |
-
|
855 |
<span style="color: red">We strongly advise against misuse of the chatbot to knowingly generate harmful or unethical content, \
|
856 |
or content that violates locally applicable and international laws or regulations, including hate speech, violence, pornography, deception, etc!</span>
|
857 |
"""
|
@@ -893,11 +881,12 @@ def launch():
|
|
893 |
ckpt_info = "None"
|
894 |
|
895 |
print(
|
896 |
-
f'Launch config: {
|
897 |
f'\n| STREAM_YIELD_MULTIPLE={STREAM_YIELD_MULTIPLE} '
|
898 |
f'\n| frequence_penalty={frequence_penalty} '
|
899 |
f'\n| temperature={temperature} '
|
900 |
f'\n| hf_model_name={hf_model_name} '
|
|
|
901 |
f'\n| DOWNLOAD_SNAPSHOT={DOWNLOAD_SNAPSHOT} '
|
902 |
f'\nsys={SYSTEM_PROMPT_1}'
|
903 |
f'\ndesc={model_desc}'
|
@@ -910,6 +899,8 @@ def launch():
|
|
910 |
else:
|
911 |
# ! load the model
|
912 |
import vllm
|
|
|
|
|
913 |
print(F'VLLM: {vllm.__version__}')
|
914 |
|
915 |
if DOWNLOAD_SNAPSHOT:
|
@@ -962,7 +953,6 @@ def launch():
|
|
962 |
|
963 |
def main():
|
964 |
|
965 |
-
# launch(parser.parse_args())
|
966 |
launch()
|
967 |
|
968 |
|
|
|
57 |
need to upload the model as hugginface/models/seal_13b_a
|
58 |
# https://huggingface.co/docs/hub/spaces-overview#managing-secrets
|
59 |
set
|
60 |
+
HF_TOKEN=???
|
61 |
|
62 |
+
TRANSFORMERS_CACHE=/data/.huggingface
|
63 |
# if persistent, then export the following
|
64 |
+
|
65 |
HF_HOME=/data/.huggingface
|
|
|
66 |
MODEL_PATH=/data/.huggingface/seal-13b-chat-a
|
67 |
HF_MODEL_NAME=DAMO-NLP-SG/seal-13b-chat-a
|
68 |
# if not persistent
|
69 |
MODEL_PATH=./seal-13b-chat-a
|
70 |
HF_MODEL_NAME=DAMO-NLP-SG/seal-13b-chat-a
|
71 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
72 |
"""
|
73 |
|
74 |
|
|
|
|
|
75 |
# ==============================
|
76 |
print(f'DEBUG mode: {DEBUG}')
|
77 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
|
79 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
80 |
|
81 |
+
# @@ constants ================
|
82 |
|
|
|
83 |
|
84 |
|
85 |
def _detect_lang(text):
|
|
|
351 |
intermediate_size + shard_size * tensor_model_parallel_rank,
|
352 |
intermediate_size + shard_size * (tensor_model_parallel_rank + 1)
|
353 |
)
|
|
|
354 |
_loaded_weight = torch.cat(
|
355 |
[
|
356 |
loaded_weight[g_offsets[0]:g_offsets[1]],
|
|
|
380 |
|
381 |
# Reassign LlamaForCausalLM.load_weights with llama_load_weights
|
382 |
if not DEBUG:
|
383 |
+
|
384 |
+
# vllm import
|
385 |
+
# from vllm import LLM, SamplingParams
|
386 |
+
# ! reconfigure vllm to faster llama
|
387 |
+
try:
|
388 |
+
import vllm
|
389 |
+
from vllm.model_executor.model_loader import _MODEL_REGISTRY
|
390 |
+
from vllm.model_executor.models import LlamaForCausalLM
|
391 |
+
|
392 |
+
_MODEL_REGISTRY['FasterLlamaForCausalLM'] = LlamaForCausalLM
|
393 |
+
LlamaForCausalLM.load_weights = llama_load_weights
|
394 |
+
|
395 |
+
if DTYPE == "bfloat16":
|
396 |
+
try:
|
397 |
+
compute_capability = torch.cuda.get_device_capability()
|
398 |
+
if compute_capability[0] < 8:
|
399 |
+
gpu_name = torch.cuda.get_device_name()
|
400 |
+
print(
|
401 |
+
"Bfloat16 is only supported on GPUs with compute capability "
|
402 |
+
f"of at least 8.0. Your {gpu_name} GPU has compute capability "
|
403 |
+
f"{compute_capability[0]}.{compute_capability[1]}. --> Move to FLOAT16")
|
404 |
+
DTYPE = "float16"
|
405 |
+
except Exception as e:
|
406 |
+
print(f'Unable to obtain compute_capability: {e}')
|
407 |
+
except Exception as e:
|
408 |
+
print(f'Failing import and reconfigure VLLM: {str(e)}')
|
409 |
+
|
410 |
|
411 |
# ! ==================================================================
|
412 |
|
|
|
487 |
return x
|
488 |
|
489 |
|
|
|
490 |
from gradio.components import Button
|
491 |
from gradio.events import Dependency, EventListenerMethod
|
492 |
|
493 |
+
# replace events so that submit button is disabled during generation, if stop_btn not found
|
494 |
+
# this prevent weird behavior
|
495 |
def _setup_stop_events(
|
496 |
self, event_triggers: list[EventListenerMethod], event_to_cancel: Dependency
|
497 |
) -> None:
|
|
|
557 |
queue=False,
|
558 |
)
|
559 |
|
|
|
|
|
560 |
gr.ChatInterface._setup_stop_events = _setup_stop_events
|
561 |
|
562 |
def chat_response(message, history, temperature: float, max_tokens: int, system_prompt: str = '') -> str:
|
563 |
global llm
|
564 |
assert llm is not None
|
565 |
+
from vllm import LLM, SamplingParams
|
566 |
temperature = float(temperature)
|
567 |
max_tokens = int(max_tokens)
|
568 |
if system_prompt.strip() != '':
|
|
|
579 |
|
580 |
|
581 |
def vllm_abort(self: Any):
|
582 |
+
from vllm.sequence import SequenceStatus
|
583 |
scheduler = self.llm_engine.scheduler
|
584 |
for state_queue in [scheduler.waiting, scheduler.running, scheduler.swapped]:
|
585 |
for seq_group in state_queue:
|
|
|
593 |
|
594 |
# def _vllm_run_engine(self: LLM, use_tqdm: bool = False) -> Dict[str, RequestOutput]:
|
595 |
def _vllm_run_engine(self: Any, use_tqdm: bool = False) -> Dict[str, Any]:
|
596 |
+
from vllm.outputs import RequestOutput
|
597 |
# Initialize tqdm.
|
598 |
if use_tqdm:
|
599 |
num_requests = self.llm_engine.get_num_unfinished_requests()
|
|
|
641 |
A list of `RequestOutput` objects containing the generated
|
642 |
completions in the same order as the input prompts.
|
643 |
"""
|
644 |
+
from vllm import LLM, SamplingParams
|
645 |
if prompts is None and prompt_token_ids is None:
|
646 |
raise ValueError("Either prompts or prompt_token_ids must be "
|
647 |
"provided.")
|
|
|
738 |
frequency_penalty: float,
|
739 |
system_prompt: Optional[str] = SYSTEM_PROMPT_1
|
740 |
) -> str:
|
741 |
+
from vllm import LLM, SamplingParams
|
742 |
"""Build multi turn
|
743 |
<bos>[INST] B_SYS SytemPrompt E_SYS Prompt [/INST] Answer <eos>
|
744 |
<bos>[INST] Prompt [/INST] Answer <eos>
|
|
|
826 |
|
827 |
|
828 |
cite_markdown = """
|
829 |
+
## Citation
|
830 |
If you find our project useful, hope you can star our repo and cite our paper as follows:
|
831 |
```
|
832 |
@article{damonlpsg2023seallm,
|
|
|
838 |
"""
|
839 |
|
840 |
warning_markdown = """
|
841 |
+
## Warning:
|
842 |
<span style="color: red">The chatbot may produce inaccurate and harmful information about people, places, or facts.</span>
|
|
|
843 |
<span style="color: red">We strongly advise against misuse of the chatbot to knowingly generate harmful or unethical content, \
|
844 |
or content that violates locally applicable and international laws or regulations, including hate speech, violence, pornography, deception, etc!</span>
|
845 |
"""
|
|
|
881 |
ckpt_info = "None"
|
882 |
|
883 |
print(
|
884 |
+
f'Launch config: {model_title=} / {tensor_parallel=} / {dtype=} / {max_tokens} | {BLOCK_ZH=} '
|
885 |
f'\n| STREAM_YIELD_MULTIPLE={STREAM_YIELD_MULTIPLE} '
|
886 |
f'\n| frequence_penalty={frequence_penalty} '
|
887 |
f'\n| temperature={temperature} '
|
888 |
f'\n| hf_model_name={hf_model_name} '
|
889 |
+
f'\n| model_path={model_path} '
|
890 |
f'\n| DOWNLOAD_SNAPSHOT={DOWNLOAD_SNAPSHOT} '
|
891 |
f'\nsys={SYSTEM_PROMPT_1}'
|
892 |
f'\ndesc={model_desc}'
|
|
|
899 |
else:
|
900 |
# ! load the model
|
901 |
import vllm
|
902 |
+
from vllm import LLM, SamplingParams
|
903 |
+
|
904 |
print(F'VLLM: {vllm.__version__}')
|
905 |
|
906 |
if DOWNLOAD_SNAPSHOT:
|
|
|
953 |
|
954 |
def main():
|
955 |
|
|
|
956 |
launch()
|
957 |
|
958 |
|