Spaces:
Running
on
Zero
Running
on
Zero
# A mirror to gradio launch stream | |
# By Xuan Phi Nguyen at DAMO Academy, Alibaba Group | |
""" | |
Load FasterLlama with original VLLM codebase, | |
require changing config names to LlamaForCausalLM | |
tensor_parallel must == 1 | |
""" | |
import os | |
import numpy as np | |
import argparse | |
import torch | |
import gradio as gr | |
from typing import Any, Iterator | |
from typing import Iterator, List, Optional, Tuple | |
import filelock | |
import glob | |
import json | |
from gradio_client.documentation import document, set_documentation_group | |
from typing import List, Optional, Union, Dict, Tuple | |
from tqdm.auto import tqdm | |
from huggingface_hub import snapshot_download | |
DEBUG = True | |
if not DEBUG: | |
# vllm import | |
from vllm import LLM, SamplingParams | |
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast | |
from vllm.engine.arg_utils import EngineArgs | |
from vllm.engine.llm_engine import LLMEngine | |
from vllm.outputs import RequestOutput | |
from vllm.sampling_params import SamplingParams | |
from vllm.utils import Counter | |
from vllm.sequence import (Sequence, SequenceData, SequenceGroup, | |
SequenceGroupMetadata, SequenceOutputs, | |
SequenceStatus) | |
# ! reconfigure vllm to faster llama | |
from vllm.model_executor.model_loader import _MODEL_REGISTRY | |
from vllm.model_executor.models import LlamaForCausalLM | |
_MODEL_REGISTRY['FasterLlamaForCausalLM'] = LlamaForCausalLM | |
def hf_model_weights_iterator( | |
model_name_or_path: str, | |
cache_dir: Optional[str] = None, | |
use_np_cache: bool = False, | |
) -> Iterator[Tuple[str, torch.Tensor]]: | |
from vllm.model_executor.weight_utils import Disabledtqdm | |
# Prepare file lock directory to prevent multiple processes from | |
# downloading the same model weights at the same time. | |
lock_dir = cache_dir if cache_dir is not None else "/tmp" | |
lock_file_name = model_name_or_path.replace("/", "-") + ".lock" | |
lock = filelock.FileLock(os.path.join(lock_dir, lock_file_name)) | |
# Download model weights from huggingface. | |
is_local = os.path.isdir(model_name_or_path) | |
if not is_local: | |
with lock: | |
hf_folder = snapshot_download(model_name_or_path, | |
allow_patterns="*.bin", | |
cache_dir=cache_dir, | |
local_files_only=True, | |
tqdm_class=Disabledtqdm) | |
else: | |
hf_folder = model_name_or_path | |
hf_bin_files = [ | |
# x for x in glob.glob(os.path.join(hf_folder, "*.bin")) | |
x for x in glob.glob(os.path.join(hf_folder, "*model*.bin")) | |
if not x.endswith("training_args.bin") | |
] | |
hf_safetensors_files = [ | |
x for x in glob.glob(os.path.join(hf_folder, "*model*.safetensors")) | |
if not x.endswith("training_args.bin") | |
] | |
# print(F'Load bin files: {hf_bin_files} // safetensors: {hf_safetensors_files}') | |
if use_np_cache: | |
# Convert the model weights from torch tensors to numpy arrays for | |
# faster loading. | |
np_folder = os.path.join(hf_folder, "np") | |
os.makedirs(np_folder, exist_ok=True) | |
weight_names_file = os.path.join(np_folder, "weight_names.json") | |
with lock: | |
if not os.path.exists(weight_names_file): | |
weight_names = [] | |
for bin_file in hf_bin_files: | |
state = torch.load(bin_file, map_location="cpu") | |
for name, param in state.items(): | |
param_path = os.path.join(np_folder, name) | |
with open(param_path, "wb") as f: | |
np.save(f, param.cpu().detach().numpy()) | |
weight_names.append(name) | |
with open(weight_names_file, "w") as f: | |
json.dump(weight_names, f) | |
with open(weight_names_file, "r") as f: | |
weight_names = json.load(f) | |
for name in weight_names: | |
param_path = os.path.join(np_folder, name) | |
with open(param_path, "rb") as f: | |
param = np.load(f) | |
yield name, torch.from_numpy(param) | |
else: | |
if len(hf_bin_files) > 0: | |
print(F'Load bin files: {hf_bin_files}') | |
for bin_file in hf_bin_files: | |
state = torch.load(bin_file, map_location="cpu") | |
for name, param in state.items(): | |
yield name, param | |
del state | |
torch.cuda.empty_cache() | |
elif len(hf_safetensors_files) > 0: | |
print(F'Load safetensor files: {hf_safetensors_files}') | |
from safetensors.torch import load_file | |
for safe_file in hf_safetensors_files: | |
# state = torch.load(bin_file, map_location="cpu") | |
state = load_file(safe_file) | |
for name, param in state.items(): | |
yield name, param | |
del state | |
torch.cuda.empty_cache() | |
else: | |
raise ValueError(f'no files available either bin or safe') | |
def convert_pyslice_to_tensor(x: Any) -> torch.Tensor: | |
"""convert PySafeSlice object from safetensors to torch.Tensor | |
PySafeSlice object supports indexing, which is done before loading the | |
actual tensor and can reduce the amount of memory being read into the | |
memory. However, it does not support more advanced functionalities | |
like `.view()` or `.t()`. Therefore, if we need to modify the loaded | |
tensor with these more complicated operators, we need to convert to | |
tensor first. | |
""" | |
if not isinstance(x, torch.Tensor): | |
x = x[:] | |
return x | |
def load_padded_tensor_parallel_vocab( | |
param: torch.Tensor, | |
loaded_weight: Any, # `torch.Tensor` or `PySafeSlice` | |
tensor_model_parallel_rank: int, | |
) -> None: | |
shard_size = param.shape[0] | |
start_idx = tensor_model_parallel_rank * shard_size | |
end_idx = (tensor_model_parallel_rank + 1) * shard_size | |
loaded_weight = loaded_weight[start_idx:end_idx] | |
loaded_weight = convert_pyslice_to_tensor(loaded_weight) | |
param[:loaded_weight.shape[0]].copy_(loaded_weight) | |
def llama_load_weights( | |
self, | |
model_name_or_path: str, | |
cache_dir: Optional[str] = None, | |
use_np_cache: bool = False, | |
load_format: str = "auto", | |
# load_format: str = "pt", | |
revision: Optional[str] = None | |
): | |
from vllm.model_executor.weight_utils import ( | |
load_tensor_parallel_weights | |
) | |
from vllm.model_executor.parallel_utils.parallel_state import ( | |
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) | |
tp_size = get_tensor_model_parallel_world_size() | |
tensor_model_parallel_rank = get_tensor_model_parallel_rank() | |
q_proj_shard_size = (self.config.hidden_size // tp_size) | |
kv_proj_shard_size = (self.config.hidden_size // | |
self.config.num_attention_heads * | |
getattr(self.config, "num_key_value_heads", self.config.num_attention_heads) // tp_size) | |
attention_weight_specs = [ | |
# (weight_name, shard_size, offset) | |
("q_proj", q_proj_shard_size, 0), | |
("k_proj", kv_proj_shard_size, q_proj_shard_size), | |
("v_proj", kv_proj_shard_size, | |
q_proj_shard_size + kv_proj_shard_size), | |
] | |
state_dict = self.state_dict() | |
need_to_load = len(state_dict) | |
loaded = 0 | |
# try: | |
# iterator = hf_model_weights_iterator(model_name_or_path, cache_dir, use_np_cache) | |
# except Exception as e: | |
# iterator = hf_model_weights_iterator(model_name_or_path, cache_dir, load_format, revision) | |
iterator = hf_model_weights_iterator(model_name_or_path, cache_dir, use_np_cache) | |
# for name, loaded_weight in hf_model_weights_iterator( | |
# model_name_or_path, cache_dir, load_format, revision): | |
# model_name_or_path, cache_dir, use_np_cache): | |
for name, loaded_weight in iterator: | |
if "rotary_emb.inv_freq" in name: | |
continue | |
# if "embed_tokens" in name or "lm_head" in name: | |
# param = state_dict[name] | |
# # Consider padding in the vocab size. | |
# padded_vocab_size = (param.shape[0] * tp_size) | |
# # num_extra_rows = padded_vocab_size - self.config.vocab_size | |
# num_extra_rows = padded_vocab_size - loaded_weight.size(0) | |
# load_size = loaded_weight.size() | |
# extra_rows = torch.empty(num_extra_rows, | |
# loaded_weight.shape[1]) | |
# extra_rows = extra_rows.to(loaded_weight) | |
# loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0) | |
# if num_extra_rows > 0: | |
# print(f'Add empty to {num_extra_rows} extra row for {name}') | |
# print(f'Load: {name} | {padded_vocab_size=} | {self.config.vocab_size=} | {num_extra_rows=} | {param.size()=} | {loaded_weight.size()=} | {load_size=}') | |
if "embed_tokens" in name or "lm_head" in name: | |
param = state_dict[name] | |
load_padded_tensor_parallel_vocab(param, loaded_weight, tensor_model_parallel_rank) | |
loaded += 1 | |
continue | |
is_attention_weight = False | |
for weight_name, shard_size, offset in attention_weight_specs: | |
if weight_name not in name or "qkv_proj" in name: | |
continue | |
param = state_dict[name.replace(weight_name, "qkv_proj")] | |
loaded_weight = loaded_weight[ | |
shard_size * tensor_model_parallel_rank:shard_size * | |
(tensor_model_parallel_rank + 1)] | |
param_slice = param.data[offset:offset + shard_size] | |
assert param_slice.shape == loaded_weight.shape | |
param_slice.copy_(loaded_weight) | |
loaded += 1.0 / 3 | |
is_attention_weight = True | |
break | |
if is_attention_weight: | |
continue | |
# ! qkv_proj is sharded differently if concatenated into qkv | |
# qkv: qqqq kkkk vvvv | |
# lweight: qq0qq1 kk0kk1 vv0vv1 | |
# q_shard_size: hidden_size // tp_size = qq | |
# qkv_s0: qq0_kk0_vv0 | |
# qkv_s1: qq1_kk1_vv1 | |
if "qkv_proj" in name: | |
param = state_dict[name] | |
# loaded_weight | |
qsize = self.config.hidden_size | |
kvsize = self.config.hidden_size // self.config.num_attention_heads * getattr(self.config, "num_key_value_heads", self.config.num_attention_heads) | |
q_offsets = ( | |
q_proj_shard_size * tensor_model_parallel_rank, | |
q_proj_shard_size * (tensor_model_parallel_rank + 1) | |
) | |
k_offsets = ( | |
qsize + kv_proj_shard_size * tensor_model_parallel_rank, | |
qsize + kv_proj_shard_size * (tensor_model_parallel_rank + 1) | |
) | |
v_offsets = ( | |
qsize + kvsize + kv_proj_shard_size * tensor_model_parallel_rank, | |
qsize + kvsize + kv_proj_shard_size * (tensor_model_parallel_rank + 1) | |
) | |
_loaded_weight = torch.cat( | |
[ | |
loaded_weight[q_offsets[0]:q_offsets[1]], | |
loaded_weight[k_offsets[0]:k_offsets[1]], | |
loaded_weight[v_offsets[0]:v_offsets[1]], | |
], 0 | |
) | |
# print(f'{name} | {q_offsets} | {k_offsets} | {v_offsets}') | |
assert param.shape == _loaded_weight.shape, f'{param.shape=} != {_loaded_weight.shape=}' | |
param.data.copy_(_loaded_weight) | |
loaded += 1.0 | |
is_attention_weight = True | |
if is_attention_weight: | |
continue | |
is_gate_up_weight = False | |
for stride_id, weight_name in enumerate(["gate_proj", "up_proj"]): | |
if weight_name not in name or "gate_up_proj" in name: | |
continue | |
param = state_dict[name.replace(weight_name, "gate_up_proj")] | |
shard_size = param.shape[0] // 2 | |
loaded_weight = loaded_weight[ | |
shard_size * tensor_model_parallel_rank:shard_size * | |
(tensor_model_parallel_rank + 1)] | |
param_slice = param.data[shard_size * stride_id:shard_size * | |
(stride_id + 1)] | |
assert param_slice.shape == loaded_weight.shape | |
param_slice.copy_(loaded_weight) | |
loaded += 1.0 / 2 | |
is_gate_up_weight = True | |
break | |
if is_gate_up_weight: | |
continue | |
if "gate_up_proj" in name: | |
param = state_dict[name] | |
shard_size = param.shape[0] // 2 | |
intermediate_size = self.config.intermediate_size | |
g_offsets = ( | |
shard_size * tensor_model_parallel_rank, | |
shard_size * (tensor_model_parallel_rank + 1) | |
) | |
u_offsets = ( | |
intermediate_size + shard_size * tensor_model_parallel_rank, | |
intermediate_size + shard_size * (tensor_model_parallel_rank + 1) | |
) | |
# print(f'{name} {param.size()} | {g_offsets} | {u_offsets}') | |
_loaded_weight = torch.cat( | |
[ | |
loaded_weight[g_offsets[0]:g_offsets[1]], | |
loaded_weight[u_offsets[0]:u_offsets[1]], | |
], 0 | |
) | |
assert param.shape == _loaded_weight.shape | |
param.data.copy_(_loaded_weight) | |
loaded += 1.0 | |
is_gate_up_weight = True | |
if is_gate_up_weight: | |
continue | |
param = state_dict[name] | |
load_tensor_parallel_weights(param, loaded_weight, name, | |
self._column_parallel_weights, | |
self._row_parallel_weights, | |
tensor_model_parallel_rank) | |
loaded += 1 | |
if np.abs(loaded - need_to_load) < 0.01: | |
print(f'WARNING: only {loaded} params loaded out of {need_to_load}') | |
else: | |
print(f'Loaded all {loaded} params loaded out of {need_to_load}') | |
# Reassign LlamaForCausalLM.load_weights with llama_load_weights | |
if not DEBUG: | |
LlamaForCausalLM.load_weights = llama_load_weights | |
# ! ================================================================== | |
set_documentation_group("component") | |
DATA_ROOT = os.environ.get("dataroot", "/mnt/workspace/workgroup/phi") | |
MODEL_CACHE_DIR = os.path.join(DATA_ROOT, "pret_models") | |
DTYPES = { | |
'float16': torch.float16, | |
'bfloat16': torch.bfloat16 | |
} | |
llm = None | |
demo = None | |
RELOAD_SIGNAL = '<<<reload:' | |
BOS_TOKEN = '<s>' | |
EOS_TOKEN = '</s>' | |
B_INST, E_INST = "[INST]", "[/INST]" | |
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n" | |
SYSTEM_PROMPT_1 = """You are a multilingual, helpful, respectful and honest assistant. Your name is SeaL and you are built by DAMO Academy, Alibaba Group. Always answer as helpfully as possible, while being safe. Your \ | |
answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure\ | |
that your responses are socially unbiased and positive in nature. | |
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not \ | |
correct. If you don't know the answer to a question, please don't share false information. | |
As a multilingual assistant, you must respond and follow instructions in the native language of the user by default, unless told otherwise. \ | |
Your response should adapt to the norms and customs of the respective language and culture. | |
""" | |
RES_PRINTED = False | |
def llama_chat_sys_input_seq_constructor(text, sys_prompt=SYSTEM_PROMPT_1, bos_token=BOS_TOKEN, eos_token=EOS_TOKEN): | |
return f"{bos_token}{B_INST} {B_SYS} {sys_prompt} {E_SYS} {text} {E_INST}" | |
def llama_chat_multiturn_sys_input_seq_constructor( | |
message: str, | |
history: List[Tuple[str, str]], | |
sys_prompt=SYSTEM_PROMPT_1, | |
bos_token=BOS_TOKEN, | |
eos_token=EOS_TOKEN, | |
): | |
""" | |
``` | |
<bos>[INST] B_SYS SytemPrompt E_SYS Prompt [/INST] Answer <eos> | |
<bos>[INST] Prompt [/INST] Answer <eos> | |
<bos>[INST] Prompt [/INST] | |
``` | |
""" | |
text = '' | |
for i, (prompt, res) in enumerate(history): | |
if i == 0: | |
text += f"{bos_token}{B_INST} {B_SYS} {sys_prompt} {E_SYS} {prompt} {E_INST}" | |
else: | |
text += f"{bos_token}{B_INST} {prompt} {E_INST}" | |
if res is not None: | |
text += f" {res} {eos_token} " | |
if len(history) == 0 or text.strip() == '': | |
text = f"{bos_token}{B_INST} {B_SYS} {sys_prompt} {E_SYS} {message} {E_INST}" | |
else: | |
text += f"{bos_token}{B_INST} {message} {E_INST}" | |
return text | |
class ChatBot(gr.Chatbot): | |
def _postprocess_chat_messages( | |
self, chat_message | |
): | |
x = super()._postprocess_chat_messages(chat_message) | |
if isinstance(x, str): | |
x = x.replace("\n", "<br>") | |
return x | |
def load_ckpt(ckpt_file: str) -> str: | |
global llm | |
status = "Failed" | |
if not os.path.exists(ckpt_file): | |
status = f"Failed - file not found: {ckpt_file}" | |
elif not ckpt_file.endswith(".bin"): | |
status = f"Failed - file not .bin: {ckpt_file}" | |
else: | |
try: | |
state_dict = torch.load(ckpt_file, map_location='cpu') | |
print(f'loaded state_dict: {ckpt_file}') | |
llm.llm_engine.workers[0].model.load_state_dict(state_dict) | |
status = f'Success. Loaded {ckpt_file}' | |
except Exception as e: | |
status = f'Failed - {str(e)}' | |
return status | |
def chat_response(message, history, temperature: float, max_tokens: int, system_prompt: str = '') -> str: | |
global llm | |
assert llm is not None | |
temperature = float(temperature) | |
max_tokens = int(max_tokens) | |
if system_prompt.strip() != '': | |
# chat version, add system prompt | |
message = llama_chat_sys_input_seq_constructor( | |
message.strip(), | |
sys_prompt=system_prompt | |
) | |
sampling_params = SamplingParams(temperature=temperature, max_tokens=max_tokens) | |
gen = llm.generate(message, sampling_params) | |
out = gen[0].outputs[0].text | |
# print(f'{message}<<<{out}>>>') | |
return f'{out}' | |
def vllm_abort(self: Any): | |
scheduler = self.llm_engine.scheduler | |
for state_queue in [scheduler.waiting, scheduler.running, scheduler.swapped]: | |
for seq_group in state_queue: | |
# if seq_group.request_id == request_id: | |
# Remove the sequence group from the state queue. | |
state_queue.remove(seq_group) | |
for seq in seq_group.seqs: | |
if seq.is_finished(): | |
continue | |
scheduler.free_seq(seq, SequenceStatus.FINISHED_ABORTED) | |
# def _vllm_run_engine(self: LLM, use_tqdm: bool = False) -> Dict[str, RequestOutput]: | |
def _vllm_run_engine(self: Any, use_tqdm: bool = False) -> Dict[str, Any]: | |
# Initialize tqdm. | |
if use_tqdm: | |
num_requests = self.llm_engine.get_num_unfinished_requests() | |
pbar = tqdm(total=num_requests, desc="Processed prompts") | |
# Run the engine. | |
outputs: Dict[str, RequestOutput] = {} | |
while self.llm_engine.has_unfinished_requests(): | |
step_outputs = self.llm_engine.step() | |
for output in step_outputs: | |
# if output.finished: | |
# outputs.append(output) | |
# if use_tqdm: | |
# pbar.update(1) | |
outputs[output.request_id] = output | |
# outputs = sorted(outputs, key=lambda x: int(x.request_id)) | |
if len(outputs) > 0: | |
yield outputs | |
# if use_tqdm: | |
# pbar.close() | |
# Sort the outputs by request ID. | |
# This is necessary because some requests may be finished earlier than | |
# its previous requests. | |
# outputs = sorted(outputs, key=lambda x: int(x.request_id)) | |
# return outputs | |
def vllm_generate_stream( | |
self: Any, | |
prompts: Optional[Union[str, List[str]]] = None, | |
sampling_params: Optional[Any] = None, | |
prompt_token_ids: Optional[List[List[int]]] = None, | |
use_tqdm: bool = False, | |
) -> Dict[str, Any]: | |
"""Generates the completions for the input prompts. | |
NOTE: This class automatically batches the given prompts, considering | |
the memory constraint. For the best performance, put all of your prompts | |
into a single list and pass it to this method. | |
Args: | |
prompts: A list of prompts to generate completions for. | |
sampling_params: The sampling parameters for text generation. If | |
None, we use the default sampling parameters. | |
prompt_token_ids: A list of token IDs for the prompts. If None, we | |
use the tokenizer to convert the prompts to token IDs. | |
use_tqdm: Whether to use tqdm to display the progress bar. | |
Returns: | |
A list of `RequestOutput` objects containing the generated | |
completions in the same order as the input prompts. | |
""" | |
if prompts is None and prompt_token_ids is None: | |
raise ValueError("Either prompts or prompt_token_ids must be " | |
"provided.") | |
if isinstance(prompts, str): | |
# Convert a single prompt to a list. | |
prompts = [prompts] | |
if prompts is not None and prompt_token_ids is not None: | |
if len(prompts) != len(prompt_token_ids): | |
raise ValueError("The lengths of prompts and prompt_token_ids " | |
"must be the same.") | |
if sampling_params is None: | |
# Use default sampling params. | |
sampling_params = SamplingParams() | |
# Add requests to the engine. | |
if prompts is not None: | |
num_requests = len(prompts) | |
else: | |
num_requests = len(prompt_token_ids) | |
for i in range(num_requests): | |
prompt = prompts[i] if prompts is not None else None | |
if prompt_token_ids is None: | |
token_ids = None | |
else: | |
token_ids = prompt_token_ids[i] | |
self._add_request(prompt, sampling_params, token_ids) | |
# return self._run_engine(use_tqdm) | |
yield from _vllm_run_engine(self, use_tqdm) | |
def chat_response_stream( | |
message: str, | |
history: List[Tuple[str, str]], | |
temperature: float, | |
max_tokens: int, | |
frequency_penalty: float, | |
system_prompt: str | |
) -> str: | |
global llm, RES_PRINTED | |
assert llm is not None | |
# force removing all | |
vllm_abort(llm) | |
temperature = float(temperature) | |
frequency_penalty = float(frequency_penalty) | |
max_tokens = int(max_tokens) | |
if system_prompt.strip() != '': | |
# chat version, add system prompt | |
message = llama_chat_sys_input_seq_constructor( | |
message.strip(), | |
sys_prompt=system_prompt | |
) | |
sampling_params = SamplingParams( | |
temperature=temperature, max_tokens=max_tokens, | |
frequency_penalty=frequency_penalty, | |
) | |
cur_out = None | |
for gen in vllm_generate_stream(llm, message, sampling_params): | |
if cur_out is not None: | |
yield cur_out | |
assert len(gen) == 1, f'{gen}' | |
item = next(iter(gen.values())) | |
cur_out = item.outputs[0].text | |
if not RES_PRINTED: | |
print(f'{message}<<<{cur_out}>>>') | |
RES_PRINTED = True | |
if cur_out is not None: | |
yield cur_out | |
def chat_response_stream_multiturn( | |
message: str, | |
history: List[Tuple[str, str]], | |
temperature: float, | |
max_tokens: int, | |
frequency_penalty: float, | |
system_prompt: str | |
) -> str: | |
"""Build multi turn | |
<bos>[INST] B_SYS SytemPrompt E_SYS Prompt [/INST] Answer <eos> | |
<bos>[INST] Prompt [/INST] Answer <eos> | |
<bos>[INST] Prompt [/INST] | |
message is incoming prompt | |
history don't have the current messauge | |
""" | |
global llm, RES_PRINTED | |
assert llm is not None | |
assert system_prompt.strip() != '', f'system prompt is empty' | |
# force removing all | |
vllm_abort(llm) | |
temperature = float(temperature) | |
frequency_penalty = float(frequency_penalty) | |
max_tokens = int(max_tokens) | |
# history.append([message, None]) | |
# history will be appended with message later on | |
full_prompt = llama_chat_multiturn_sys_input_seq_constructor( | |
message, history, sys_prompt=system_prompt | |
) | |
sampling_params = SamplingParams( | |
temperature=temperature, max_tokens=max_tokens, | |
frequency_penalty=frequency_penalty, | |
) | |
cur_out = None | |
for gen in vllm_generate_stream(llm, full_prompt, sampling_params): | |
if cur_out is not None: | |
yield cur_out | |
assert len(gen) == 1, f'{gen}' | |
item = next(iter(gen.values())) | |
cur_out = item.outputs[0].text | |
if not RES_PRINTED: | |
print(f'{full_prompt}<<<{cur_out}>>>') | |
RES_PRINTED = True | |
if cur_out is not None: | |
yield cur_out | |
def debug_chat_response_echo( | |
message: str, | |
history: List[Tuple[str, str]], | |
temperature: float = 0.0, | |
max_tokens: int = 4096, | |
frequency_penalty: float = 0.4, | |
system_prompt: str = SYSTEM_PROMPT_1, | |
) -> str: | |
yield f"repeat: {message}" | |
# ============ CONSTANT ============ | |
MODEL_TITLE = "DAMO-SeaL-13B - An Assistant for South East Asian Languages" | |
MODEL_DESC = """ | |
This is a 13B DAMO-SeaL-Chat assistant model built by DAMO Academy, Alibaba Group. It can produce helpful responses in English, Vietnamese, Indonesian and Thai. | |
""".strip() | |
cite_markdown = """ | |
## Citation | |
If you find our project useful, hope you can star our repo and cite our paper as follows: | |
``` | |
@article{damonlpsg2023seallm, | |
author = {???}, | |
title = {SeaL: A language model for South East Asian Languages}, | |
year = 2023, | |
} | |
""" | |
# journal = {arXiv preprint arXiv:2306.02858} | |
# url = {https://arxiv.org/abs/2306.02858} | |
TENSOR_PARALLEL = int(os.environ.get("TENSOR_PARALLEL", "1")) | |
DTYPE = 'bfloat16' | |
DTYPE = 'float16' | |
MODEL_PATH = os.environ.get("MODEL_PATH", "notfound, please set `export MODEL_PATH=`") | |
def launch(): | |
global demo, llm, DEBUG | |
model_desc = MODEL_DESC | |
model_path = MODEL_PATH | |
assert os.path.exists(model_path), f'{model_path} not found' | |
model_title = MODEL_TITLE | |
tensor_parallel = TENSOR_PARALLEL | |
assert tensor_parallel > 0 , f'{tensor_parallel} invalid' | |
dtype = DTYPE | |
sys_prompt = SYSTEM_PROMPT_1 | |
max_tokens = 4096 | |
if DEBUG: | |
model_desc += "<br>!!!!! This is in debug mode, responses will be copy original" | |
response_fn = debug_chat_response_echo | |
else: | |
# ! load the model | |
llm = LLM(model=model_path, dtype=dtype, tensor_parallel_size=tensor_parallel) | |
print(f'Use system prompt:\n{sys_prompt}') | |
# response_fn = chat_response_stream_multiturn if args.multiturn else chat_response_stream | |
response_fn = chat_response_stream_multiturn | |
print(F'respond: {response_fn}') | |
demo = gr.ChatInterface( | |
response_fn, | |
chatbot=ChatBot( | |
bubble_full_width=False, | |
latex_delimiters=[ | |
{ "left": "$", "right": "$", "display": False}, | |
{ "left": "$$", "right": "$$", "display": True}, | |
] | |
), | |
textbox=gr.Textbox(placeholder='Type message', lines=8, max_lines=128, min_width=200), | |
submit_btn=gr.Button(value='Submit', variant="primary", scale=0), | |
# stop_btn=None, | |
title=f"{model_title}", | |
description=f"{model_desc}", | |
# ! decide if can change the system prompt. | |
additional_inputs=[ | |
gr.Number(value=0, label='Temperature (higher -> more random)'), | |
gr.Number(value=max_tokens, label='Max generated tokens (increase if want more generation)'), | |
gr.Number(value=0.4, label='Frequency penalty (> 0 encourage new tokens)'), | |
gr.Textbox(value=sys_prompt, label='System prompt', lines=8)], | |
) | |
gr.Markdown(cite_markdown) | |
demo.queue() | |
# demo.launch(server_port=args.port) | |
demo.launch() | |
def main(): | |
# launch(parser.parse_args()) | |
launch() | |
if __name__ == "__main__": | |
main() | |
""" | |
export CUDA_VISIBLE_DEVICES=0 | |
export MODEL_PATH=${dataroot}/hf_train/pretrain_lm/swpn/merlion13s108Hi8kPretFlCW8k.LMFromHf.a.gc.t5k0.vizhthid.mean_std.TrainTask.NLNL.Multi.Vi.FSePlCq13M.FSePlCq13M.m4k.b8.lr1e5.linear.wa0k.ms858k.grac1.se1.8g.v4c.zfsdp/step_4000 | |
export MODEL_PATH=${dataroot}/llama-2-7b-lxxp-faster | |
export MODEL_PATH=${dataroot}/llama-2-7b-chat-xp | |
python app.py | |
""" |