NGUYEN, Xuan Phi commited on
Commit
203c3cd
·
1 Parent(s): 1e1d244

initial comit

Browse files
Files changed (3) hide show
  1. .gitignore +167 -0
  2. app.py +747 -0
  3. requirements.txt +26 -0
.gitignore ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/#use-with-ide
110
+ .pdm.toml
111
+
112
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113
+ __pypackages__/
114
+
115
+ # Celery stuff
116
+ celerybeat-schedule
117
+ celerybeat.pid
118
+
119
+ # SageMath parsed files
120
+ *.sage.py
121
+
122
+ # Environments
123
+ .env
124
+ .venv
125
+ env/
126
+ venv/
127
+ ENV/
128
+ env.bak/
129
+ venv.bak/
130
+
131
+ # Spyder project settings
132
+ .spyderproject
133
+ .spyproject
134
+
135
+ # Rope project settings
136
+ .ropeproject
137
+
138
+ # mkdocs documentation
139
+ /site
140
+
141
+ # mypy
142
+ .mypy_cache/
143
+ .dmypy.json
144
+ dmypy.json
145
+
146
+ # Pyre type checker
147
+ .pyre/
148
+
149
+ # pytype static type analyzer
150
+ .pytype/
151
+
152
+ # Cython debug symbols
153
+ cython_debug/
154
+
155
+ # PyCharm
156
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
159
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160
+ #.idea/
161
+
162
+ # ruff
163
+ .ruff_cache
164
+
165
+ .vscode
166
+
167
+ core*
app.py ADDED
@@ -0,0 +1,747 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # A mirror to gradio launch stream
2
+ # By Xuan Phi Nguyen at DAMO Academy, Alibaba Group
3
+
4
+ """
5
+ Load FasterLlama with original VLLM codebase,
6
+
7
+ require changing config names to LlamaForCausalLM
8
+
9
+ tensor_parallel must == 1
10
+
11
+ """
12
+
13
+ import torch
14
+ import os
15
+ import numpy as np
16
+ import argparse
17
+ from vllm import LLM, SamplingParams
18
+ import gradio as gr
19
+ from gradio_client.documentation import document, set_documentation_group
20
+
21
+ from typing import List, Optional, Union, Dict, Tuple
22
+
23
+ from tqdm import tqdm
24
+ from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
25
+
26
+ from vllm.engine.arg_utils import EngineArgs
27
+ from vllm.engine.llm_engine import LLMEngine
28
+ from vllm.outputs import RequestOutput
29
+ from vllm.sampling_params import SamplingParams
30
+ from vllm.utils import Counter
31
+ from vllm.sequence import (Sequence, SequenceData, SequenceGroup,
32
+ SequenceGroupMetadata, SequenceOutputs,
33
+ SequenceStatus)
34
+
35
+ # ! reconfigure vllm to faster llama
36
+ from typing import Any, Iterator
37
+ from typing import Iterator, List, Optional, Tuple
38
+ import filelock
39
+ import glob
40
+ import json
41
+ import os
42
+ from huggingface_hub import snapshot_download
43
+
44
+ from tqdm.auto import tqdm
45
+
46
+ from vllm.model_executor.model_loader import _MODEL_REGISTRY
47
+ from vllm.model_executor.models import LlamaForCausalLM
48
+
49
+ _MODEL_REGISTRY['FasterLlamaForCausalLM'] = LlamaForCausalLM
50
+
51
+
52
+ def hf_model_weights_iterator(
53
+ model_name_or_path: str,
54
+ cache_dir: Optional[str] = None,
55
+ use_np_cache: bool = False,
56
+ ) -> Iterator[Tuple[str, torch.Tensor]]:
57
+ from vllm.model_executor.weight_utils import Disabledtqdm
58
+ # Prepare file lock directory to prevent multiple processes from
59
+ # downloading the same model weights at the same time.
60
+ lock_dir = cache_dir if cache_dir is not None else "/tmp"
61
+ lock_file_name = model_name_or_path.replace("/", "-") + ".lock"
62
+ lock = filelock.FileLock(os.path.join(lock_dir, lock_file_name))
63
+
64
+ # Download model weights from huggingface.
65
+ is_local = os.path.isdir(model_name_or_path)
66
+ if not is_local:
67
+ with lock:
68
+ hf_folder = snapshot_download(model_name_or_path,
69
+ allow_patterns="*.bin",
70
+ cache_dir=cache_dir,
71
+ local_files_only=True,
72
+ tqdm_class=Disabledtqdm)
73
+ else:
74
+ hf_folder = model_name_or_path
75
+
76
+ hf_bin_files = [
77
+ # x for x in glob.glob(os.path.join(hf_folder, "*.bin"))
78
+ x for x in glob.glob(os.path.join(hf_folder, "*model*.bin"))
79
+ if not x.endswith("training_args.bin")
80
+ ]
81
+ hf_safetensors_files = [
82
+ x for x in glob.glob(os.path.join(hf_folder, "*model*.safetensors"))
83
+ if not x.endswith("training_args.bin")
84
+ ]
85
+ # print(F'Load bin files: {hf_bin_files} // safetensors: {hf_safetensors_files}')
86
+
87
+ if use_np_cache:
88
+ # Convert the model weights from torch tensors to numpy arrays for
89
+ # faster loading.
90
+ np_folder = os.path.join(hf_folder, "np")
91
+ os.makedirs(np_folder, exist_ok=True)
92
+ weight_names_file = os.path.join(np_folder, "weight_names.json")
93
+ with lock:
94
+ if not os.path.exists(weight_names_file):
95
+ weight_names = []
96
+ for bin_file in hf_bin_files:
97
+ state = torch.load(bin_file, map_location="cpu")
98
+ for name, param in state.items():
99
+ param_path = os.path.join(np_folder, name)
100
+ with open(param_path, "wb") as f:
101
+ np.save(f, param.cpu().detach().numpy())
102
+ weight_names.append(name)
103
+ with open(weight_names_file, "w") as f:
104
+ json.dump(weight_names, f)
105
+
106
+ with open(weight_names_file, "r") as f:
107
+ weight_names = json.load(f)
108
+
109
+ for name in weight_names:
110
+ param_path = os.path.join(np_folder, name)
111
+ with open(param_path, "rb") as f:
112
+ param = np.load(f)
113
+ yield name, torch.from_numpy(param)
114
+ else:
115
+ if len(hf_bin_files) > 0:
116
+ print(F'Load bin files: {hf_bin_files}')
117
+ for bin_file in hf_bin_files:
118
+ state = torch.load(bin_file, map_location="cpu")
119
+ for name, param in state.items():
120
+ yield name, param
121
+ del state
122
+ torch.cuda.empty_cache()
123
+ elif len(hf_safetensors_files) > 0:
124
+ print(F'Load safetensor files: {hf_safetensors_files}')
125
+ from safetensors.torch import load_file
126
+ for safe_file in hf_safetensors_files:
127
+ # state = torch.load(bin_file, map_location="cpu")
128
+ state = load_file(safe_file)
129
+ for name, param in state.items():
130
+ yield name, param
131
+ del state
132
+ torch.cuda.empty_cache()
133
+ else:
134
+ raise ValueError(f'no files available either bin or safe')
135
+
136
+
137
+ def convert_pyslice_to_tensor(x: Any) -> torch.Tensor:
138
+ """convert PySafeSlice object from safetensors to torch.Tensor
139
+
140
+ PySafeSlice object supports indexing, which is done before loading the
141
+ actual tensor and can reduce the amount of memory being read into the
142
+ memory. However, it does not support more advanced functionalities
143
+ like `.view()` or `.t()`. Therefore, if we need to modify the loaded
144
+ tensor with these more complicated operators, we need to convert to
145
+ tensor first.
146
+ """
147
+ if not isinstance(x, torch.Tensor):
148
+ x = x[:]
149
+ return x
150
+
151
+
152
+ def load_padded_tensor_parallel_vocab(
153
+ param: torch.Tensor,
154
+ loaded_weight: Any, # `torch.Tensor` or `PySafeSlice`
155
+ tensor_model_parallel_rank: int,
156
+ ) -> None:
157
+ shard_size = param.shape[0]
158
+ start_idx = tensor_model_parallel_rank * shard_size
159
+ end_idx = (tensor_model_parallel_rank + 1) * shard_size
160
+ loaded_weight = loaded_weight[start_idx:end_idx]
161
+ loaded_weight = convert_pyslice_to_tensor(loaded_weight)
162
+ param[:loaded_weight.shape[0]].copy_(loaded_weight)
163
+
164
+
165
+ def llama_load_weights(
166
+ self,
167
+ model_name_or_path: str,
168
+ cache_dir: Optional[str] = None,
169
+ use_np_cache: bool = False,
170
+ load_format: str = "auto",
171
+ # load_format: str = "pt",
172
+ revision: Optional[str] = None
173
+ ):
174
+ from vllm.model_executor.weight_utils import (
175
+ load_tensor_parallel_weights
176
+ )
177
+ from vllm.model_executor.parallel_utils.parallel_state import (
178
+ get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
179
+ tp_size = get_tensor_model_parallel_world_size()
180
+ tensor_model_parallel_rank = get_tensor_model_parallel_rank()
181
+
182
+ q_proj_shard_size = (self.config.hidden_size // tp_size)
183
+ kv_proj_shard_size = (self.config.hidden_size //
184
+ self.config.num_attention_heads *
185
+ getattr(self.config, "num_key_value_heads", self.config.num_attention_heads) // tp_size)
186
+ attention_weight_specs = [
187
+ # (weight_name, shard_size, offset)
188
+ ("q_proj", q_proj_shard_size, 0),
189
+ ("k_proj", kv_proj_shard_size, q_proj_shard_size),
190
+ ("v_proj", kv_proj_shard_size,
191
+ q_proj_shard_size + kv_proj_shard_size),
192
+ ]
193
+ state_dict = self.state_dict()
194
+ need_to_load = len(state_dict)
195
+ loaded = 0
196
+ # try:
197
+ # iterator = hf_model_weights_iterator(model_name_or_path, cache_dir, use_np_cache)
198
+ # except Exception as e:
199
+ # iterator = hf_model_weights_iterator(model_name_or_path, cache_dir, load_format, revision)
200
+ iterator = hf_model_weights_iterator(model_name_or_path, cache_dir, use_np_cache)
201
+
202
+ # for name, loaded_weight in hf_model_weights_iterator(
203
+ # model_name_or_path, cache_dir, load_format, revision):
204
+ # model_name_or_path, cache_dir, use_np_cache):
205
+ for name, loaded_weight in iterator:
206
+ if "rotary_emb.inv_freq" in name:
207
+ continue
208
+
209
+ # if "embed_tokens" in name or "lm_head" in name:
210
+ # param = state_dict[name]
211
+ # # Consider padding in the vocab size.
212
+ # padded_vocab_size = (param.shape[0] * tp_size)
213
+ # # num_extra_rows = padded_vocab_size - self.config.vocab_size
214
+ # num_extra_rows = padded_vocab_size - loaded_weight.size(0)
215
+ # load_size = loaded_weight.size()
216
+ # extra_rows = torch.empty(num_extra_rows,
217
+ # loaded_weight.shape[1])
218
+ # extra_rows = extra_rows.to(loaded_weight)
219
+ # loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0)
220
+ # if num_extra_rows > 0:
221
+ # print(f'Add empty to {num_extra_rows} extra row for {name}')
222
+ # print(f'Load: {name} | {padded_vocab_size=} | {self.config.vocab_size=} | {num_extra_rows=} | {param.size()=} | {loaded_weight.size()=} | {load_size=}')
223
+
224
+ if "embed_tokens" in name or "lm_head" in name:
225
+ param = state_dict[name]
226
+ load_padded_tensor_parallel_vocab(param, loaded_weight, tensor_model_parallel_rank)
227
+ loaded += 1
228
+ continue
229
+
230
+ is_attention_weight = False
231
+ for weight_name, shard_size, offset in attention_weight_specs:
232
+ if weight_name not in name or "qkv_proj" in name:
233
+ continue
234
+ param = state_dict[name.replace(weight_name, "qkv_proj")]
235
+
236
+ loaded_weight = loaded_weight[
237
+ shard_size * tensor_model_parallel_rank:shard_size *
238
+ (tensor_model_parallel_rank + 1)]
239
+ param_slice = param.data[offset:offset + shard_size]
240
+ assert param_slice.shape == loaded_weight.shape
241
+
242
+ param_slice.copy_(loaded_weight)
243
+ loaded += 1.0 / 3
244
+ is_attention_weight = True
245
+ break
246
+ if is_attention_weight:
247
+ continue
248
+
249
+ # ! qkv_proj is sharded differently if concatenated into qkv
250
+ # qkv: qqqq kkkk vvvv
251
+ # lweight: qq0qq1 kk0kk1 vv0vv1
252
+ # q_shard_size: hidden_size // tp_size = qq
253
+ # qkv_s0: qq0_kk0_vv0
254
+ # qkv_s1: qq1_kk1_vv1
255
+ if "qkv_proj" in name:
256
+ param = state_dict[name]
257
+ # loaded_weight
258
+ qsize = self.config.hidden_size
259
+ kvsize = self.config.hidden_size // self.config.num_attention_heads * getattr(self.config, "num_key_value_heads", self.config.num_attention_heads)
260
+ q_offsets = (
261
+ q_proj_shard_size * tensor_model_parallel_rank,
262
+ q_proj_shard_size * (tensor_model_parallel_rank + 1)
263
+ )
264
+ k_offsets = (
265
+ qsize + kv_proj_shard_size * tensor_model_parallel_rank,
266
+ qsize + kv_proj_shard_size * (tensor_model_parallel_rank + 1)
267
+ )
268
+ v_offsets = (
269
+ qsize + kvsize + kv_proj_shard_size * tensor_model_parallel_rank,
270
+ qsize + kvsize + kv_proj_shard_size * (tensor_model_parallel_rank + 1)
271
+ )
272
+ _loaded_weight = torch.cat(
273
+ [
274
+ loaded_weight[q_offsets[0]:q_offsets[1]],
275
+ loaded_weight[k_offsets[0]:k_offsets[1]],
276
+ loaded_weight[v_offsets[0]:v_offsets[1]],
277
+ ], 0
278
+ )
279
+ # print(f'{name} | {q_offsets} | {k_offsets} | {v_offsets}')
280
+ assert param.shape == _loaded_weight.shape, f'{param.shape=} != {_loaded_weight.shape=}'
281
+ param.data.copy_(_loaded_weight)
282
+ loaded += 1.0
283
+ is_attention_weight = True
284
+ if is_attention_weight:
285
+ continue
286
+
287
+
288
+ is_gate_up_weight = False
289
+ for stride_id, weight_name in enumerate(["gate_proj", "up_proj"]):
290
+ if weight_name not in name or "gate_up_proj" in name:
291
+ continue
292
+ param = state_dict[name.replace(weight_name, "gate_up_proj")]
293
+ shard_size = param.shape[0] // 2
294
+ loaded_weight = loaded_weight[
295
+ shard_size * tensor_model_parallel_rank:shard_size *
296
+ (tensor_model_parallel_rank + 1)]
297
+ param_slice = param.data[shard_size * stride_id:shard_size *
298
+ (stride_id + 1)]
299
+ assert param_slice.shape == loaded_weight.shape
300
+ param_slice.copy_(loaded_weight)
301
+ loaded += 1.0 / 2
302
+ is_gate_up_weight = True
303
+ break
304
+ if is_gate_up_weight:
305
+ continue
306
+
307
+ if "gate_up_proj" in name:
308
+ param = state_dict[name]
309
+ shard_size = param.shape[0] // 2
310
+ intermediate_size = self.config.intermediate_size
311
+ g_offsets = (
312
+ shard_size * tensor_model_parallel_rank,
313
+ shard_size * (tensor_model_parallel_rank + 1)
314
+ )
315
+ u_offsets = (
316
+ intermediate_size + shard_size * tensor_model_parallel_rank,
317
+ intermediate_size + shard_size * (tensor_model_parallel_rank + 1)
318
+ )
319
+ # print(f'{name} {param.size()} | {g_offsets} | {u_offsets}')
320
+ _loaded_weight = torch.cat(
321
+ [
322
+ loaded_weight[g_offsets[0]:g_offsets[1]],
323
+ loaded_weight[u_offsets[0]:u_offsets[1]],
324
+ ], 0
325
+ )
326
+ assert param.shape == _loaded_weight.shape
327
+ param.data.copy_(_loaded_weight)
328
+ loaded += 1.0
329
+ is_gate_up_weight = True
330
+ if is_gate_up_weight:
331
+ continue
332
+
333
+
334
+ param = state_dict[name]
335
+ load_tensor_parallel_weights(param, loaded_weight, name,
336
+ self._column_parallel_weights,
337
+ self._row_parallel_weights,
338
+ tensor_model_parallel_rank)
339
+ loaded += 1
340
+
341
+ if np.abs(loaded - need_to_load) < 0.01:
342
+ print(f'WARNING: only {loaded} params loaded out of {need_to_load}')
343
+ else:
344
+ print(f'Loaded all {loaded} params loaded out of {need_to_load}')
345
+
346
+
347
+ # Reassign LlamaForCausalLM.load_weights with llama_load_weights
348
+ LlamaForCausalLM.load_weights = llama_load_weights
349
+
350
+ # ! ==================================================================
351
+
352
+ set_documentation_group("component")
353
+
354
+ DATA_ROOT = os.environ.get("dataroot", "/mnt/workspace/workgroup/phi")
355
+ MODEL_CACHE_DIR = os.path.join(DATA_ROOT, "pret_models")
356
+
357
+
358
+ DTYPES = {
359
+ 'float16': torch.float16,
360
+ 'bfloat16': torch.bfloat16
361
+ }
362
+
363
+ llm = None
364
+ demo = None
365
+
366
+ RELOAD_SIGNAL = '<<<reload:'
367
+
368
+ BOS_TOKEN = '<s>'
369
+ EOS_TOKEN = '</s>'
370
+
371
+ B_INST, E_INST = "[INST]", "[/INST]"
372
+ B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
373
+
374
+ SYSTEM_PROMPT_1 = """You are a multilingual, helpful, respectful and honest assistant. Your name is SeaL and you are built by DAMO Academy, Alibaba Group. Always answer as helpfully as possible, while being safe. Your \
375
+ answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure\
376
+ that your responses are socially unbiased and positive in nature.
377
+
378
+ If a question does not make any sense, or is not factually coherent, explain why instead of answering something not \
379
+ correct. If you don't know the answer to a question, please don't share false information.
380
+
381
+ As a multilingual assistant, you must respond and follow instructions in the native language of the user by default, unless told otherwise. \
382
+ Your response should adapt to the norms and customs of the respective language and culture.
383
+ """
384
+
385
+ RES_PRINTED = False
386
+
387
+ def llama_chat_sys_input_seq_constructor(text, sys_prompt=SYSTEM_PROMPT_1, bos_token=BOS_TOKEN, eos_token=EOS_TOKEN):
388
+ return f"{bos_token}{B_INST} {B_SYS} {sys_prompt} {E_SYS} {text} {E_INST}"
389
+
390
+
391
+ def llama_chat_multiturn_sys_input_seq_constructor(
392
+ message: str,
393
+ history: List[Tuple[str, str]],
394
+ sys_prompt=SYSTEM_PROMPT_1,
395
+ bos_token=BOS_TOKEN,
396
+ eos_token=EOS_TOKEN,
397
+ ):
398
+ """
399
+ ```
400
+ <bos>[INST] B_SYS SytemPrompt E_SYS Prompt [/INST] Answer <eos>
401
+ <bos>[INST] Prompt [/INST] Answer <eos>
402
+ <bos>[INST] Prompt [/INST]
403
+ ```
404
+ """
405
+ text = ''
406
+ for i, (prompt, res) in enumerate(history):
407
+ if i == 0:
408
+ text += f"{bos_token}{B_INST} {B_SYS} {sys_prompt} {E_SYS} {prompt} {E_INST}"
409
+ else:
410
+ text += f"{bos_token}{B_INST} {prompt} {E_INST}"
411
+
412
+ if res is not None:
413
+ text += f" {res} {eos_token} "
414
+ if len(history) == 0 or text.strip() == '':
415
+ text = f"{bos_token}{B_INST} {B_SYS} {sys_prompt} {E_SYS} {message} {E_INST}"
416
+ else:
417
+ text += f"{bos_token}{B_INST} {message} {E_INST}"
418
+ return text
419
+
420
+
421
+ @document()
422
+ class ChatBot(gr.Chatbot):
423
+ def _postprocess_chat_messages(
424
+ self, chat_message
425
+ ):
426
+ x = super()._postprocess_chat_messages(chat_message)
427
+ if isinstance(x, str):
428
+ x = x.replace("\n", "<br>")
429
+ return x
430
+
431
+
432
+ def load_ckpt(ckpt_file: str) -> str:
433
+ global llm
434
+ status = "Failed"
435
+ if not os.path.exists(ckpt_file):
436
+ status = f"Failed - file not found: {ckpt_file}"
437
+ elif not ckpt_file.endswith(".bin"):
438
+ status = f"Failed - file not .bin: {ckpt_file}"
439
+ else:
440
+ try:
441
+ state_dict = torch.load(ckpt_file, map_location='cpu')
442
+ print(f'loaded state_dict: {ckpt_file}')
443
+ llm.llm_engine.workers[0].model.load_state_dict(state_dict)
444
+ status = f'Success. Loaded {ckpt_file}'
445
+ except Exception as e:
446
+ status = f'Failed - {str(e)}'
447
+ return status
448
+
449
+
450
+
451
+ def chat_response(message, history, temperature: float, max_tokens: int, system_prompt: str = '') -> str:
452
+ global llm
453
+ assert llm is not None
454
+ temperature = float(temperature)
455
+ max_tokens = int(max_tokens)
456
+ if system_prompt.strip() != '':
457
+ # chat version, add system prompt
458
+ message = llama_chat_sys_input_seq_constructor(
459
+ message.strip(),
460
+ sys_prompt=system_prompt
461
+ )
462
+
463
+ sampling_params = SamplingParams(temperature=temperature, max_tokens=max_tokens)
464
+ gen = llm.generate(message, sampling_params)
465
+ out = gen[0].outputs[0].text
466
+ # print(f'{message}<<<{out}>>>')
467
+ return f'{out}'
468
+
469
+
470
+ def vllm_abort(self: LLM):
471
+ scheduler = self.llm_engine.scheduler
472
+ for state_queue in [scheduler.waiting, scheduler.running, scheduler.swapped]:
473
+ for seq_group in state_queue:
474
+ # if seq_group.request_id == request_id:
475
+ # Remove the sequence group from the state queue.
476
+ state_queue.remove(seq_group)
477
+ for seq in seq_group.seqs:
478
+ if seq.is_finished():
479
+ continue
480
+ scheduler.free_seq(seq, SequenceStatus.FINISHED_ABORTED)
481
+
482
+ def _vllm_run_engine(self: LLM, use_tqdm: bool = False) -> Dict[str, RequestOutput]:
483
+ # Initialize tqdm.
484
+ if use_tqdm:
485
+ num_requests = self.llm_engine.get_num_unfinished_requests()
486
+ pbar = tqdm(total=num_requests, desc="Processed prompts")
487
+ # Run the engine.
488
+ outputs: Dict[str, RequestOutput] = {}
489
+ while self.llm_engine.has_unfinished_requests():
490
+ step_outputs = self.llm_engine.step()
491
+ for output in step_outputs:
492
+ # if output.finished:
493
+ # outputs.append(output)
494
+ # if use_tqdm:
495
+ # pbar.update(1)
496
+ outputs[output.request_id] = output
497
+ # outputs = sorted(outputs, key=lambda x: int(x.request_id))
498
+ if len(outputs) > 0:
499
+ yield outputs
500
+ # if use_tqdm:
501
+ # pbar.close()
502
+ # Sort the outputs by request ID.
503
+ # This is necessary because some requests may be finished earlier than
504
+ # its previous requests.
505
+ # outputs = sorted(outputs, key=lambda x: int(x.request_id))
506
+ # return outputs
507
+
508
+
509
+ def vllm_generate_stream(
510
+ self: LLM,
511
+ prompts: Optional[Union[str, List[str]]] = None,
512
+ sampling_params: Optional[SamplingParams] = None,
513
+ prompt_token_ids: Optional[List[List[int]]] = None,
514
+ use_tqdm: bool = False,
515
+ ) -> Dict[str, RequestOutput]:
516
+ """Generates the completions for the input prompts.
517
+
518
+ NOTE: This class automatically batches the given prompts, considering
519
+ the memory constraint. For the best performance, put all of your prompts
520
+ into a single list and pass it to this method.
521
+
522
+ Args:
523
+ prompts: A list of prompts to generate completions for.
524
+ sampling_params: The sampling parameters for text generation. If
525
+ None, we use the default sampling parameters.
526
+ prompt_token_ids: A list of token IDs for the prompts. If None, we
527
+ use the tokenizer to convert the prompts to token IDs.
528
+ use_tqdm: Whether to use tqdm to display the progress bar.
529
+
530
+ Returns:
531
+ A list of `RequestOutput` objects containing the generated
532
+ completions in the same order as the input prompts.
533
+ """
534
+ if prompts is None and prompt_token_ids is None:
535
+ raise ValueError("Either prompts or prompt_token_ids must be "
536
+ "provided.")
537
+ if isinstance(prompts, str):
538
+ # Convert a single prompt to a list.
539
+ prompts = [prompts]
540
+ if prompts is not None and prompt_token_ids is not None:
541
+ if len(prompts) != len(prompt_token_ids):
542
+ raise ValueError("The lengths of prompts and prompt_token_ids "
543
+ "must be the same.")
544
+ if sampling_params is None:
545
+ # Use default sampling params.
546
+ sampling_params = SamplingParams()
547
+
548
+ # Add requests to the engine.
549
+ if prompts is not None:
550
+ num_requests = len(prompts)
551
+ else:
552
+ num_requests = len(prompt_token_ids)
553
+ for i in range(num_requests):
554
+ prompt = prompts[i] if prompts is not None else None
555
+ if prompt_token_ids is None:
556
+ token_ids = None
557
+ else:
558
+ token_ids = prompt_token_ids[i]
559
+ self._add_request(prompt, sampling_params, token_ids)
560
+ # return self._run_engine(use_tqdm)
561
+ yield from _vllm_run_engine(self, use_tqdm)
562
+
563
+
564
+ def chat_response_stream(
565
+ message: str,
566
+ history: List[Tuple[str, str]],
567
+ temperature: float,
568
+ max_tokens: int,
569
+ frequency_penalty: float,
570
+ system_prompt: str
571
+ ) -> str:
572
+ global llm, RES_PRINTED
573
+ assert llm is not None
574
+ # force removing all
575
+ vllm_abort(llm)
576
+
577
+ temperature = float(temperature)
578
+ frequency_penalty = float(frequency_penalty)
579
+ max_tokens = int(max_tokens)
580
+ if system_prompt.strip() != '':
581
+ # chat version, add system prompt
582
+ message = llama_chat_sys_input_seq_constructor(
583
+ message.strip(),
584
+ sys_prompt=system_prompt
585
+ )
586
+ sampling_params = SamplingParams(
587
+ temperature=temperature, max_tokens=max_tokens,
588
+ frequency_penalty=frequency_penalty,
589
+ )
590
+ cur_out = None
591
+ for gen in vllm_generate_stream(llm, message, sampling_params):
592
+ if cur_out is not None:
593
+ yield cur_out
594
+ assert len(gen) == 1, f'{gen}'
595
+ item = next(iter(gen.values()))
596
+ cur_out = item.outputs[0].text
597
+ if not RES_PRINTED:
598
+ print(f'{message}<<<{cur_out}>>>')
599
+ RES_PRINTED = True
600
+ if cur_out is not None:
601
+ yield cur_out
602
+
603
+
604
+ def chat_response_stream_multiturn(
605
+ message: str,
606
+ history: List[Tuple[str, str]],
607
+ temperature: float,
608
+ max_tokens: int,
609
+ frequency_penalty: float,
610
+ system_prompt: str
611
+ ) -> str:
612
+ """Build multi turn
613
+ <bos>[INST] B_SYS SytemPrompt E_SYS Prompt [/INST] Answer <eos>
614
+ <bos>[INST] Prompt [/INST] Answer <eos>
615
+ <bos>[INST] Prompt [/INST]
616
+
617
+ message is incoming prompt
618
+ history don't have the current messauge
619
+ """
620
+ global llm, RES_PRINTED
621
+ assert llm is not None
622
+ assert system_prompt.strip() != '', f'system prompt is empty'
623
+ # force removing all
624
+ vllm_abort(llm)
625
+
626
+ temperature = float(temperature)
627
+ frequency_penalty = float(frequency_penalty)
628
+ max_tokens = int(max_tokens)
629
+
630
+ # history.append([message, None])
631
+ # history will be appended with message later on
632
+ full_prompt = llama_chat_multiturn_sys_input_seq_constructor(
633
+ message, history, sys_prompt=system_prompt
634
+ )
635
+ sampling_params = SamplingParams(
636
+ temperature=temperature, max_tokens=max_tokens,
637
+ frequency_penalty=frequency_penalty,
638
+ )
639
+ cur_out = None
640
+ for gen in vllm_generate_stream(llm, full_prompt, sampling_params):
641
+ if cur_out is not None:
642
+ yield cur_out
643
+ assert len(gen) == 1, f'{gen}'
644
+ item = next(iter(gen.values()))
645
+ cur_out = item.outputs[0].text
646
+ if not RES_PRINTED:
647
+ print(f'{full_prompt}<<<{cur_out}>>>')
648
+ RES_PRINTED = True
649
+ if cur_out is not None:
650
+ yield cur_out
651
+
652
+
653
+ def debug_chat_response_echo(
654
+ message: str,
655
+ history: List[Tuple[str, str]],
656
+ temperature: float = 0.0,
657
+ max_tokens: int = 4096,
658
+ frequency_penalty: float = 0.4,
659
+ system_prompt: str = SYSTEM_PROMPT_1,
660
+ ) -> str:
661
+ yield message
662
+
663
+
664
+ MODEL_TITLE = "DAMO-SeaL-13B - An Assistant for South East Asian Languages"
665
+ MODEL_DESC = """
666
+ This is a 13B DAMO-SeaL-Chat assistant model built by DAMO Academy, Alibaba Group. It can produce helpful responses in English, Vietnamese, Indonesian and Thai.
667
+ """.strip()
668
+
669
+ TENSOR_PARALLEL = int(os.environ.get("TENSOR_PARALLEL", "1"))
670
+ DTYPE = 'bfloat16'
671
+ DTYPE = 'float16'
672
+
673
+ MODEL_PATH = os.environ.get("MODEL_PATH", "notfound, please set `export MODEL_PATH=`")
674
+
675
+ DEBUG = 1
676
+
677
+ def launch():
678
+ global demo, llm, DEBUG
679
+ if DEBUG:
680
+ model_desc + "<br>This is in debug mode, responses will be copy original"
681
+ response_fn = debug_chat_response_echo
682
+ else:
683
+ model_desc = MODEL_DESC
684
+ model_path = MODEL_PATH
685
+ assert os.path.exists(model_path), f'{model_path} not found'
686
+ model_title = MODEL_TITLE
687
+ tensor_parallel = TENSOR_PARALLEL
688
+ assert tensor_parallel > 0 , f'{tensor_parallel} invalid'
689
+ dtype = DTYPE
690
+
691
+ # ! load the model
692
+ llm = LLM(model=model_path, dtype=dtype, tensor_parallel_size=tensor_parallel)
693
+
694
+ sys_prompt = SYSTEM_PROMPT_1
695
+ max_tokens = 4096
696
+ print(f'Use system prompt:\n{sys_prompt}')
697
+
698
+ # response_fn = chat_response_stream_multiturn if args.multiturn else chat_response_stream
699
+ response_fn = chat_response_stream_multiturn
700
+ print(F'respond: {response_fn}')
701
+
702
+ demo = gr.ChatInterface(
703
+ response_fn,
704
+ chatbot=ChatBot(
705
+ bubble_full_width=False,
706
+ latex_delimiters=[
707
+ { "left": "$", "right": "$", "display": False},
708
+ { "left": "$$", "right": "$$", "display": True},
709
+ ]
710
+ ),
711
+ textbox=gr.Textbox(placeholder='Type message', lines=8, max_lines=128, min_width=200),
712
+ submit_btn=gr.Button(value='Submit', variant="primary", scale=0),
713
+ # stop_btn=None,
714
+ title=f"{model_title}",
715
+ description=f"{model_desc}",
716
+ # ! decide if can change the system prompt.
717
+ additional_inputs=[
718
+ gr.Number(value=0, label='Temperature (higher -> more random)'),
719
+ gr.Number(value=max_tokens, label='Max generated tokens (increase if want more generation)'),
720
+ gr.Number(value=0.4, label='Frequency penalty (> 0 encourage new tokens)'),
721
+ gr.Textbox(value=sys_prompt, label='System prompt', lines=8)],
722
+ )
723
+ demo.queue()
724
+ # demo.launch(server_port=args.port)
725
+ demo.launch()
726
+
727
+
728
+ def main():
729
+
730
+ # launch(parser.parse_args())
731
+ launch()
732
+
733
+
734
+ if __name__ == "__main__":
735
+ main()
736
+
737
+
738
+ """
739
+
740
+ export CUDA_VISIBLE_DEVICES=0
741
+ export MODEL_PATH=${dataroot}/hf_train/pretrain_lm/swpn/merlion13s108Hi8kPretFlCW8k.LMFromHf.a.gc.t5k0.vizhthid.mean_std.TrainTask.NLNL.Multi.Vi.FSePlCq13M.FSePlCq13M.m4k.b8.lr1e5.linear.wa0k.ms858k.grac1.se1.8g.v4c.zfsdp/step_4000
742
+ export MODEL_PATH=${dataroot}/llama-2-7b-lxxp-faster
743
+ export MODEL_PATH=${dataroot}/llama-2-7b-chat-xp
744
+ python app.py
745
+
746
+
747
+ """
requirements.txt ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ deepspeed
2
+ sentencepiece
3
+ accelerate
4
+ evaluate
5
+ datasets
6
+ sacrebleu
7
+ websockets
8
+ fire
9
+ indic-nlp-library
10
+ omegaconf
11
+ scikit-learn
12
+ jiwer
13
+ tenacity
14
+ pynvml
15
+ rouge_score
16
+ ninja
17
+ ray
18
+ psutil
19
+ xformers >= 0.0.19
20
+ fastapi
21
+ tensorboard
22
+ geomloss
23
+ einops
24
+ gdown
25
+ vllm==0.1.4
26
+ transformers