Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -25,7 +25,7 @@ from tqdm.auto import tqdm
|
|
25 |
from huggingface_hub import snapshot_download
|
26 |
|
27 |
|
28 |
-
# @@
|
29 |
|
30 |
DEBUG = bool(int(os.environ.get("DEBUG", "1")))
|
31 |
BLOCK_ZH = bool(int(os.environ.get("BLOCK_ZH", "1")))
|
@@ -34,59 +34,53 @@ DTYPE = os.environ.get("DTYPE", "bfloat16")
|
|
34 |
|
35 |
# ! (no debug) whether to download HF_MODEL_NAME and save to MODEL_PATH
|
36 |
DOWNLOAD_SNAPSHOT = bool(int(os.environ.get("DOWNLOAD_SNAPSHOT", "0")))
|
|
|
|
|
37 |
# ! uploaded model path, will be downloaded to MODEL_PATH
|
38 |
HF_MODEL_NAME = os.environ.get("HF_MODEL_NAME", "DAMO-NLP-SG/seal-13b-chat-a")
|
|
|
39 |
HF_TOKEN = os.environ.get("HF_TOKEN", None)
|
|
|
40 |
MODEL_PATH = os.environ.get("MODEL_PATH", "./seal-13b-chat-a")
|
41 |
|
42 |
-
|
|
|
|
|
|
|
43 |
|
44 |
# gradio config
|
45 |
PORT = int(os.environ.get("PORT", "7860"))
|
|
|
46 |
STREAM_YIELD_MULTIPLE = int(os.environ.get("STREAM_YIELD_MULTIPLE", "1"))
|
|
|
|
|
|
|
|
|
47 |
MAX_TOKENS = int(os.environ.get("MAX_TOKENS", "2048"))
|
48 |
TEMPERATURE = float(os.environ.get("TEMPERATURE", "0.1"))
|
49 |
FREQUENCE_PENALTY = float(os.environ.get("FREQUENCE_PENALTY", "0.4"))
|
|
|
50 |
|
|
|
|
|
51 |
|
52 |
-
"""
|
53 |
-
TODO:
|
54 |
-
need to upload the model as hugginface/models/seal_13b_a
|
55 |
-
# https://huggingface.co/docs/hub/spaces-overview#managing-secrets
|
56 |
-
set
|
57 |
-
HF_TOKEN=???
|
58 |
|
59 |
-
|
60 |
-
|
61 |
|
|
|
|
|
|
|
|
|
62 |
HF_HOME=/data/.huggingface
|
63 |
MODEL_PATH=/data/.huggingface/seal-13b-chat-a
|
64 |
-
|
65 |
-
# if not persistent
|
66 |
MODEL_PATH=./seal-13b-chat-a
|
67 |
-
HF_MODEL_NAME=DAMO-NLP-SG/seal-13b-chat-a
|
68 |
-
|
69 |
-
|
70 |
-
===== Application Startup at 2023-10-20 04:03:49 =====
|
71 |
-
|
72 |
-
DEBUG mode: False
|
73 |
-
Torch version: 2.1.0+cu121
|
74 |
-
Torch CUDA version: 12.1
|
75 |
-
/home/user/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch/cuda/__init__.py:138: UserWarning: CUDA initialization: The NVIDIA driver on your system is too old (found version 11040). Please update your GPU driver by downloading and installing a new version from the URL: http://www.nvidia.com/Download/index.aspx Alternatively, go to: https://pytorch.org to install a PyTorch version that has been compiled with your version of the CUDA driver. (Triggered internally at ../c10/cuda/CUDAFunctions.cpp:108.)
|
76 |
-
return torch._C._cuda_getDeviceCount() > 0
|
77 |
-
Unable to obtain compute_capability: The NVIDIA driver on your system is too old (found version 11040). Please update your GPU driver by downloading and installing a new version from the URL: http://www.nvidia.com/Download/index.aspx Alternatively, go to: https://pytorch.org to install a PyTorch version that has been compiled with your version of the CUDA driver.
|
78 |
-
Launch config: model_title='SeaL-13B - An Assistant for South East Asian Languages' / tensor_parallel=1 / dtype='bfloat16' / 2048 | BLOCK_ZH=True
|
79 |
-
| STREAM_YIELD_MULTIPLE=1
|
80 |
-
| frequence_penalty=0.4
|
81 |
-
| temperature=0.1
|
82 |
-
| hf_model_name=DAMO-NLP-SG/seal-13b-chat-a
|
83 |
-
| model_path=./seal-13b-chat-a
|
84 |
-
| DOWNLOAD_SNAPSHOT=True
|
85 |
-
sys=You are a multilingual, helpful,
|
86 |
|
87 |
"""
|
88 |
|
89 |
|
|
|
90 |
# ==============================
|
91 |
print(f'DEBUG mode: {DEBUG}')
|
92 |
print(f'Torch version: {torch.__version__}')
|
@@ -95,16 +89,109 @@ try:
|
|
95 |
except Exception as e:
|
96 |
print(f'Failed to print cuda version: {e}')
|
97 |
|
98 |
-
|
|
|
|
|
|
|
|
|
99 |
|
100 |
|
101 |
# @@ constants ================
|
102 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
103 |
|
104 |
|
105 |
def _detect_lang(text):
|
106 |
from langdetect import detect as detect_lang
|
107 |
-
from langdetect.detector import LangDetectException
|
108 |
dlang = None
|
109 |
try:
|
110 |
dlang = detect_lang(text)
|
@@ -118,11 +205,12 @@ def _detect_lang(text):
|
|
118 |
return dlang
|
119 |
|
120 |
|
121 |
-
def
|
122 |
model_name_or_path: str,
|
123 |
cache_dir: Optional[str] = None,
|
124 |
use_np_cache: bool = False,
|
125 |
) -> Iterator[Tuple[str, torch.Tensor]]:
|
|
|
126 |
from vllm.model_executor.weight_utils import Disabledtqdm
|
127 |
# Prepare file lock directory to prevent multiple processes from
|
128 |
# downloading the same model weights at the same time.
|
@@ -143,7 +231,6 @@ def hf_model_weights_iterator(
|
|
143 |
hf_folder = model_name_or_path
|
144 |
|
145 |
hf_bin_files = [
|
146 |
-
# x for x in glob.glob(os.path.join(hf_folder, "*.bin"))
|
147 |
x for x in glob.glob(os.path.join(hf_folder, "*model*.bin"))
|
148 |
if not x.endswith("training_args.bin")
|
149 |
]
|
@@ -236,9 +323,9 @@ def llama_load_weights(
|
|
236 |
cache_dir: Optional[str] = None,
|
237 |
use_np_cache: bool = False,
|
238 |
load_format: str = "auto",
|
239 |
-
# load_format: str = "pt",
|
240 |
revision: Optional[str] = None
|
241 |
):
|
|
|
242 |
from vllm.model_executor.weight_utils import (
|
243 |
load_tensor_parallel_weights
|
244 |
)
|
@@ -261,7 +348,7 @@ def llama_load_weights(
|
|
261 |
state_dict = self.state_dict()
|
262 |
need_to_load = len(state_dict)
|
263 |
loaded = 0
|
264 |
-
iterator =
|
265 |
|
266 |
for name, loaded_weight in iterator:
|
267 |
if "rotary_emb.inv_freq" in name:
|
@@ -331,7 +418,6 @@ def llama_load_weights(
|
|
331 |
loaded_weight[v_offsets[0]:v_offsets[1]],
|
332 |
], 0
|
333 |
)
|
334 |
-
# print(f'{name} | {q_offsets} | {k_offsets} | {v_offsets}')
|
335 |
assert param.shape == _loaded_weight.shape, f'{param.shape=} != {_loaded_weight.shape=}'
|
336 |
param.data.copy_(_loaded_weight)
|
337 |
loaded += 1.0
|
@@ -398,19 +484,158 @@ def llama_load_weights(
|
|
398 |
print(f'Loaded all {loaded} params loaded out of {need_to_load}')
|
399 |
|
400 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
401 |
# Reassign LlamaForCausalLM.load_weights with llama_load_weights
|
402 |
if not DEBUG:
|
403 |
|
404 |
-
# vllm import
|
405 |
-
# from vllm import LLM, SamplingParams
|
406 |
-
# ! reconfigure vllm to faster llama
|
407 |
try:
|
408 |
import vllm
|
409 |
from vllm.model_executor.model_loader import _MODEL_REGISTRY
|
410 |
from vllm.model_executor.models import LlamaForCausalLM
|
411 |
|
412 |
_MODEL_REGISTRY['FasterLlamaForCausalLM'] = LlamaForCausalLM
|
413 |
-
|
|
|
|
|
|
|
414 |
|
415 |
if DTYPE == "bfloat16":
|
416 |
try:
|
@@ -433,33 +658,6 @@ if not DEBUG:
|
|
433 |
set_documentation_group("component")
|
434 |
|
435 |
|
436 |
-
|
437 |
-
DTYPES = {
|
438 |
-
'float16': torch.float16,
|
439 |
-
'bfloat16': torch.bfloat16
|
440 |
-
}
|
441 |
-
|
442 |
-
llm = None
|
443 |
-
demo = None
|
444 |
-
|
445 |
-
|
446 |
-
BOS_TOKEN = '<s>'
|
447 |
-
EOS_TOKEN = '</s>'
|
448 |
-
|
449 |
-
B_INST, E_INST = "[INST]", "[/INST]"
|
450 |
-
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
|
451 |
-
|
452 |
-
SYSTEM_PROMPT_1 = """You are a multilingual, helpful, respectful and honest assistant. Your name is SeaL and you are built by DAMO Academy, Alibaba Group. Always answer as helpfully as possible, while being safe. Your \
|
453 |
-
answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure\
|
454 |
-
that your responses are socially unbiased and positive in nature.
|
455 |
-
|
456 |
-
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not \
|
457 |
-
correct. If you don't know the answer to a question, please don't share false information.
|
458 |
-
|
459 |
-
As a multilingual assistant, you must respond and follow instructions in the native language of the user by default, unless told otherwise. \
|
460 |
-
Your response should adapt to the norms and customs of the respective language and culture.
|
461 |
-
"""
|
462 |
-
|
463 |
RES_PRINTED = False
|
464 |
|
465 |
def llama_chat_sys_input_seq_constructor(text, sys_prompt=SYSTEM_PROMPT_1, bos_token=BOS_TOKEN, eos_token=EOS_TOKEN):
|
@@ -576,8 +774,117 @@ def _setup_stop_events(
|
|
576 |
api_name=False,
|
577 |
queue=False,
|
578 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
579 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
580 |
gr.ChatInterface._setup_stop_events = _setup_stop_events
|
|
|
581 |
|
582 |
def chat_response(message, history, temperature: float, max_tokens: int, system_prompt: str = '') -> str:
|
583 |
global llm
|
@@ -611,7 +918,6 @@ def vllm_abort(self: Any):
|
|
611 |
continue
|
612 |
scheduler.free_seq(seq, SequenceStatus.FINISHED_ABORTED)
|
613 |
|
614 |
-
# def _vllm_run_engine(self: LLM, use_tqdm: bool = False) -> Dict[str, RequestOutput]:
|
615 |
def _vllm_run_engine(self: Any, use_tqdm: bool = False) -> Dict[str, Any]:
|
616 |
from vllm.outputs import RequestOutput
|
617 |
# Initialize tqdm.
|
@@ -624,16 +930,9 @@ def _vllm_run_engine(self: Any, use_tqdm: bool = False) -> Dict[str, Any]:
|
|
624 |
step_outputs = self.llm_engine.step()
|
625 |
for output in step_outputs:
|
626 |
outputs[output.request_id] = output
|
627 |
-
# outputs = sorted(outputs, key=lambda x: int(x.request_id))
|
628 |
if len(outputs) > 0:
|
629 |
yield outputs
|
630 |
-
|
631 |
-
# pbar.close()
|
632 |
-
# Sort the outputs by request ID.
|
633 |
-
# This is necessary because some requests may be finished earlier than
|
634 |
-
# its previous requests.
|
635 |
-
# outputs = sorted(outputs, key=lambda x: int(x.request_id))
|
636 |
-
# return outputs
|
637 |
|
638 |
|
639 |
def vllm_generate_stream(
|
@@ -692,64 +991,47 @@ def vllm_generate_stream(
|
|
692 |
yield from _vllm_run_engine(self, use_tqdm)
|
693 |
|
694 |
|
695 |
-
# def chat_response_stream(
|
696 |
-
# message: str,
|
697 |
-
# history: List[Tuple[str, str]],
|
698 |
-
# temperature: float,
|
699 |
-
# max_tokens: int,
|
700 |
-
# frequency_penalty: float,
|
701 |
-
# system_prompt: str
|
702 |
-
# ) -> str:
|
703 |
-
# global llm, RES_PRINTED
|
704 |
-
# assert llm is not None
|
705 |
-
# # force removing all
|
706 |
-
# vllm_abort(llm)
|
707 |
-
|
708 |
-
# temperature = float(temperature)
|
709 |
-
# frequency_penalty = float(frequency_penalty)
|
710 |
-
# max_tokens = int(max_tokens)
|
711 |
-
# if system_prompt.strip() != '':
|
712 |
-
# # chat version, add system prompt
|
713 |
-
# message = llama_chat_sys_input_seq_constructor(
|
714 |
-
# message.strip(),
|
715 |
-
# sys_prompt=system_prompt
|
716 |
-
# )
|
717 |
-
# sampling_params = SamplingParams(
|
718 |
-
# temperature=temperature, max_tokens=max_tokens,
|
719 |
-
# frequency_penalty=frequency_penalty,
|
720 |
-
# )
|
721 |
-
# cur_out = None
|
722 |
-
# for j, gen in enumerate(vllm_generate_stream(llm, message, sampling_params)):
|
723 |
-
# if cur_out is not None and (STREAM_YIELD_MULTIPLE < 1 or j % STREAM_YIELD_MULTIPLE == 0) and j > 0:
|
724 |
-
# yield cur_out
|
725 |
-
# assert len(gen) == 1, f'{gen}'
|
726 |
-
# item = next(iter(gen.values()))
|
727 |
-
# cur_out = item.outputs[0].text
|
728 |
-
# if not RES_PRINTED:
|
729 |
-
# print(f'{message}<<<{cur_out}>>>')
|
730 |
-
# RES_PRINTED = True
|
731 |
-
# if cur_out is not None:
|
732 |
-
# yield cur_out
|
733 |
-
|
734 |
-
|
735 |
BLOCK_MESSAGE = """Sorry, Chinese is not currently supported. Please clear the chat box for a new conversation.
|
736 |
抱歉,目前不支持中文。 请清除聊天框以进行新对话。"""
|
737 |
|
|
|
|
|
738 |
def block_zh(
|
739 |
message: str,
|
740 |
history: List[Tuple[str, str]]
|
741 |
) -> str:
|
742 |
-
|
743 |
-
if any((BLOCK_MESSAGE in x[1].strip()) for x in history):
|
744 |
return True
|
745 |
elif 'zh' in _detect_lang(message):
|
746 |
print(f'Detect zh: {message}')
|
747 |
return True
|
748 |
-
# ! optionally detect every responses message
|
749 |
else:
|
750 |
return False
|
751 |
|
752 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
753 |
def chat_response_stream_multiturn(
|
754 |
message: str,
|
755 |
history: List[Tuple[str, str]],
|
@@ -779,44 +1061,48 @@ def chat_response_stream_multiturn(
|
|
779 |
|
780 |
message = message.strip()
|
781 |
|
782 |
-
|
783 |
-
|
|
|
|
|
784 |
|
785 |
-
# ! lang detect
|
786 |
-
if BLOCK_ZH:
|
787 |
-
if block_zh(message, history):
|
788 |
-
yield BLOCK_MESSAGE
|
789 |
-
return
|
790 |
-
|
791 |
-
# history.append([message, None])
|
792 |
# history will be appended with message later on
|
793 |
full_prompt = llama_chat_multiturn_sys_input_seq_constructor(
|
794 |
message, history, sys_prompt=system_prompt
|
795 |
)
|
796 |
-
|
797 |
sampling_params = SamplingParams(
|
798 |
temperature=temperature, max_tokens=max_tokens,
|
799 |
frequency_penalty=frequency_penalty,
|
800 |
)
|
801 |
cur_out = None
|
802 |
-
|
803 |
for j, gen in enumerate(vllm_generate_stream(llm, full_prompt, sampling_params)):
|
804 |
if cur_out is not None and (STREAM_YIELD_MULTIPLE < 1 or j % STREAM_YIELD_MULTIPLE == 0) and j > 0:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
805 |
yield cur_out
|
806 |
assert len(gen) == 1, f'{gen}'
|
807 |
item = next(iter(gen.values()))
|
808 |
cur_out = item.outputs[0].text
|
809 |
|
810 |
-
|
811 |
-
print(f'{full_prompt}<<<{cur_out}>>>\n')
|
812 |
-
# RES_PRINTED = True
|
813 |
if cur_out is not None:
|
814 |
yield cur_out
|
815 |
|
816 |
-
|
817 |
-
if
|
818 |
-
|
819 |
-
|
|
|
|
|
|
|
|
|
820 |
|
821 |
|
822 |
def debug_chat_response_echo(
|
@@ -832,44 +1118,6 @@ def debug_chat_response_echo(
|
|
832 |
yield f"repeat: {message}"
|
833 |
|
834 |
|
835 |
-
# ============ CONSTANT ============
|
836 |
-
# https://github.com/gradio-app/gradio/issues/884
|
837 |
-
MODEL_NAME = "SeaL-13B"
|
838 |
-
MODEL_TITLE = "SeaL-13B - An Assistant for South East Asian Languages"
|
839 |
-
# ! add icon: "<img src='file/lion.jpg' alt='image One'>"
|
840 |
-
MODEL_DESC = """
|
841 |
-
<span style="font-size: larger">
|
842 |
-
This is a DAMO SeaL-13B chatbot assistant built by DAMO Academy, Alibaba Group. It can produce helpful responses in English 🇬🇧, Vietnamese 🇻🇳, Indonesian 🇮🇩 and Thai 🇹🇭.
|
843 |
-
</span>
|
844 |
-
""".strip()
|
845 |
-
# <br>
|
846 |
-
|
847 |
-
|
848 |
-
cite_markdown = """
|
849 |
-
## Citation
|
850 |
-
If you find our project useful, hope you can star our repo and cite our paper as follows:
|
851 |
-
```
|
852 |
-
@article{damonlpsg2023seallm,
|
853 |
-
author = {???},
|
854 |
-
title = {SeaL: A language model for South East Asian Languages},
|
855 |
-
year = 2023,
|
856 |
-
}
|
857 |
-
```
|
858 |
-
"""
|
859 |
-
|
860 |
-
warning_markdown = """
|
861 |
-
## Warning:
|
862 |
-
<span style="color: red">The chatbot may produce inaccurate and harmful information about people, places, or facts.</span>
|
863 |
-
<span style="color: red">We strongly advise against misuse of the chatbot to knowingly generate harmful or unethical content, \
|
864 |
-
or content that violates locally applicable and international laws or regulations, including hate speech, violence, pornography, deception, etc!</span>
|
865 |
-
"""
|
866 |
-
|
867 |
-
|
868 |
-
path_markdown = """
|
869 |
-
#### Model path:
|
870 |
-
{model_path}
|
871 |
-
"""
|
872 |
-
|
873 |
def check_model_path(model_path) -> str:
|
874 |
assert os.path.exists(model_path), f'{model_path} not found'
|
875 |
ckpt_info = "None"
|
@@ -903,11 +1151,14 @@ def launch():
|
|
903 |
print(
|
904 |
f'Launch config: {model_title=} / {tensor_parallel=} / {dtype=} / {max_tokens} | {BLOCK_ZH=} '
|
905 |
f'\n| STREAM_YIELD_MULTIPLE={STREAM_YIELD_MULTIPLE} '
|
|
|
906 |
f'\n| frequence_penalty={frequence_penalty} '
|
907 |
f'\n| temperature={temperature} '
|
908 |
f'\n| hf_model_name={hf_model_name} '
|
909 |
f'\n| model_path={model_path} '
|
910 |
f'\n| DOWNLOAD_SNAPSHOT={DOWNLOAD_SNAPSHOT} '
|
|
|
|
|
911 |
f'\nsys={SYSTEM_PROMPT_1}'
|
912 |
f'\ndesc={model_desc}'
|
913 |
)
|
@@ -928,13 +1179,23 @@ def launch():
|
|
928 |
snapshot_download(hf_model_name, local_dir=model_path)
|
929 |
|
930 |
import vllm
|
931 |
-
from vllm import LLM
|
932 |
|
933 |
print(F'VLLM: {vllm.__version__}')
|
934 |
ckpt_info = check_model_path(model_path)
|
935 |
|
936 |
print(f'Load path: {model_path} | {ckpt_info}')
|
937 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
938 |
|
939 |
print(f'Use system prompt:\n{sys_prompt}')
|
940 |
|
@@ -957,16 +1218,17 @@ def launch():
|
|
957 |
stop_btn=None,
|
958 |
title=f"{model_title}",
|
959 |
description=f"{model_desc}",
|
960 |
-
# ! decide if can change the system prompt.
|
961 |
additional_inputs=[
|
962 |
gr.Number(value=temperature, label='Temperature (higher -> more random)'),
|
963 |
gr.Number(value=max_tokens, label='Max generated tokens (increase if want more generation)'),
|
964 |
gr.Number(value=frequence_penalty, label='Frequency penalty (> 0 encourage new tokens)'),
|
|
|
965 |
# gr.Textbox(value=sys_prompt, label='System prompt', lines=8)
|
966 |
],
|
967 |
)
|
|
|
968 |
with demo:
|
969 |
-
gr.Markdown(warning_markdown)
|
970 |
gr.Markdown(cite_markdown)
|
971 |
gr.Markdown(path_markdown.format(model_path=model_path))
|
972 |
|
@@ -981,30 +1243,3 @@ def main():
|
|
981 |
|
982 |
if __name__ == "__main__":
|
983 |
main()
|
984 |
-
|
985 |
-
|
986 |
-
"""
|
987 |
-
|
988 |
-
export CUDA_VISIBLE_DEVICES=0
|
989 |
-
export MODEL_PATH=${dataroot}/hf_train/pretrain_lm/swpn/merlion13s108Hi8kPretFlCW8k.LMFromHf.a.gc.t5k0.vizhthid.mean_std.TrainTask.NLNL.Multi.Vi.FSePlCq13M.FSePlCq13M.m4k.b8.lr1e5.linear.wa0k.ms858k.grac1.se1.8g.v4c.zfsdp/step_4000
|
990 |
-
export MODEL_PATH=${dataroot}/llama-2-7b-lxxp-faster
|
991 |
-
export MODEL_PATH=${dataroot}/llama-2-7b-chat-xp
|
992 |
-
|
993 |
-
export DEBUG=0
|
994 |
-
export CUDA_VISIBLE_DEVICES=0
|
995 |
-
export MODEL_PATH=seal_13b_a
|
996 |
-
export MODEL_PATH=${dataroot}/hf_train/pretrain_lm/swpn/merlion13s108Hi8kPretFlCW12k.LMFromHf.a.gc.t5k0.vizhthid.mean_std.TrainTask.NLNL.Multi.Vi.SeaV2Cq13M.SeaV2Cq13M.m4k.b8.lr1e5.linear.wa0k.ms858k.grac1.se1.8g.v4c.zfsdp/step_6000
|
997 |
-
|
998 |
-
export MODEL_PATH=${dataroot}/hf_train/pretrain_lm/swpn/mer13s108Hi16kPretFlCWNLP12k_SFT2.LMFromHf.a.gc.t5k0.vizhthid.mean_std.TrainTask.NLNL.Multi.Vi.Sft2Censor.Sft2Censor.m4k.b8.lr1e5.linear.wa0k.ms1144k.grac1.se1.6g.v4c.zfsdp/step_4000
|
999 |
-
# 70-30 model
|
1000 |
-
export MODEL_PATH=${dataroot}/hf_train/pretrain_lm/swpn/mer13s108Hi16kPretFlCWNLP12k_SFT2.LMFromHf.a.gc.t5k0.vizhthid.mean_std.TrainTask.NLNL.Multi.BgSft2aCensor0a.BgSft2Cens.BgSft2Cens.m4k.b2.lr1e5.linear.wa0k.ms4577k.grac1.se1.6g.v4c73.zfsdp/step_500
|
1001 |
-
export PORT=8799
|
1002 |
-
export BLOCK_ZH=1
|
1003 |
-
export DEBUG=0
|
1004 |
-
python app.py
|
1005 |
-
|
1006 |
-
|
1007 |
-
DEBUG=1 python app.py
|
1008 |
-
|
1009 |
-
|
1010 |
-
"""
|
|
|
25 |
from huggingface_hub import snapshot_download
|
26 |
|
27 |
|
28 |
+
# @@ environments ================
|
29 |
|
30 |
DEBUG = bool(int(os.environ.get("DEBUG", "1")))
|
31 |
BLOCK_ZH = bool(int(os.environ.get("BLOCK_ZH", "1")))
|
|
|
34 |
|
35 |
# ! (no debug) whether to download HF_MODEL_NAME and save to MODEL_PATH
|
36 |
DOWNLOAD_SNAPSHOT = bool(int(os.environ.get("DOWNLOAD_SNAPSHOT", "0")))
|
37 |
+
LOG_RESPONSE = bool(int(os.environ.get("LOG_RESPONSE", "0")))
|
38 |
+
|
39 |
# ! uploaded model path, will be downloaded to MODEL_PATH
|
40 |
HF_MODEL_NAME = os.environ.get("HF_MODEL_NAME", "DAMO-NLP-SG/seal-13b-chat-a")
|
41 |
+
# ! if model is private, need HF_TOKEN to access the model
|
42 |
HF_TOKEN = os.environ.get("HF_TOKEN", None)
|
43 |
+
# ! path where the model is downloaded, either on ./ or persistent disc
|
44 |
MODEL_PATH = os.environ.get("MODEL_PATH", "./seal-13b-chat-a")
|
45 |
|
46 |
+
# ! list of keywords to disabled as security measures to comply with local regulation
|
47 |
+
KEYWORDS = os.environ.get("KEYWORDS", "").strip()
|
48 |
+
KEYWORDS = KEYWORDS.split(";") if len(KEYWORDS) > 0 else []
|
49 |
+
KEYWORDS = [x.lower() for x in KEYWORDS]
|
50 |
|
51 |
# gradio config
|
52 |
PORT = int(os.environ.get("PORT", "7860"))
|
53 |
+
# how many iterations to yield response
|
54 |
STREAM_YIELD_MULTIPLE = int(os.environ.get("STREAM_YIELD_MULTIPLE", "1"))
|
55 |
+
# how many iterations to perform safety check on response
|
56 |
+
STREAM_CHECK_MULTIPLE = int(os.environ.get("STREAM_CHECK_MULTIPLE", "0"))
|
57 |
+
|
58 |
+
# self explanatory
|
59 |
MAX_TOKENS = int(os.environ.get("MAX_TOKENS", "2048"))
|
60 |
TEMPERATURE = float(os.environ.get("TEMPERATURE", "0.1"))
|
61 |
FREQUENCE_PENALTY = float(os.environ.get("FREQUENCE_PENALTY", "0.4"))
|
62 |
+
gpu_memory_utilization = float(os.environ.get("gpu_memory_utilization", "0.9"))
|
63 |
|
64 |
+
# whether to enable quantization, currently not in use
|
65 |
+
QUANTIZATION = str(os.environ.get("QUANTIZATION", ""))
|
66 |
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
|
68 |
+
"""
|
69 |
+
Internal instructions of how to configure the DEMO
|
70 |
|
71 |
+
1. Upload SFT model as a model to huggingface: hugginface/models/seal_13b_a
|
72 |
+
2. If the model weights is private, set HF_TOKEN=<your private hf token> in https://huggingface.co/spaces/????/?????/settings
|
73 |
+
3. space config env: `HF_MODEL_NAME=DAMO-NLP-SG/seal-13b-chat-a` or the underlining model
|
74 |
+
4. If enable persistent storage: set
|
75 |
HF_HOME=/data/.huggingface
|
76 |
MODEL_PATH=/data/.huggingface/seal-13b-chat-a
|
77 |
+
if not:
|
|
|
78 |
MODEL_PATH=./seal-13b-chat-a
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
|
80 |
"""
|
81 |
|
82 |
|
83 |
+
|
84 |
# ==============================
|
85 |
print(f'DEBUG mode: {DEBUG}')
|
86 |
print(f'Torch version: {torch.__version__}')
|
|
|
89 |
except Exception as e:
|
90 |
print(f'Failed to print cuda version: {e}')
|
91 |
|
92 |
+
try:
|
93 |
+
compute_capability = torch.cuda.get_device_capability()
|
94 |
+
print(f'Torch CUDA compute_capability: {compute_capability}')
|
95 |
+
except Exception as e:
|
96 |
+
print(f'Failed to print compute_capability version: {e}')
|
97 |
|
98 |
|
99 |
# @@ constants ================
|
100 |
|
101 |
+
DTYPES = {
|
102 |
+
'float16': torch.float16,
|
103 |
+
'bfloat16': torch.bfloat16
|
104 |
+
}
|
105 |
+
|
106 |
+
llm = None
|
107 |
+
demo = None
|
108 |
+
|
109 |
+
|
110 |
+
BOS_TOKEN = '<s>'
|
111 |
+
EOS_TOKEN = '</s>'
|
112 |
+
|
113 |
+
B_INST, E_INST = "[INST]", "[/INST]"
|
114 |
+
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
|
115 |
+
|
116 |
+
SYSTEM_PROMPT_1 = """You are a multilingual, helpful, respectful and honest assistant. Your name is SeaL and you are built by DAMO Academy, Alibaba Group. Always answer as helpfully as possible, while being safe. Your \
|
117 |
+
answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure\
|
118 |
+
that your responses are socially unbiased and positive in nature.
|
119 |
+
|
120 |
+
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not \
|
121 |
+
correct. If you don't know the answer to a question, please don't share false information.
|
122 |
+
|
123 |
+
As a multilingual assistant, you must respond and follow instructions in the native language of the user by default, unless told otherwise. \
|
124 |
+
Your response should adapt to the norms and customs of the respective language and culture.
|
125 |
+
"""
|
126 |
+
|
127 |
+
# ============ CONSTANT ============
|
128 |
+
# https://github.com/gradio-app/gradio/issues/884
|
129 |
+
MODEL_NAME = "SeaLLM-13B"
|
130 |
+
MODEL_TITLE = "SeaLLM-13B - An Assistant for South East Asian Languages"
|
131 |
+
# ! add icon: "<img src='file/lion.jpg' alt='image One'>"
|
132 |
+
MODEL_TITLE = """
|
133 |
+
<div class="container" style="
|
134 |
+
align-items: center;
|
135 |
+
justify-content: center;
|
136 |
+
display: flex;
|
137 |
+
">
|
138 |
+
<div class="image" >
|
139 |
+
<img src="file/seal_logo.png" style="
|
140 |
+
max-width: 10em;
|
141 |
+
max-height: 5%;
|
142 |
+
height: 5em;
|
143 |
+
width: 5em;
|
144 |
+
float: left;
|
145 |
+
margin-left: auto;
|
146 |
+
">
|
147 |
+
</div>
|
148 |
+
<div class="text" style="
|
149 |
+
padding-left: 20px;
|
150 |
+
padding-top: 2%;
|
151 |
+
float: left;
|
152 |
+
">
|
153 |
+
<h1>SeaLLM-13B - An Assistant for South East Asian Languages</h1>
|
154 |
+
</div>
|
155 |
+
</div>
|
156 |
+
"""
|
157 |
+
MODEL_DESC = """
|
158 |
+
<span style="font-size: larger">
|
159 |
+
This is SeaLLM-13B - a chatbot assistant optimized for South East Asian Languages. It can produce helpful responses in English 🇬🇧, Vietnamese 🇻🇳, Indonesian 🇮🇩 and Thai 🇹🇭.
|
160 |
+
</span>
|
161 |
+
<br>
|
162 |
+
<span style="color: red">NOTICE: The chatbot may produce inaccurate and harmful information about people, places, or facts. \
|
163 |
+
We strongly advise against misuse of the chatbot to knowingly generate harmful or unethical content, \
|
164 |
+
or content that violates locally applicable and international laws or regulations, including hate speech, violence, pornography, deception, etc!</span>
|
165 |
+
""".strip()
|
166 |
+
|
167 |
+
|
168 |
+
cite_markdown = """
|
169 |
+
## Citation
|
170 |
+
If you find our project useful, hope you can star our repo and cite our paper as follows:
|
171 |
+
```
|
172 |
+
@article{damonlpsg2023seallm,
|
173 |
+
author = {???},
|
174 |
+
title = {SeaLLM: A language model for South East Asian Languages},
|
175 |
+
year = 2023,
|
176 |
+
}
|
177 |
+
```
|
178 |
+
"""
|
179 |
+
|
180 |
+
# warning_markdown = """
|
181 |
+
# ## Warning:
|
182 |
+
# <span style="color: red">The chatbot may produce inaccurate and harmful information about people, places, or facts.</span>
|
183 |
+
# <span style="color: red">We strongly advise against misuse of the chatbot to knowingly generate harmful or unethical content, \
|
184 |
+
# or content that violates locally applicable and international laws or regulations, including hate speech, violence, pornography, deception, etc!</span>
|
185 |
+
# """
|
186 |
+
|
187 |
+
path_markdown = """
|
188 |
+
#### Model path:
|
189 |
+
{model_path}
|
190 |
+
"""
|
191 |
|
192 |
|
193 |
def _detect_lang(text):
|
194 |
from langdetect import detect as detect_lang
|
|
|
195 |
dlang = None
|
196 |
try:
|
197 |
dlang = detect_lang(text)
|
|
|
205 |
return dlang
|
206 |
|
207 |
|
208 |
+
def custom_hf_model_weights_iterator(
|
209 |
model_name_or_path: str,
|
210 |
cache_dir: Optional[str] = None,
|
211 |
use_np_cache: bool = False,
|
212 |
) -> Iterator[Tuple[str, torch.Tensor]]:
|
213 |
+
# ! if use vllm==0.1.4, use this to augment hf_model_weights_iterator loader
|
214 |
from vllm.model_executor.weight_utils import Disabledtqdm
|
215 |
# Prepare file lock directory to prevent multiple processes from
|
216 |
# downloading the same model weights at the same time.
|
|
|
231 |
hf_folder = model_name_or_path
|
232 |
|
233 |
hf_bin_files = [
|
|
|
234 |
x for x in glob.glob(os.path.join(hf_folder, "*model*.bin"))
|
235 |
if not x.endswith("training_args.bin")
|
236 |
]
|
|
|
323 |
cache_dir: Optional[str] = None,
|
324 |
use_np_cache: bool = False,
|
325 |
load_format: str = "auto",
|
|
|
326 |
revision: Optional[str] = None
|
327 |
):
|
328 |
+
# if use vllm==0.1.4
|
329 |
from vllm.model_executor.weight_utils import (
|
330 |
load_tensor_parallel_weights
|
331 |
)
|
|
|
348 |
state_dict = self.state_dict()
|
349 |
need_to_load = len(state_dict)
|
350 |
loaded = 0
|
351 |
+
iterator = custom_hf_model_weights_iterator(model_name_or_path, cache_dir, use_np_cache)
|
352 |
|
353 |
for name, loaded_weight in iterator:
|
354 |
if "rotary_emb.inv_freq" in name:
|
|
|
418 |
loaded_weight[v_offsets[0]:v_offsets[1]],
|
419 |
], 0
|
420 |
)
|
|
|
421 |
assert param.shape == _loaded_weight.shape, f'{param.shape=} != {_loaded_weight.shape=}'
|
422 |
param.data.copy_(_loaded_weight)
|
423 |
loaded += 1.0
|
|
|
484 |
print(f'Loaded all {loaded} params loaded out of {need_to_load}')
|
485 |
|
486 |
|
487 |
+
def new_llama_load_weights(
|
488 |
+
self,
|
489 |
+
model_name_or_path: str,
|
490 |
+
cache_dir: Optional[str] = None,
|
491 |
+
load_format: str = "auto",
|
492 |
+
revision: Optional[str] = None
|
493 |
+
):
|
494 |
+
# If use newest vllm
|
495 |
+
from vllm.model_executor.weight_utils import (
|
496 |
+
load_tensor_parallel_weights, hf_model_weights_iterator
|
497 |
+
)
|
498 |
+
from vllm.model_executor.parallel_utils.parallel_state import (
|
499 |
+
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
500 |
+
|
501 |
+
if self.quant_config is None:
|
502 |
+
weight_suffixes = ["weight"]
|
503 |
+
else:
|
504 |
+
weight_suffixes = self.quant_config.get_tp_tensor_names()
|
505 |
+
|
506 |
+
column_parallel_weights: List[str] = []
|
507 |
+
for layer in self._column_parallel_layers:
|
508 |
+
for suffix in weight_suffixes:
|
509 |
+
column_parallel_weights.append(f"{layer}.{suffix}")
|
510 |
+
row_parallel_weights: List[str] = []
|
511 |
+
for layer in self._row_parallel_layers:
|
512 |
+
for suffix in weight_suffixes:
|
513 |
+
row_parallel_weights.append(f"{layer}.{suffix}")
|
514 |
+
|
515 |
+
tp_size = get_tensor_model_parallel_world_size()
|
516 |
+
tp_rank = get_tensor_model_parallel_rank()
|
517 |
+
assert tp_size == 1, f'tensorparallel >=2 not allowed. {tp_size}'
|
518 |
+
q_proj_shard_size = (self.config.hidden_size // tp_size)
|
519 |
+
num_kv_heads_replicas = max(1,
|
520 |
+
tp_size // self.config.num_key_value_heads)
|
521 |
+
num_kv_heads_per_gpu = max(1,
|
522 |
+
self.config.num_key_value_heads // tp_size)
|
523 |
+
kv_proj_shard_size = (self.config.hidden_size //
|
524 |
+
self.config.num_attention_heads *
|
525 |
+
num_kv_heads_per_gpu)
|
526 |
+
attention_weight_specs = [
|
527 |
+
# (weight_name, shard_size, offset)
|
528 |
+
("q_proj", q_proj_shard_size, 0),
|
529 |
+
("k_proj", kv_proj_shard_size, q_proj_shard_size),
|
530 |
+
("v_proj", kv_proj_shard_size,
|
531 |
+
q_proj_shard_size + kv_proj_shard_size),
|
532 |
+
]
|
533 |
+
state_dict = self.state_dict()
|
534 |
+
need_to_load = len(state_dict)
|
535 |
+
loaded = 0
|
536 |
+
|
537 |
+
for name, loaded_weight in hf_model_weights_iterator(
|
538 |
+
model_name_or_path, cache_dir, load_format, revision):
|
539 |
+
if "rotary_emb.inv_freq" in name:
|
540 |
+
continue
|
541 |
+
|
542 |
+
is_packed = False
|
543 |
+
is_transposed = False
|
544 |
+
if self.quant_config is not None:
|
545 |
+
is_packed = self.quant_config.is_packed(name)
|
546 |
+
is_transposed = self.quant_config.is_transposed(name)
|
547 |
+
if is_transposed:
|
548 |
+
loaded_weight = convert_pyslice_to_tensor(loaded_weight)
|
549 |
+
loaded_weight = loaded_weight.T
|
550 |
+
|
551 |
+
is_attention_weight = False
|
552 |
+
for weight_name, shard_size, offset in attention_weight_specs:
|
553 |
+
if weight_name not in name or "qkv_proj" in name:
|
554 |
+
continue
|
555 |
+
param = state_dict[name.replace(weight_name, "qkv_proj")]
|
556 |
+
if is_transposed:
|
557 |
+
param = param.T
|
558 |
+
|
559 |
+
if is_packed:
|
560 |
+
shard_size //= self.quant_config.pack_factor
|
561 |
+
offset //= self.quant_config.pack_factor
|
562 |
+
|
563 |
+
if weight_name in ["k_proj", "v_proj"]:
|
564 |
+
shard_id = tp_rank // num_kv_heads_replicas
|
565 |
+
else:
|
566 |
+
shard_id = tp_rank
|
567 |
+
loaded_weight = loaded_weight[shard_size *
|
568 |
+
shard_id:shard_size *
|
569 |
+
(shard_id + 1)]
|
570 |
+
param_slice = param.data[offset:offset + shard_size]
|
571 |
+
assert param_slice.shape == loaded_weight.shape
|
572 |
+
|
573 |
+
param_slice.copy_(loaded_weight)
|
574 |
+
loaded += 1.0 / 3
|
575 |
+
is_attention_weight = True
|
576 |
+
break
|
577 |
+
if is_attention_weight:
|
578 |
+
continue
|
579 |
+
|
580 |
+
# TODO: need to figure out to do sharding with qkv_proj fused
|
581 |
+
|
582 |
+
is_gate_up_weight = False
|
583 |
+
for stride_id, weight_name in enumerate(["gate_proj", "up_proj"]):
|
584 |
+
if weight_name not in name or "gate_up_proj" in name:
|
585 |
+
continue
|
586 |
+
param = state_dict[name.replace(weight_name, "gate_up_proj")]
|
587 |
+
if is_transposed:
|
588 |
+
param = param.T
|
589 |
+
|
590 |
+
shard_size = param.shape[0] // 2
|
591 |
+
loaded_weight = loaded_weight[shard_size * tp_rank:shard_size *
|
592 |
+
(tp_rank + 1)]
|
593 |
+
param_slice = param.data[shard_size * stride_id:shard_size *
|
594 |
+
(stride_id + 1)]
|
595 |
+
assert param_slice.shape == loaded_weight.shape
|
596 |
+
param_slice.copy_(loaded_weight)
|
597 |
+
loaded += 1.0 / 2
|
598 |
+
is_gate_up_weight = True
|
599 |
+
break
|
600 |
+
if is_gate_up_weight:
|
601 |
+
continue
|
602 |
+
|
603 |
+
# TODO: need to figure out to do sharding with gate_up_proj fused
|
604 |
+
|
605 |
+
param = state_dict[name]
|
606 |
+
if is_transposed:
|
607 |
+
param = param.T
|
608 |
+
|
609 |
+
if "embed_tokens" in name or "lm_head" in name:
|
610 |
+
load_padded_tensor_parallel_vocab(param, loaded_weight,
|
611 |
+
tp_rank)
|
612 |
+
loaded += 1
|
613 |
+
continue
|
614 |
+
|
615 |
+
load_tensor_parallel_weights(param, loaded_weight, name,
|
616 |
+
column_parallel_weights,
|
617 |
+
row_parallel_weights, tp_rank)
|
618 |
+
loaded += 1
|
619 |
+
|
620 |
+
if np.abs(loaded - need_to_load) < 0.01:
|
621 |
+
print(f'WARNING: only {loaded} params loaded out of {need_to_load}')
|
622 |
+
else:
|
623 |
+
print(f'Loaded all {loaded} params loaded out of {need_to_load}')
|
624 |
+
|
625 |
+
|
626 |
# Reassign LlamaForCausalLM.load_weights with llama_load_weights
|
627 |
if not DEBUG:
|
628 |
|
|
|
|
|
|
|
629 |
try:
|
630 |
import vllm
|
631 |
from vllm.model_executor.model_loader import _MODEL_REGISTRY
|
632 |
from vllm.model_executor.models import LlamaForCausalLM
|
633 |
|
634 |
_MODEL_REGISTRY['FasterLlamaForCausalLM'] = LlamaForCausalLM
|
635 |
+
if vllm.__version__ == "0.1.4":
|
636 |
+
LlamaForCausalLM.load_weights = llama_load_weights
|
637 |
+
else:
|
638 |
+
LlamaForCausalLM.load_weights = new_llama_load_weights
|
639 |
|
640 |
if DTYPE == "bfloat16":
|
641 |
try:
|
|
|
658 |
set_documentation_group("component")
|
659 |
|
660 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
661 |
RES_PRINTED = False
|
662 |
|
663 |
def llama_chat_sys_input_seq_constructor(text, sys_prompt=SYSTEM_PROMPT_1, bos_token=BOS_TOKEN, eos_token=EOS_TOKEN):
|
|
|
774 |
api_name=False,
|
775 |
queue=False,
|
776 |
)
|
777 |
+
# upon clear, cancel the submit event as well
|
778 |
+
if self.clear_btn:
|
779 |
+
self.clear_btn.click(
|
780 |
+
lambda: ([], [], None, Button.update(interactive=True)),
|
781 |
+
None,
|
782 |
+
[self.chatbot, self.chatbot_state, self.saved_input, self.submit_btn],
|
783 |
+
queue=False,
|
784 |
+
api_name=False,
|
785 |
+
cancels=event_to_cancel,
|
786 |
+
)
|
787 |
+
|
788 |
+
# TODO: reconfigure clear button as stop and clear button
|
789 |
+
def _setup_events(self) -> None:
|
790 |
+
has_on = False
|
791 |
+
try:
|
792 |
+
from gradio.events import Dependency, EventListenerMethod, on
|
793 |
+
has_on = True
|
794 |
+
except ImportError as ie:
|
795 |
+
has_on = False
|
796 |
+
submit_fn = self._stream_fn if self.is_generator else self._submit_fn
|
797 |
+
|
798 |
+
|
799 |
+
if has_on:
|
800 |
+
# new version
|
801 |
+
submit_triggers = (
|
802 |
+
[self.textbox.submit, self.submit_btn.click]
|
803 |
+
if self.submit_btn
|
804 |
+
else [self.textbox.submit]
|
805 |
+
)
|
806 |
+
submit_event = (
|
807 |
+
on(
|
808 |
+
submit_triggers,
|
809 |
+
self._clear_and_save_textbox,
|
810 |
+
[self.textbox],
|
811 |
+
[self.textbox, self.saved_input],
|
812 |
+
api_name=False,
|
813 |
+
queue=False,
|
814 |
+
)
|
815 |
+
.then(
|
816 |
+
self._display_input,
|
817 |
+
[self.saved_input, self.chatbot_state],
|
818 |
+
[self.chatbot, self.chatbot_state],
|
819 |
+
api_name=False,
|
820 |
+
queue=False,
|
821 |
+
)
|
822 |
+
.then(
|
823 |
+
submit_fn,
|
824 |
+
[self.saved_input, self.chatbot_state] + self.additional_inputs,
|
825 |
+
[self.chatbot, self.chatbot_state],
|
826 |
+
api_name=False,
|
827 |
+
)
|
828 |
+
)
|
829 |
+
self._setup_stop_events(submit_triggers, submit_event)
|
830 |
+
else:
|
831 |
+
raise ValueError(f'Better install new gradio version than 3.44.0')
|
832 |
+
|
833 |
+
if self.retry_btn:
|
834 |
+
retry_event = (
|
835 |
+
self.retry_btn.click(
|
836 |
+
self._delete_prev_fn,
|
837 |
+
[self.chatbot_state],
|
838 |
+
[self.chatbot, self.saved_input, self.chatbot_state],
|
839 |
+
api_name=False,
|
840 |
+
queue=False,
|
841 |
+
)
|
842 |
+
.then(
|
843 |
+
self._display_input,
|
844 |
+
[self.saved_input, self.chatbot_state],
|
845 |
+
[self.chatbot, self.chatbot_state],
|
846 |
+
api_name=False,
|
847 |
+
queue=False,
|
848 |
+
)
|
849 |
+
.then(
|
850 |
+
submit_fn,
|
851 |
+
[self.saved_input, self.chatbot_state] + self.additional_inputs,
|
852 |
+
[self.chatbot, self.chatbot_state],
|
853 |
+
api_name=False,
|
854 |
+
)
|
855 |
+
)
|
856 |
+
self._setup_stop_events([self.retry_btn.click], retry_event)
|
857 |
+
|
858 |
+
if self.undo_btn:
|
859 |
+
self.undo_btn.click(
|
860 |
+
self._delete_prev_fn,
|
861 |
+
[self.chatbot_state],
|
862 |
+
[self.chatbot, self.saved_input, self.chatbot_state],
|
863 |
+
api_name=False,
|
864 |
+
queue=False,
|
865 |
+
).then(
|
866 |
+
lambda x: x,
|
867 |
+
[self.saved_input],
|
868 |
+
[self.textbox],
|
869 |
+
api_name=False,
|
870 |
+
queue=False,
|
871 |
+
)
|
872 |
|
873 |
+
# Reconfigure clear_btn to stop and clear text box
|
874 |
+
# if self.clear_btn:
|
875 |
+
# self.clear_btn.click(
|
876 |
+
# lambda: ([], [], None),
|
877 |
+
# None,
|
878 |
+
# [self.chatbot, self.chatbot_state, self.saved_input],
|
879 |
+
# queue=False,
|
880 |
+
# api_name=False,
|
881 |
+
# cancels=submit_event,
|
882 |
+
# )
|
883 |
+
|
884 |
+
|
885 |
+
# replace
|
886 |
gr.ChatInterface._setup_stop_events = _setup_stop_events
|
887 |
+
gr.ChatInterface._setup_events = _setup_events
|
888 |
|
889 |
def chat_response(message, history, temperature: float, max_tokens: int, system_prompt: str = '') -> str:
|
890 |
global llm
|
|
|
918 |
continue
|
919 |
scheduler.free_seq(seq, SequenceStatus.FINISHED_ABORTED)
|
920 |
|
|
|
921 |
def _vllm_run_engine(self: Any, use_tqdm: bool = False) -> Dict[str, Any]:
|
922 |
from vllm.outputs import RequestOutput
|
923 |
# Initialize tqdm.
|
|
|
930 |
step_outputs = self.llm_engine.step()
|
931 |
for output in step_outputs:
|
932 |
outputs[output.request_id] = output
|
|
|
933 |
if len(outputs) > 0:
|
934 |
yield outputs
|
935 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
936 |
|
937 |
|
938 |
def vllm_generate_stream(
|
|
|
991 |
yield from _vllm_run_engine(self, use_tqdm)
|
992 |
|
993 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
994 |
BLOCK_MESSAGE = """Sorry, Chinese is not currently supported. Please clear the chat box for a new conversation.
|
995 |
抱歉,目前不支持中文。 请清除聊天框以进行新对话。"""
|
996 |
|
997 |
+
KEYWORD_BLOCK_MESSAGE = "Sorry, I cannot fulfill your request. If you have any unrelated questions, I'll be glad to help."
|
998 |
+
|
999 |
def block_zh(
|
1000 |
message: str,
|
1001 |
history: List[Tuple[str, str]]
|
1002 |
) -> str:
|
1003 |
+
if history is not None and any((BLOCK_MESSAGE in x[1].strip()) for x in history):
|
|
|
1004 |
return True
|
1005 |
elif 'zh' in _detect_lang(message):
|
1006 |
print(f'Detect zh: {message}')
|
1007 |
return True
|
|
|
1008 |
else:
|
1009 |
return False
|
1010 |
|
1011 |
+
|
1012 |
+
def log_responses(history, message, response):
|
1013 |
+
pass
|
1014 |
+
|
1015 |
+
|
1016 |
+
def safety_check(text, history=None, ) -> Optional[str]:
|
1017 |
+
"""
|
1018 |
+
Despite our effort in safety tuning and red teaming, our models may still generate harmful or illegal content.
|
1019 |
+
This provides an additional security measure to enhance safety and compliance with local regulations.
|
1020 |
+
"""
|
1021 |
+
if BLOCK_ZH:
|
1022 |
+
if history is not None:
|
1023 |
+
if block_zh(text, history):
|
1024 |
+
return BLOCK_MESSAGE
|
1025 |
+
else:
|
1026 |
+
if "zh" in _detect_lang(text):
|
1027 |
+
return BLOCK_MESSAGE
|
1028 |
+
|
1029 |
+
if len(KEYWORDS) > 0 and any(x in text.lower() for x in KEYWORDS):
|
1030 |
+
return KEYWORD_BLOCK_MESSAGE
|
1031 |
+
|
1032 |
+
return None
|
1033 |
+
|
1034 |
+
|
1035 |
def chat_response_stream_multiturn(
|
1036 |
message: str,
|
1037 |
history: List[Tuple[str, str]],
|
|
|
1061 |
|
1062 |
message = message.strip()
|
1063 |
|
1064 |
+
message_safety = safety_check(message, history=history)
|
1065 |
+
if message_safety is not None:
|
1066 |
+
yield message_safety
|
1067 |
+
return
|
1068 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1069 |
# history will be appended with message later on
|
1070 |
full_prompt = llama_chat_multiturn_sys_input_seq_constructor(
|
1071 |
message, history, sys_prompt=system_prompt
|
1072 |
)
|
1073 |
+
|
1074 |
sampling_params = SamplingParams(
|
1075 |
temperature=temperature, max_tokens=max_tokens,
|
1076 |
frequency_penalty=frequency_penalty,
|
1077 |
)
|
1078 |
cur_out = None
|
1079 |
+
|
1080 |
for j, gen in enumerate(vllm_generate_stream(llm, full_prompt, sampling_params)):
|
1081 |
if cur_out is not None and (STREAM_YIELD_MULTIPLE < 1 or j % STREAM_YIELD_MULTIPLE == 0) and j > 0:
|
1082 |
+
# optionally check safety, and respond
|
1083 |
+
if STREAM_CHECK_MULTIPLE > 0 and j % STREAM_CHECK_MULTIPLE == 0:
|
1084 |
+
message_safety = safety_check(cur_out, history=None)
|
1085 |
+
if message_safety is not None:
|
1086 |
+
yield message_safety
|
1087 |
+
return
|
1088 |
+
|
1089 |
yield cur_out
|
1090 |
assert len(gen) == 1, f'{gen}'
|
1091 |
item = next(iter(gen.values()))
|
1092 |
cur_out = item.outputs[0].text
|
1093 |
|
1094 |
+
print(f'{full_prompt}<<<{cur_out}>>>\n\n')
|
|
|
|
|
1095 |
if cur_out is not None:
|
1096 |
yield cur_out
|
1097 |
|
1098 |
+
message_safety = safety_check(cur_out, history=None)
|
1099 |
+
if message_safety is not None:
|
1100 |
+
yield message_safety
|
1101 |
+
return
|
1102 |
+
|
1103 |
+
if LOG_RESPONSE:
|
1104 |
+
log_responses(history, message, cur_out)
|
1105 |
+
|
1106 |
|
1107 |
|
1108 |
def debug_chat_response_echo(
|
|
|
1118 |
yield f"repeat: {message}"
|
1119 |
|
1120 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1121 |
def check_model_path(model_path) -> str:
|
1122 |
assert os.path.exists(model_path), f'{model_path} not found'
|
1123 |
ckpt_info = "None"
|
|
|
1151 |
print(
|
1152 |
f'Launch config: {model_title=} / {tensor_parallel=} / {dtype=} / {max_tokens} | {BLOCK_ZH=} '
|
1153 |
f'\n| STREAM_YIELD_MULTIPLE={STREAM_YIELD_MULTIPLE} '
|
1154 |
+
f'\n| STREAM_CHECK_MULTIPLE={STREAM_CHECK_MULTIPLE} '
|
1155 |
f'\n| frequence_penalty={frequence_penalty} '
|
1156 |
f'\n| temperature={temperature} '
|
1157 |
f'\n| hf_model_name={hf_model_name} '
|
1158 |
f'\n| model_path={model_path} '
|
1159 |
f'\n| DOWNLOAD_SNAPSHOT={DOWNLOAD_SNAPSHOT} '
|
1160 |
+
f'\n| gpu_memory_utilization={gpu_memory_utilization} '
|
1161 |
+
f'\n| KEYWORDS={KEYWORDS} '
|
1162 |
f'\nsys={SYSTEM_PROMPT_1}'
|
1163 |
f'\ndesc={model_desc}'
|
1164 |
)
|
|
|
1179 |
snapshot_download(hf_model_name, local_dir=model_path)
|
1180 |
|
1181 |
import vllm
|
1182 |
+
from vllm import LLM
|
1183 |
|
1184 |
print(F'VLLM: {vllm.__version__}')
|
1185 |
ckpt_info = check_model_path(model_path)
|
1186 |
|
1187 |
print(f'Load path: {model_path} | {ckpt_info}')
|
1188 |
+
|
1189 |
+
if QUANTIZATION == 'awq':
|
1190 |
+
print(F'Load model in int4 quantization')
|
1191 |
+
llm = LLM(model=model_path, dtype=dtype, tensor_parallel_size=tensor_parallel, gpu_memory_utilization=gpu_memory_utilization, quantization="awq")
|
1192 |
+
else:
|
1193 |
+
llm = LLM(model=model_path, dtype=dtype, tensor_parallel_size=tensor_parallel, gpu_memory_utilization=gpu_memory_utilization)
|
1194 |
+
|
1195 |
+
try:
|
1196 |
+
print(llm.llm_engine.workers[0].model)
|
1197 |
+
except Exception as e:
|
1198 |
+
print(f'Cannot print model worker: {e}')
|
1199 |
|
1200 |
print(f'Use system prompt:\n{sys_prompt}')
|
1201 |
|
|
|
1218 |
stop_btn=None,
|
1219 |
title=f"{model_title}",
|
1220 |
description=f"{model_desc}",
|
|
|
1221 |
additional_inputs=[
|
1222 |
gr.Number(value=temperature, label='Temperature (higher -> more random)'),
|
1223 |
gr.Number(value=max_tokens, label='Max generated tokens (increase if want more generation)'),
|
1224 |
gr.Number(value=frequence_penalty, label='Frequency penalty (> 0 encourage new tokens)'),
|
1225 |
+
# ! Remove the system prompt textbox to avoid jailbreaking
|
1226 |
# gr.Textbox(value=sys_prompt, label='System prompt', lines=8)
|
1227 |
],
|
1228 |
)
|
1229 |
+
demo.title = MODEL_NAME
|
1230 |
with demo:
|
1231 |
+
# gr.Markdown(warning_markdown)
|
1232 |
gr.Markdown(cite_markdown)
|
1233 |
gr.Markdown(path_markdown.format(model_path=model_path))
|
1234 |
|
|
|
1243 |
|
1244 |
if __name__ == "__main__":
|
1245 |
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|