Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import numpy as np | |
import mlx.core as mx | |
import mlx.nn as nn | |
from huggingface_hub import snapshot_download | |
from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizer | |
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union | |
import time | |
from mlx_lm import load, generate | |
from mlx_lm.utils import generate_step | |
from .base_engine import BaseEngine | |
from ..configs import ( | |
MODEL_PATH, | |
) | |
def generate_string( | |
model: nn.Module, | |
tokenizer: PreTrainedTokenizer, | |
prompt: str, | |
temp: float = 0.0, | |
max_tokens: int = 100, | |
verbose: bool = False, | |
formatter: Callable = None, | |
repetition_penalty: Optional[float] = None, | |
repetition_context_size: Optional[int] = None, | |
stop_strings: Optional[Tuple[str]] = None | |
): | |
prompt_tokens = mx.array(tokenizer.encode(prompt)) | |
stop_strings = stop_strings if stop_strings is None or isinstance(stop_strings, tuple) else tuple(stop_strings) | |
assert stop_strings is None or isinstance(stop_strings, tuple), f'invalid {stop_strings}' | |
tic = time.perf_counter() | |
tokens = [] | |
skip = 0 | |
REPLACEMENT_CHAR = "\ufffd" | |
for (token, prob), n in zip( | |
generate_step( | |
prompt_tokens, | |
model, | |
temp, | |
repetition_penalty, | |
repetition_context_size, | |
), | |
range(max_tokens), | |
): | |
if token == tokenizer.eos_token_id: | |
break | |
if n == 0: | |
prompt_time = time.perf_counter() - tic | |
tic = time.perf_counter() | |
tokens.append(token.item()) | |
if stop_strings is not None: | |
token_string = tokenizer.decode(tokens).replace(REPLACEMENT_CHAR, "") | |
if token_string.strip().endswith(stop_strings): | |
break | |
token_string = tokenizer.decode(tokens).replace(REPLACEMENT_CHAR, "") | |
return token_string | |
def generate_yield_string( | |
model: nn.Module, | |
tokenizer: PreTrainedTokenizer, | |
prompt: str, | |
temp: float = 0.0, | |
max_tokens: int = 100, | |
verbose: bool = False, | |
formatter: Callable = None, | |
repetition_penalty: Optional[float] = None, | |
repetition_context_size: Optional[int] = None, | |
stop_strings: Optional[Tuple[str]] = None | |
): | |
""" | |
Generate text from the model. | |
Args: | |
model (nn.Module): The language model. | |
tokenizer (PreTrainedTokenizer): The tokenizer. | |
prompt (str): The string prompt. | |
temp (float): The temperature for sampling (default 0). | |
max_tokens (int): The maximum number of tokens (default 100). | |
verbose (bool): If ``True``, print tokens and timing information | |
(default ``False``). | |
formatter (Optional[Callable]): A function which takes a token and a | |
probability and displays it. | |
repetition_penalty (float, optional): The penalty factor for repeating tokens. | |
repetition_context_size (int, optional): The number of tokens to consider for repetition penalty. | |
""" | |
if verbose: | |
print("=" * 10) | |
print("Prompt:", prompt) | |
stop_strings = stop_strings if stop_strings is None or isinstance(stop_strings, tuple) else tuple(stop_strings) | |
assert stop_strings is None or isinstance(stop_strings, tuple), f'invalid {stop_strings}' | |
prompt_tokens = mx.array(tokenizer.encode(prompt)) | |
tic = time.perf_counter() | |
tokens = [] | |
skip = 0 | |
REPLACEMENT_CHAR = "\ufffd" | |
for (token, prob), n in zip( | |
generate_step( | |
prompt_tokens, | |
model, | |
temp, | |
repetition_penalty, | |
repetition_context_size, | |
), | |
range(max_tokens), | |
): | |
if token == tokenizer.eos_token_id: | |
break | |
# if n == 0: | |
# prompt_time = time.perf_counter() - tic | |
# tic = time.perf_counter() | |
tokens.append(token.item()) | |
# if verbose: | |
# s = tokenizer.decode(tokens) | |
# if formatter: | |
# formatter(s[skip:], prob.item()) | |
# skip = len(s) | |
# elif REPLACEMENT_CHAR not in s: | |
# print(s[skip:], end="", flush=True) | |
# skip = len(s) | |
token_string = tokenizer.decode(tokens).replace(REPLACEMENT_CHAR, "") | |
yield token_string | |
if stop_strings is not None and token_string.strip().endswith(stop_strings): | |
break | |
# token_count = len(tokens) | |
# token_string = tokenizer.decode(tokens).replace(REPLACEMENT_CHAR, "") | |
# if verbose: | |
# print(token_string[skip:], flush=True) | |
# gen_time = time.perf_counter() - tic | |
# print("=" * 10) | |
# if token_count == 0: | |
# print("No tokens generated for this prompt") | |
# return | |
# prompt_tps = prompt_tokens.size / prompt_time | |
# gen_tps = (token_count - 1) / gen_time | |
# print(f"Prompt: {prompt_tps:.3f} tokens-per-sec") | |
# print(f"Generation: {gen_tps:.3f} tokens-per-sec") | |
# return token_string | |
class MlxEngine(BaseEngine): | |
def __init__(self, **kwargs) -> None: | |
super().__init__(**kwargs) | |
self._model = None | |
self._tokenizer = None | |
def tokenizer(self) -> PreTrainedTokenizer: | |
return self._tokenizer | |
def load_model(self, ): | |
model_path = MODEL_PATH | |
self._model, self._tokenizer = load(model_path) | |
self.model_path = model_path | |
print(f'Load MLX model from {model_path}') | |
def generate_yield_string(self, prompt, temperature, max_tokens, stop_strings: Optional[Tuple[str]] = None, **kwargs): | |
num_tokens = len(self.tokenizer.encode(prompt)) | |
response = None | |
for response in generate_yield_string( | |
self._model, self._tokenizer, | |
prompt, temp=temperature, max_tokens=max_tokens, | |
repetition_penalty=kwargs.get("repetition_penalty", None), | |
stop_strings=stop_strings, | |
): | |
yield response, num_tokens | |
if response is not None: | |
full_text = prompt + response | |
num_tokens = len(self.tokenizer.encode(full_text)) | |
yield response, num_tokens | |
def batch_generate(self, prompts, temperature, max_tokens, stop_strings: Optional[Tuple[str]] = None, **kwargs): | |
""" | |
! MLX does not support | |
""" | |
responses = [ | |
generate_string( | |
self._model, self._tokenizer, | |
s, temp=temperature, max_tokens=max_tokens, | |
repetition_penalty=kwargs.get("repetition_penalty", None), | |
stop_strings=stop_strings, | |
) | |
for s in prompts | |
] | |
return responses | |