Spaces:
Running
on
Zero
Running
on
Zero
NGUYEN, Xuan Phi
commited on
Commit
•
203c3cd
1
Parent(s):
1e1d244
initial comit
Browse files- .gitignore +167 -0
- app.py +747 -0
- requirements.txt +26 -0
.gitignore
ADDED
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
share/python-wheels/
|
24 |
+
*.egg-info/
|
25 |
+
.installed.cfg
|
26 |
+
*.egg
|
27 |
+
MANIFEST
|
28 |
+
|
29 |
+
# PyInstaller
|
30 |
+
# Usually these files are written by a python script from a template
|
31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
32 |
+
*.manifest
|
33 |
+
*.spec
|
34 |
+
|
35 |
+
# Installer logs
|
36 |
+
pip-log.txt
|
37 |
+
pip-delete-this-directory.txt
|
38 |
+
|
39 |
+
# Unit test / coverage reports
|
40 |
+
htmlcov/
|
41 |
+
.tox/
|
42 |
+
.nox/
|
43 |
+
.coverage
|
44 |
+
.coverage.*
|
45 |
+
.cache
|
46 |
+
nosetests.xml
|
47 |
+
coverage.xml
|
48 |
+
*.cover
|
49 |
+
*.py,cover
|
50 |
+
.hypothesis/
|
51 |
+
.pytest_cache/
|
52 |
+
cover/
|
53 |
+
|
54 |
+
# Translations
|
55 |
+
*.mo
|
56 |
+
*.pot
|
57 |
+
|
58 |
+
# Django stuff:
|
59 |
+
*.log
|
60 |
+
local_settings.py
|
61 |
+
db.sqlite3
|
62 |
+
db.sqlite3-journal
|
63 |
+
|
64 |
+
# Flask stuff:
|
65 |
+
instance/
|
66 |
+
.webassets-cache
|
67 |
+
|
68 |
+
# Scrapy stuff:
|
69 |
+
.scrapy
|
70 |
+
|
71 |
+
# Sphinx documentation
|
72 |
+
docs/_build/
|
73 |
+
|
74 |
+
# PyBuilder
|
75 |
+
.pybuilder/
|
76 |
+
target/
|
77 |
+
|
78 |
+
# Jupyter Notebook
|
79 |
+
.ipynb_checkpoints
|
80 |
+
|
81 |
+
# IPython
|
82 |
+
profile_default/
|
83 |
+
ipython_config.py
|
84 |
+
|
85 |
+
# pyenv
|
86 |
+
# For a library or package, you might want to ignore these files since the code is
|
87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
88 |
+
# .python-version
|
89 |
+
|
90 |
+
# pipenv
|
91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
94 |
+
# install all needed dependencies.
|
95 |
+
#Pipfile.lock
|
96 |
+
|
97 |
+
# poetry
|
98 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
99 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
100 |
+
# commonly ignored for libraries.
|
101 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
102 |
+
#poetry.lock
|
103 |
+
|
104 |
+
# pdm
|
105 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
106 |
+
#pdm.lock
|
107 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
108 |
+
# in version control.
|
109 |
+
# https://pdm.fming.dev/#use-with-ide
|
110 |
+
.pdm.toml
|
111 |
+
|
112 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
113 |
+
__pypackages__/
|
114 |
+
|
115 |
+
# Celery stuff
|
116 |
+
celerybeat-schedule
|
117 |
+
celerybeat.pid
|
118 |
+
|
119 |
+
# SageMath parsed files
|
120 |
+
*.sage.py
|
121 |
+
|
122 |
+
# Environments
|
123 |
+
.env
|
124 |
+
.venv
|
125 |
+
env/
|
126 |
+
venv/
|
127 |
+
ENV/
|
128 |
+
env.bak/
|
129 |
+
venv.bak/
|
130 |
+
|
131 |
+
# Spyder project settings
|
132 |
+
.spyderproject
|
133 |
+
.spyproject
|
134 |
+
|
135 |
+
# Rope project settings
|
136 |
+
.ropeproject
|
137 |
+
|
138 |
+
# mkdocs documentation
|
139 |
+
/site
|
140 |
+
|
141 |
+
# mypy
|
142 |
+
.mypy_cache/
|
143 |
+
.dmypy.json
|
144 |
+
dmypy.json
|
145 |
+
|
146 |
+
# Pyre type checker
|
147 |
+
.pyre/
|
148 |
+
|
149 |
+
# pytype static type analyzer
|
150 |
+
.pytype/
|
151 |
+
|
152 |
+
# Cython debug symbols
|
153 |
+
cython_debug/
|
154 |
+
|
155 |
+
# PyCharm
|
156 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
157 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
158 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
159 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
160 |
+
#.idea/
|
161 |
+
|
162 |
+
# ruff
|
163 |
+
.ruff_cache
|
164 |
+
|
165 |
+
.vscode
|
166 |
+
|
167 |
+
core*
|
app.py
ADDED
@@ -0,0 +1,747 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# A mirror to gradio launch stream
|
2 |
+
# By Xuan Phi Nguyen at DAMO Academy, Alibaba Group
|
3 |
+
|
4 |
+
"""
|
5 |
+
Load FasterLlama with original VLLM codebase,
|
6 |
+
|
7 |
+
require changing config names to LlamaForCausalLM
|
8 |
+
|
9 |
+
tensor_parallel must == 1
|
10 |
+
|
11 |
+
"""
|
12 |
+
|
13 |
+
import torch
|
14 |
+
import os
|
15 |
+
import numpy as np
|
16 |
+
import argparse
|
17 |
+
from vllm import LLM, SamplingParams
|
18 |
+
import gradio as gr
|
19 |
+
from gradio_client.documentation import document, set_documentation_group
|
20 |
+
|
21 |
+
from typing import List, Optional, Union, Dict, Tuple
|
22 |
+
|
23 |
+
from tqdm import tqdm
|
24 |
+
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
|
25 |
+
|
26 |
+
from vllm.engine.arg_utils import EngineArgs
|
27 |
+
from vllm.engine.llm_engine import LLMEngine
|
28 |
+
from vllm.outputs import RequestOutput
|
29 |
+
from vllm.sampling_params import SamplingParams
|
30 |
+
from vllm.utils import Counter
|
31 |
+
from vllm.sequence import (Sequence, SequenceData, SequenceGroup,
|
32 |
+
SequenceGroupMetadata, SequenceOutputs,
|
33 |
+
SequenceStatus)
|
34 |
+
|
35 |
+
# ! reconfigure vllm to faster llama
|
36 |
+
from typing import Any, Iterator
|
37 |
+
from typing import Iterator, List, Optional, Tuple
|
38 |
+
import filelock
|
39 |
+
import glob
|
40 |
+
import json
|
41 |
+
import os
|
42 |
+
from huggingface_hub import snapshot_download
|
43 |
+
|
44 |
+
from tqdm.auto import tqdm
|
45 |
+
|
46 |
+
from vllm.model_executor.model_loader import _MODEL_REGISTRY
|
47 |
+
from vllm.model_executor.models import LlamaForCausalLM
|
48 |
+
|
49 |
+
_MODEL_REGISTRY['FasterLlamaForCausalLM'] = LlamaForCausalLM
|
50 |
+
|
51 |
+
|
52 |
+
def hf_model_weights_iterator(
|
53 |
+
model_name_or_path: str,
|
54 |
+
cache_dir: Optional[str] = None,
|
55 |
+
use_np_cache: bool = False,
|
56 |
+
) -> Iterator[Tuple[str, torch.Tensor]]:
|
57 |
+
from vllm.model_executor.weight_utils import Disabledtqdm
|
58 |
+
# Prepare file lock directory to prevent multiple processes from
|
59 |
+
# downloading the same model weights at the same time.
|
60 |
+
lock_dir = cache_dir if cache_dir is not None else "/tmp"
|
61 |
+
lock_file_name = model_name_or_path.replace("/", "-") + ".lock"
|
62 |
+
lock = filelock.FileLock(os.path.join(lock_dir, lock_file_name))
|
63 |
+
|
64 |
+
# Download model weights from huggingface.
|
65 |
+
is_local = os.path.isdir(model_name_or_path)
|
66 |
+
if not is_local:
|
67 |
+
with lock:
|
68 |
+
hf_folder = snapshot_download(model_name_or_path,
|
69 |
+
allow_patterns="*.bin",
|
70 |
+
cache_dir=cache_dir,
|
71 |
+
local_files_only=True,
|
72 |
+
tqdm_class=Disabledtqdm)
|
73 |
+
else:
|
74 |
+
hf_folder = model_name_or_path
|
75 |
+
|
76 |
+
hf_bin_files = [
|
77 |
+
# x for x in glob.glob(os.path.join(hf_folder, "*.bin"))
|
78 |
+
x for x in glob.glob(os.path.join(hf_folder, "*model*.bin"))
|
79 |
+
if not x.endswith("training_args.bin")
|
80 |
+
]
|
81 |
+
hf_safetensors_files = [
|
82 |
+
x for x in glob.glob(os.path.join(hf_folder, "*model*.safetensors"))
|
83 |
+
if not x.endswith("training_args.bin")
|
84 |
+
]
|
85 |
+
# print(F'Load bin files: {hf_bin_files} // safetensors: {hf_safetensors_files}')
|
86 |
+
|
87 |
+
if use_np_cache:
|
88 |
+
# Convert the model weights from torch tensors to numpy arrays for
|
89 |
+
# faster loading.
|
90 |
+
np_folder = os.path.join(hf_folder, "np")
|
91 |
+
os.makedirs(np_folder, exist_ok=True)
|
92 |
+
weight_names_file = os.path.join(np_folder, "weight_names.json")
|
93 |
+
with lock:
|
94 |
+
if not os.path.exists(weight_names_file):
|
95 |
+
weight_names = []
|
96 |
+
for bin_file in hf_bin_files:
|
97 |
+
state = torch.load(bin_file, map_location="cpu")
|
98 |
+
for name, param in state.items():
|
99 |
+
param_path = os.path.join(np_folder, name)
|
100 |
+
with open(param_path, "wb") as f:
|
101 |
+
np.save(f, param.cpu().detach().numpy())
|
102 |
+
weight_names.append(name)
|
103 |
+
with open(weight_names_file, "w") as f:
|
104 |
+
json.dump(weight_names, f)
|
105 |
+
|
106 |
+
with open(weight_names_file, "r") as f:
|
107 |
+
weight_names = json.load(f)
|
108 |
+
|
109 |
+
for name in weight_names:
|
110 |
+
param_path = os.path.join(np_folder, name)
|
111 |
+
with open(param_path, "rb") as f:
|
112 |
+
param = np.load(f)
|
113 |
+
yield name, torch.from_numpy(param)
|
114 |
+
else:
|
115 |
+
if len(hf_bin_files) > 0:
|
116 |
+
print(F'Load bin files: {hf_bin_files}')
|
117 |
+
for bin_file in hf_bin_files:
|
118 |
+
state = torch.load(bin_file, map_location="cpu")
|
119 |
+
for name, param in state.items():
|
120 |
+
yield name, param
|
121 |
+
del state
|
122 |
+
torch.cuda.empty_cache()
|
123 |
+
elif len(hf_safetensors_files) > 0:
|
124 |
+
print(F'Load safetensor files: {hf_safetensors_files}')
|
125 |
+
from safetensors.torch import load_file
|
126 |
+
for safe_file in hf_safetensors_files:
|
127 |
+
# state = torch.load(bin_file, map_location="cpu")
|
128 |
+
state = load_file(safe_file)
|
129 |
+
for name, param in state.items():
|
130 |
+
yield name, param
|
131 |
+
del state
|
132 |
+
torch.cuda.empty_cache()
|
133 |
+
else:
|
134 |
+
raise ValueError(f'no files available either bin or safe')
|
135 |
+
|
136 |
+
|
137 |
+
def convert_pyslice_to_tensor(x: Any) -> torch.Tensor:
|
138 |
+
"""convert PySafeSlice object from safetensors to torch.Tensor
|
139 |
+
|
140 |
+
PySafeSlice object supports indexing, which is done before loading the
|
141 |
+
actual tensor and can reduce the amount of memory being read into the
|
142 |
+
memory. However, it does not support more advanced functionalities
|
143 |
+
like `.view()` or `.t()`. Therefore, if we need to modify the loaded
|
144 |
+
tensor with these more complicated operators, we need to convert to
|
145 |
+
tensor first.
|
146 |
+
"""
|
147 |
+
if not isinstance(x, torch.Tensor):
|
148 |
+
x = x[:]
|
149 |
+
return x
|
150 |
+
|
151 |
+
|
152 |
+
def load_padded_tensor_parallel_vocab(
|
153 |
+
param: torch.Tensor,
|
154 |
+
loaded_weight: Any, # `torch.Tensor` or `PySafeSlice`
|
155 |
+
tensor_model_parallel_rank: int,
|
156 |
+
) -> None:
|
157 |
+
shard_size = param.shape[0]
|
158 |
+
start_idx = tensor_model_parallel_rank * shard_size
|
159 |
+
end_idx = (tensor_model_parallel_rank + 1) * shard_size
|
160 |
+
loaded_weight = loaded_weight[start_idx:end_idx]
|
161 |
+
loaded_weight = convert_pyslice_to_tensor(loaded_weight)
|
162 |
+
param[:loaded_weight.shape[0]].copy_(loaded_weight)
|
163 |
+
|
164 |
+
|
165 |
+
def llama_load_weights(
|
166 |
+
self,
|
167 |
+
model_name_or_path: str,
|
168 |
+
cache_dir: Optional[str] = None,
|
169 |
+
use_np_cache: bool = False,
|
170 |
+
load_format: str = "auto",
|
171 |
+
# load_format: str = "pt",
|
172 |
+
revision: Optional[str] = None
|
173 |
+
):
|
174 |
+
from vllm.model_executor.weight_utils import (
|
175 |
+
load_tensor_parallel_weights
|
176 |
+
)
|
177 |
+
from vllm.model_executor.parallel_utils.parallel_state import (
|
178 |
+
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
179 |
+
tp_size = get_tensor_model_parallel_world_size()
|
180 |
+
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
181 |
+
|
182 |
+
q_proj_shard_size = (self.config.hidden_size // tp_size)
|
183 |
+
kv_proj_shard_size = (self.config.hidden_size //
|
184 |
+
self.config.num_attention_heads *
|
185 |
+
getattr(self.config, "num_key_value_heads", self.config.num_attention_heads) // tp_size)
|
186 |
+
attention_weight_specs = [
|
187 |
+
# (weight_name, shard_size, offset)
|
188 |
+
("q_proj", q_proj_shard_size, 0),
|
189 |
+
("k_proj", kv_proj_shard_size, q_proj_shard_size),
|
190 |
+
("v_proj", kv_proj_shard_size,
|
191 |
+
q_proj_shard_size + kv_proj_shard_size),
|
192 |
+
]
|
193 |
+
state_dict = self.state_dict()
|
194 |
+
need_to_load = len(state_dict)
|
195 |
+
loaded = 0
|
196 |
+
# try:
|
197 |
+
# iterator = hf_model_weights_iterator(model_name_or_path, cache_dir, use_np_cache)
|
198 |
+
# except Exception as e:
|
199 |
+
# iterator = hf_model_weights_iterator(model_name_or_path, cache_dir, load_format, revision)
|
200 |
+
iterator = hf_model_weights_iterator(model_name_or_path, cache_dir, use_np_cache)
|
201 |
+
|
202 |
+
# for name, loaded_weight in hf_model_weights_iterator(
|
203 |
+
# model_name_or_path, cache_dir, load_format, revision):
|
204 |
+
# model_name_or_path, cache_dir, use_np_cache):
|
205 |
+
for name, loaded_weight in iterator:
|
206 |
+
if "rotary_emb.inv_freq" in name:
|
207 |
+
continue
|
208 |
+
|
209 |
+
# if "embed_tokens" in name or "lm_head" in name:
|
210 |
+
# param = state_dict[name]
|
211 |
+
# # Consider padding in the vocab size.
|
212 |
+
# padded_vocab_size = (param.shape[0] * tp_size)
|
213 |
+
# # num_extra_rows = padded_vocab_size - self.config.vocab_size
|
214 |
+
# num_extra_rows = padded_vocab_size - loaded_weight.size(0)
|
215 |
+
# load_size = loaded_weight.size()
|
216 |
+
# extra_rows = torch.empty(num_extra_rows,
|
217 |
+
# loaded_weight.shape[1])
|
218 |
+
# extra_rows = extra_rows.to(loaded_weight)
|
219 |
+
# loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0)
|
220 |
+
# if num_extra_rows > 0:
|
221 |
+
# print(f'Add empty to {num_extra_rows} extra row for {name}')
|
222 |
+
# print(f'Load: {name} | {padded_vocab_size=} | {self.config.vocab_size=} | {num_extra_rows=} | {param.size()=} | {loaded_weight.size()=} | {load_size=}')
|
223 |
+
|
224 |
+
if "embed_tokens" in name or "lm_head" in name:
|
225 |
+
param = state_dict[name]
|
226 |
+
load_padded_tensor_parallel_vocab(param, loaded_weight, tensor_model_parallel_rank)
|
227 |
+
loaded += 1
|
228 |
+
continue
|
229 |
+
|
230 |
+
is_attention_weight = False
|
231 |
+
for weight_name, shard_size, offset in attention_weight_specs:
|
232 |
+
if weight_name not in name or "qkv_proj" in name:
|
233 |
+
continue
|
234 |
+
param = state_dict[name.replace(weight_name, "qkv_proj")]
|
235 |
+
|
236 |
+
loaded_weight = loaded_weight[
|
237 |
+
shard_size * tensor_model_parallel_rank:shard_size *
|
238 |
+
(tensor_model_parallel_rank + 1)]
|
239 |
+
param_slice = param.data[offset:offset + shard_size]
|
240 |
+
assert param_slice.shape == loaded_weight.shape
|
241 |
+
|
242 |
+
param_slice.copy_(loaded_weight)
|
243 |
+
loaded += 1.0 / 3
|
244 |
+
is_attention_weight = True
|
245 |
+
break
|
246 |
+
if is_attention_weight:
|
247 |
+
continue
|
248 |
+
|
249 |
+
# ! qkv_proj is sharded differently if concatenated into qkv
|
250 |
+
# qkv: qqqq kkkk vvvv
|
251 |
+
# lweight: qq0qq1 kk0kk1 vv0vv1
|
252 |
+
# q_shard_size: hidden_size // tp_size = qq
|
253 |
+
# qkv_s0: qq0_kk0_vv0
|
254 |
+
# qkv_s1: qq1_kk1_vv1
|
255 |
+
if "qkv_proj" in name:
|
256 |
+
param = state_dict[name]
|
257 |
+
# loaded_weight
|
258 |
+
qsize = self.config.hidden_size
|
259 |
+
kvsize = self.config.hidden_size // self.config.num_attention_heads * getattr(self.config, "num_key_value_heads", self.config.num_attention_heads)
|
260 |
+
q_offsets = (
|
261 |
+
q_proj_shard_size * tensor_model_parallel_rank,
|
262 |
+
q_proj_shard_size * (tensor_model_parallel_rank + 1)
|
263 |
+
)
|
264 |
+
k_offsets = (
|
265 |
+
qsize + kv_proj_shard_size * tensor_model_parallel_rank,
|
266 |
+
qsize + kv_proj_shard_size * (tensor_model_parallel_rank + 1)
|
267 |
+
)
|
268 |
+
v_offsets = (
|
269 |
+
qsize + kvsize + kv_proj_shard_size * tensor_model_parallel_rank,
|
270 |
+
qsize + kvsize + kv_proj_shard_size * (tensor_model_parallel_rank + 1)
|
271 |
+
)
|
272 |
+
_loaded_weight = torch.cat(
|
273 |
+
[
|
274 |
+
loaded_weight[q_offsets[0]:q_offsets[1]],
|
275 |
+
loaded_weight[k_offsets[0]:k_offsets[1]],
|
276 |
+
loaded_weight[v_offsets[0]:v_offsets[1]],
|
277 |
+
], 0
|
278 |
+
)
|
279 |
+
# print(f'{name} | {q_offsets} | {k_offsets} | {v_offsets}')
|
280 |
+
assert param.shape == _loaded_weight.shape, f'{param.shape=} != {_loaded_weight.shape=}'
|
281 |
+
param.data.copy_(_loaded_weight)
|
282 |
+
loaded += 1.0
|
283 |
+
is_attention_weight = True
|
284 |
+
if is_attention_weight:
|
285 |
+
continue
|
286 |
+
|
287 |
+
|
288 |
+
is_gate_up_weight = False
|
289 |
+
for stride_id, weight_name in enumerate(["gate_proj", "up_proj"]):
|
290 |
+
if weight_name not in name or "gate_up_proj" in name:
|
291 |
+
continue
|
292 |
+
param = state_dict[name.replace(weight_name, "gate_up_proj")]
|
293 |
+
shard_size = param.shape[0] // 2
|
294 |
+
loaded_weight = loaded_weight[
|
295 |
+
shard_size * tensor_model_parallel_rank:shard_size *
|
296 |
+
(tensor_model_parallel_rank + 1)]
|
297 |
+
param_slice = param.data[shard_size * stride_id:shard_size *
|
298 |
+
(stride_id + 1)]
|
299 |
+
assert param_slice.shape == loaded_weight.shape
|
300 |
+
param_slice.copy_(loaded_weight)
|
301 |
+
loaded += 1.0 / 2
|
302 |
+
is_gate_up_weight = True
|
303 |
+
break
|
304 |
+
if is_gate_up_weight:
|
305 |
+
continue
|
306 |
+
|
307 |
+
if "gate_up_proj" in name:
|
308 |
+
param = state_dict[name]
|
309 |
+
shard_size = param.shape[0] // 2
|
310 |
+
intermediate_size = self.config.intermediate_size
|
311 |
+
g_offsets = (
|
312 |
+
shard_size * tensor_model_parallel_rank,
|
313 |
+
shard_size * (tensor_model_parallel_rank + 1)
|
314 |
+
)
|
315 |
+
u_offsets = (
|
316 |
+
intermediate_size + shard_size * tensor_model_parallel_rank,
|
317 |
+
intermediate_size + shard_size * (tensor_model_parallel_rank + 1)
|
318 |
+
)
|
319 |
+
# print(f'{name} {param.size()} | {g_offsets} | {u_offsets}')
|
320 |
+
_loaded_weight = torch.cat(
|
321 |
+
[
|
322 |
+
loaded_weight[g_offsets[0]:g_offsets[1]],
|
323 |
+
loaded_weight[u_offsets[0]:u_offsets[1]],
|
324 |
+
], 0
|
325 |
+
)
|
326 |
+
assert param.shape == _loaded_weight.shape
|
327 |
+
param.data.copy_(_loaded_weight)
|
328 |
+
loaded += 1.0
|
329 |
+
is_gate_up_weight = True
|
330 |
+
if is_gate_up_weight:
|
331 |
+
continue
|
332 |
+
|
333 |
+
|
334 |
+
param = state_dict[name]
|
335 |
+
load_tensor_parallel_weights(param, loaded_weight, name,
|
336 |
+
self._column_parallel_weights,
|
337 |
+
self._row_parallel_weights,
|
338 |
+
tensor_model_parallel_rank)
|
339 |
+
loaded += 1
|
340 |
+
|
341 |
+
if np.abs(loaded - need_to_load) < 0.01:
|
342 |
+
print(f'WARNING: only {loaded} params loaded out of {need_to_load}')
|
343 |
+
else:
|
344 |
+
print(f'Loaded all {loaded} params loaded out of {need_to_load}')
|
345 |
+
|
346 |
+
|
347 |
+
# Reassign LlamaForCausalLM.load_weights with llama_load_weights
|
348 |
+
LlamaForCausalLM.load_weights = llama_load_weights
|
349 |
+
|
350 |
+
# ! ==================================================================
|
351 |
+
|
352 |
+
set_documentation_group("component")
|
353 |
+
|
354 |
+
DATA_ROOT = os.environ.get("dataroot", "/mnt/workspace/workgroup/phi")
|
355 |
+
MODEL_CACHE_DIR = os.path.join(DATA_ROOT, "pret_models")
|
356 |
+
|
357 |
+
|
358 |
+
DTYPES = {
|
359 |
+
'float16': torch.float16,
|
360 |
+
'bfloat16': torch.bfloat16
|
361 |
+
}
|
362 |
+
|
363 |
+
llm = None
|
364 |
+
demo = None
|
365 |
+
|
366 |
+
RELOAD_SIGNAL = '<<<reload:'
|
367 |
+
|
368 |
+
BOS_TOKEN = '<s>'
|
369 |
+
EOS_TOKEN = '</s>'
|
370 |
+
|
371 |
+
B_INST, E_INST = "[INST]", "[/INST]"
|
372 |
+
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
|
373 |
+
|
374 |
+
SYSTEM_PROMPT_1 = """You are a multilingual, helpful, respectful and honest assistant. Your name is SeaL and you are built by DAMO Academy, Alibaba Group. Always answer as helpfully as possible, while being safe. Your \
|
375 |
+
answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure\
|
376 |
+
that your responses are socially unbiased and positive in nature.
|
377 |
+
|
378 |
+
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not \
|
379 |
+
correct. If you don't know the answer to a question, please don't share false information.
|
380 |
+
|
381 |
+
As a multilingual assistant, you must respond and follow instructions in the native language of the user by default, unless told otherwise. \
|
382 |
+
Your response should adapt to the norms and customs of the respective language and culture.
|
383 |
+
"""
|
384 |
+
|
385 |
+
RES_PRINTED = False
|
386 |
+
|
387 |
+
def llama_chat_sys_input_seq_constructor(text, sys_prompt=SYSTEM_PROMPT_1, bos_token=BOS_TOKEN, eos_token=EOS_TOKEN):
|
388 |
+
return f"{bos_token}{B_INST} {B_SYS} {sys_prompt} {E_SYS} {text} {E_INST}"
|
389 |
+
|
390 |
+
|
391 |
+
def llama_chat_multiturn_sys_input_seq_constructor(
|
392 |
+
message: str,
|
393 |
+
history: List[Tuple[str, str]],
|
394 |
+
sys_prompt=SYSTEM_PROMPT_1,
|
395 |
+
bos_token=BOS_TOKEN,
|
396 |
+
eos_token=EOS_TOKEN,
|
397 |
+
):
|
398 |
+
"""
|
399 |
+
```
|
400 |
+
<bos>[INST] B_SYS SytemPrompt E_SYS Prompt [/INST] Answer <eos>
|
401 |
+
<bos>[INST] Prompt [/INST] Answer <eos>
|
402 |
+
<bos>[INST] Prompt [/INST]
|
403 |
+
```
|
404 |
+
"""
|
405 |
+
text = ''
|
406 |
+
for i, (prompt, res) in enumerate(history):
|
407 |
+
if i == 0:
|
408 |
+
text += f"{bos_token}{B_INST} {B_SYS} {sys_prompt} {E_SYS} {prompt} {E_INST}"
|
409 |
+
else:
|
410 |
+
text += f"{bos_token}{B_INST} {prompt} {E_INST}"
|
411 |
+
|
412 |
+
if res is not None:
|
413 |
+
text += f" {res} {eos_token} "
|
414 |
+
if len(history) == 0 or text.strip() == '':
|
415 |
+
text = f"{bos_token}{B_INST} {B_SYS} {sys_prompt} {E_SYS} {message} {E_INST}"
|
416 |
+
else:
|
417 |
+
text += f"{bos_token}{B_INST} {message} {E_INST}"
|
418 |
+
return text
|
419 |
+
|
420 |
+
|
421 |
+
@document()
|
422 |
+
class ChatBot(gr.Chatbot):
|
423 |
+
def _postprocess_chat_messages(
|
424 |
+
self, chat_message
|
425 |
+
):
|
426 |
+
x = super()._postprocess_chat_messages(chat_message)
|
427 |
+
if isinstance(x, str):
|
428 |
+
x = x.replace("\n", "<br>")
|
429 |
+
return x
|
430 |
+
|
431 |
+
|
432 |
+
def load_ckpt(ckpt_file: str) -> str:
|
433 |
+
global llm
|
434 |
+
status = "Failed"
|
435 |
+
if not os.path.exists(ckpt_file):
|
436 |
+
status = f"Failed - file not found: {ckpt_file}"
|
437 |
+
elif not ckpt_file.endswith(".bin"):
|
438 |
+
status = f"Failed - file not .bin: {ckpt_file}"
|
439 |
+
else:
|
440 |
+
try:
|
441 |
+
state_dict = torch.load(ckpt_file, map_location='cpu')
|
442 |
+
print(f'loaded state_dict: {ckpt_file}')
|
443 |
+
llm.llm_engine.workers[0].model.load_state_dict(state_dict)
|
444 |
+
status = f'Success. Loaded {ckpt_file}'
|
445 |
+
except Exception as e:
|
446 |
+
status = f'Failed - {str(e)}'
|
447 |
+
return status
|
448 |
+
|
449 |
+
|
450 |
+
|
451 |
+
def chat_response(message, history, temperature: float, max_tokens: int, system_prompt: str = '') -> str:
|
452 |
+
global llm
|
453 |
+
assert llm is not None
|
454 |
+
temperature = float(temperature)
|
455 |
+
max_tokens = int(max_tokens)
|
456 |
+
if system_prompt.strip() != '':
|
457 |
+
# chat version, add system prompt
|
458 |
+
message = llama_chat_sys_input_seq_constructor(
|
459 |
+
message.strip(),
|
460 |
+
sys_prompt=system_prompt
|
461 |
+
)
|
462 |
+
|
463 |
+
sampling_params = SamplingParams(temperature=temperature, max_tokens=max_tokens)
|
464 |
+
gen = llm.generate(message, sampling_params)
|
465 |
+
out = gen[0].outputs[0].text
|
466 |
+
# print(f'{message}<<<{out}>>>')
|
467 |
+
return f'{out}'
|
468 |
+
|
469 |
+
|
470 |
+
def vllm_abort(self: LLM):
|
471 |
+
scheduler = self.llm_engine.scheduler
|
472 |
+
for state_queue in [scheduler.waiting, scheduler.running, scheduler.swapped]:
|
473 |
+
for seq_group in state_queue:
|
474 |
+
# if seq_group.request_id == request_id:
|
475 |
+
# Remove the sequence group from the state queue.
|
476 |
+
state_queue.remove(seq_group)
|
477 |
+
for seq in seq_group.seqs:
|
478 |
+
if seq.is_finished():
|
479 |
+
continue
|
480 |
+
scheduler.free_seq(seq, SequenceStatus.FINISHED_ABORTED)
|
481 |
+
|
482 |
+
def _vllm_run_engine(self: LLM, use_tqdm: bool = False) -> Dict[str, RequestOutput]:
|
483 |
+
# Initialize tqdm.
|
484 |
+
if use_tqdm:
|
485 |
+
num_requests = self.llm_engine.get_num_unfinished_requests()
|
486 |
+
pbar = tqdm(total=num_requests, desc="Processed prompts")
|
487 |
+
# Run the engine.
|
488 |
+
outputs: Dict[str, RequestOutput] = {}
|
489 |
+
while self.llm_engine.has_unfinished_requests():
|
490 |
+
step_outputs = self.llm_engine.step()
|
491 |
+
for output in step_outputs:
|
492 |
+
# if output.finished:
|
493 |
+
# outputs.append(output)
|
494 |
+
# if use_tqdm:
|
495 |
+
# pbar.update(1)
|
496 |
+
outputs[output.request_id] = output
|
497 |
+
# outputs = sorted(outputs, key=lambda x: int(x.request_id))
|
498 |
+
if len(outputs) > 0:
|
499 |
+
yield outputs
|
500 |
+
# if use_tqdm:
|
501 |
+
# pbar.close()
|
502 |
+
# Sort the outputs by request ID.
|
503 |
+
# This is necessary because some requests may be finished earlier than
|
504 |
+
# its previous requests.
|
505 |
+
# outputs = sorted(outputs, key=lambda x: int(x.request_id))
|
506 |
+
# return outputs
|
507 |
+
|
508 |
+
|
509 |
+
def vllm_generate_stream(
|
510 |
+
self: LLM,
|
511 |
+
prompts: Optional[Union[str, List[str]]] = None,
|
512 |
+
sampling_params: Optional[SamplingParams] = None,
|
513 |
+
prompt_token_ids: Optional[List[List[int]]] = None,
|
514 |
+
use_tqdm: bool = False,
|
515 |
+
) -> Dict[str, RequestOutput]:
|
516 |
+
"""Generates the completions for the input prompts.
|
517 |
+
|
518 |
+
NOTE: This class automatically batches the given prompts, considering
|
519 |
+
the memory constraint. For the best performance, put all of your prompts
|
520 |
+
into a single list and pass it to this method.
|
521 |
+
|
522 |
+
Args:
|
523 |
+
prompts: A list of prompts to generate completions for.
|
524 |
+
sampling_params: The sampling parameters for text generation. If
|
525 |
+
None, we use the default sampling parameters.
|
526 |
+
prompt_token_ids: A list of token IDs for the prompts. If None, we
|
527 |
+
use the tokenizer to convert the prompts to token IDs.
|
528 |
+
use_tqdm: Whether to use tqdm to display the progress bar.
|
529 |
+
|
530 |
+
Returns:
|
531 |
+
A list of `RequestOutput` objects containing the generated
|
532 |
+
completions in the same order as the input prompts.
|
533 |
+
"""
|
534 |
+
if prompts is None and prompt_token_ids is None:
|
535 |
+
raise ValueError("Either prompts or prompt_token_ids must be "
|
536 |
+
"provided.")
|
537 |
+
if isinstance(prompts, str):
|
538 |
+
# Convert a single prompt to a list.
|
539 |
+
prompts = [prompts]
|
540 |
+
if prompts is not None and prompt_token_ids is not None:
|
541 |
+
if len(prompts) != len(prompt_token_ids):
|
542 |
+
raise ValueError("The lengths of prompts and prompt_token_ids "
|
543 |
+
"must be the same.")
|
544 |
+
if sampling_params is None:
|
545 |
+
# Use default sampling params.
|
546 |
+
sampling_params = SamplingParams()
|
547 |
+
|
548 |
+
# Add requests to the engine.
|
549 |
+
if prompts is not None:
|
550 |
+
num_requests = len(prompts)
|
551 |
+
else:
|
552 |
+
num_requests = len(prompt_token_ids)
|
553 |
+
for i in range(num_requests):
|
554 |
+
prompt = prompts[i] if prompts is not None else None
|
555 |
+
if prompt_token_ids is None:
|
556 |
+
token_ids = None
|
557 |
+
else:
|
558 |
+
token_ids = prompt_token_ids[i]
|
559 |
+
self._add_request(prompt, sampling_params, token_ids)
|
560 |
+
# return self._run_engine(use_tqdm)
|
561 |
+
yield from _vllm_run_engine(self, use_tqdm)
|
562 |
+
|
563 |
+
|
564 |
+
def chat_response_stream(
|
565 |
+
message: str,
|
566 |
+
history: List[Tuple[str, str]],
|
567 |
+
temperature: float,
|
568 |
+
max_tokens: int,
|
569 |
+
frequency_penalty: float,
|
570 |
+
system_prompt: str
|
571 |
+
) -> str:
|
572 |
+
global llm, RES_PRINTED
|
573 |
+
assert llm is not None
|
574 |
+
# force removing all
|
575 |
+
vllm_abort(llm)
|
576 |
+
|
577 |
+
temperature = float(temperature)
|
578 |
+
frequency_penalty = float(frequency_penalty)
|
579 |
+
max_tokens = int(max_tokens)
|
580 |
+
if system_prompt.strip() != '':
|
581 |
+
# chat version, add system prompt
|
582 |
+
message = llama_chat_sys_input_seq_constructor(
|
583 |
+
message.strip(),
|
584 |
+
sys_prompt=system_prompt
|
585 |
+
)
|
586 |
+
sampling_params = SamplingParams(
|
587 |
+
temperature=temperature, max_tokens=max_tokens,
|
588 |
+
frequency_penalty=frequency_penalty,
|
589 |
+
)
|
590 |
+
cur_out = None
|
591 |
+
for gen in vllm_generate_stream(llm, message, sampling_params):
|
592 |
+
if cur_out is not None:
|
593 |
+
yield cur_out
|
594 |
+
assert len(gen) == 1, f'{gen}'
|
595 |
+
item = next(iter(gen.values()))
|
596 |
+
cur_out = item.outputs[0].text
|
597 |
+
if not RES_PRINTED:
|
598 |
+
print(f'{message}<<<{cur_out}>>>')
|
599 |
+
RES_PRINTED = True
|
600 |
+
if cur_out is not None:
|
601 |
+
yield cur_out
|
602 |
+
|
603 |
+
|
604 |
+
def chat_response_stream_multiturn(
|
605 |
+
message: str,
|
606 |
+
history: List[Tuple[str, str]],
|
607 |
+
temperature: float,
|
608 |
+
max_tokens: int,
|
609 |
+
frequency_penalty: float,
|
610 |
+
system_prompt: str
|
611 |
+
) -> str:
|
612 |
+
"""Build multi turn
|
613 |
+
<bos>[INST] B_SYS SytemPrompt E_SYS Prompt [/INST] Answer <eos>
|
614 |
+
<bos>[INST] Prompt [/INST] Answer <eos>
|
615 |
+
<bos>[INST] Prompt [/INST]
|
616 |
+
|
617 |
+
message is incoming prompt
|
618 |
+
history don't have the current messauge
|
619 |
+
"""
|
620 |
+
global llm, RES_PRINTED
|
621 |
+
assert llm is not None
|
622 |
+
assert system_prompt.strip() != '', f'system prompt is empty'
|
623 |
+
# force removing all
|
624 |
+
vllm_abort(llm)
|
625 |
+
|
626 |
+
temperature = float(temperature)
|
627 |
+
frequency_penalty = float(frequency_penalty)
|
628 |
+
max_tokens = int(max_tokens)
|
629 |
+
|
630 |
+
# history.append([message, None])
|
631 |
+
# history will be appended with message later on
|
632 |
+
full_prompt = llama_chat_multiturn_sys_input_seq_constructor(
|
633 |
+
message, history, sys_prompt=system_prompt
|
634 |
+
)
|
635 |
+
sampling_params = SamplingParams(
|
636 |
+
temperature=temperature, max_tokens=max_tokens,
|
637 |
+
frequency_penalty=frequency_penalty,
|
638 |
+
)
|
639 |
+
cur_out = None
|
640 |
+
for gen in vllm_generate_stream(llm, full_prompt, sampling_params):
|
641 |
+
if cur_out is not None:
|
642 |
+
yield cur_out
|
643 |
+
assert len(gen) == 1, f'{gen}'
|
644 |
+
item = next(iter(gen.values()))
|
645 |
+
cur_out = item.outputs[0].text
|
646 |
+
if not RES_PRINTED:
|
647 |
+
print(f'{full_prompt}<<<{cur_out}>>>')
|
648 |
+
RES_PRINTED = True
|
649 |
+
if cur_out is not None:
|
650 |
+
yield cur_out
|
651 |
+
|
652 |
+
|
653 |
+
def debug_chat_response_echo(
|
654 |
+
message: str,
|
655 |
+
history: List[Tuple[str, str]],
|
656 |
+
temperature: float = 0.0,
|
657 |
+
max_tokens: int = 4096,
|
658 |
+
frequency_penalty: float = 0.4,
|
659 |
+
system_prompt: str = SYSTEM_PROMPT_1,
|
660 |
+
) -> str:
|
661 |
+
yield message
|
662 |
+
|
663 |
+
|
664 |
+
MODEL_TITLE = "DAMO-SeaL-13B - An Assistant for South East Asian Languages"
|
665 |
+
MODEL_DESC = """
|
666 |
+
This is a 13B DAMO-SeaL-Chat assistant model built by DAMO Academy, Alibaba Group. It can produce helpful responses in English, Vietnamese, Indonesian and Thai.
|
667 |
+
""".strip()
|
668 |
+
|
669 |
+
TENSOR_PARALLEL = int(os.environ.get("TENSOR_PARALLEL", "1"))
|
670 |
+
DTYPE = 'bfloat16'
|
671 |
+
DTYPE = 'float16'
|
672 |
+
|
673 |
+
MODEL_PATH = os.environ.get("MODEL_PATH", "notfound, please set `export MODEL_PATH=`")
|
674 |
+
|
675 |
+
DEBUG = 1
|
676 |
+
|
677 |
+
def launch():
|
678 |
+
global demo, llm, DEBUG
|
679 |
+
if DEBUG:
|
680 |
+
model_desc + "<br>This is in debug mode, responses will be copy original"
|
681 |
+
response_fn = debug_chat_response_echo
|
682 |
+
else:
|
683 |
+
model_desc = MODEL_DESC
|
684 |
+
model_path = MODEL_PATH
|
685 |
+
assert os.path.exists(model_path), f'{model_path} not found'
|
686 |
+
model_title = MODEL_TITLE
|
687 |
+
tensor_parallel = TENSOR_PARALLEL
|
688 |
+
assert tensor_parallel > 0 , f'{tensor_parallel} invalid'
|
689 |
+
dtype = DTYPE
|
690 |
+
|
691 |
+
# ! load the model
|
692 |
+
llm = LLM(model=model_path, dtype=dtype, tensor_parallel_size=tensor_parallel)
|
693 |
+
|
694 |
+
sys_prompt = SYSTEM_PROMPT_1
|
695 |
+
max_tokens = 4096
|
696 |
+
print(f'Use system prompt:\n{sys_prompt}')
|
697 |
+
|
698 |
+
# response_fn = chat_response_stream_multiturn if args.multiturn else chat_response_stream
|
699 |
+
response_fn = chat_response_stream_multiturn
|
700 |
+
print(F'respond: {response_fn}')
|
701 |
+
|
702 |
+
demo = gr.ChatInterface(
|
703 |
+
response_fn,
|
704 |
+
chatbot=ChatBot(
|
705 |
+
bubble_full_width=False,
|
706 |
+
latex_delimiters=[
|
707 |
+
{ "left": "$", "right": "$", "display": False},
|
708 |
+
{ "left": "$$", "right": "$$", "display": True},
|
709 |
+
]
|
710 |
+
),
|
711 |
+
textbox=gr.Textbox(placeholder='Type message', lines=8, max_lines=128, min_width=200),
|
712 |
+
submit_btn=gr.Button(value='Submit', variant="primary", scale=0),
|
713 |
+
# stop_btn=None,
|
714 |
+
title=f"{model_title}",
|
715 |
+
description=f"{model_desc}",
|
716 |
+
# ! decide if can change the system prompt.
|
717 |
+
additional_inputs=[
|
718 |
+
gr.Number(value=0, label='Temperature (higher -> more random)'),
|
719 |
+
gr.Number(value=max_tokens, label='Max generated tokens (increase if want more generation)'),
|
720 |
+
gr.Number(value=0.4, label='Frequency penalty (> 0 encourage new tokens)'),
|
721 |
+
gr.Textbox(value=sys_prompt, label='System prompt', lines=8)],
|
722 |
+
)
|
723 |
+
demo.queue()
|
724 |
+
# demo.launch(server_port=args.port)
|
725 |
+
demo.launch()
|
726 |
+
|
727 |
+
|
728 |
+
def main():
|
729 |
+
|
730 |
+
# launch(parser.parse_args())
|
731 |
+
launch()
|
732 |
+
|
733 |
+
|
734 |
+
if __name__ == "__main__":
|
735 |
+
main()
|
736 |
+
|
737 |
+
|
738 |
+
"""
|
739 |
+
|
740 |
+
export CUDA_VISIBLE_DEVICES=0
|
741 |
+
export MODEL_PATH=${dataroot}/hf_train/pretrain_lm/swpn/merlion13s108Hi8kPretFlCW8k.LMFromHf.a.gc.t5k0.vizhthid.mean_std.TrainTask.NLNL.Multi.Vi.FSePlCq13M.FSePlCq13M.m4k.b8.lr1e5.linear.wa0k.ms858k.grac1.se1.8g.v4c.zfsdp/step_4000
|
742 |
+
export MODEL_PATH=${dataroot}/llama-2-7b-lxxp-faster
|
743 |
+
export MODEL_PATH=${dataroot}/llama-2-7b-chat-xp
|
744 |
+
python app.py
|
745 |
+
|
746 |
+
|
747 |
+
"""
|
requirements.txt
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
deepspeed
|
2 |
+
sentencepiece
|
3 |
+
accelerate
|
4 |
+
evaluate
|
5 |
+
datasets
|
6 |
+
sacrebleu
|
7 |
+
websockets
|
8 |
+
fire
|
9 |
+
indic-nlp-library
|
10 |
+
omegaconf
|
11 |
+
scikit-learn
|
12 |
+
jiwer
|
13 |
+
tenacity
|
14 |
+
pynvml
|
15 |
+
rouge_score
|
16 |
+
ninja
|
17 |
+
ray
|
18 |
+
psutil
|
19 |
+
xformers >= 0.0.19
|
20 |
+
fastapi
|
21 |
+
tensorboard
|
22 |
+
geomloss
|
23 |
+
einops
|
24 |
+
gdown
|
25 |
+
vllm==0.1.4
|
26 |
+
transformers
|