nxphi47's picture
Upload 40 files
8889bbb verified
raw
history blame
6.71 kB
import os
import numpy as np
import mlx.core as mx
import mlx.nn as nn
from huggingface_hub import snapshot_download
from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizer
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union
import time
from mlx_lm import load, generate
from mlx_lm.utils import generate_step
from .base_engine import BaseEngine
from ..configs import (
MODEL_PATH,
)
def generate_string(
model: nn.Module,
tokenizer: PreTrainedTokenizer,
prompt: str,
temp: float = 0.0,
max_tokens: int = 100,
verbose: bool = False,
formatter: Callable = None,
repetition_penalty: Optional[float] = None,
repetition_context_size: Optional[int] = None,
stop_strings: Optional[Tuple[str]] = None
):
prompt_tokens = mx.array(tokenizer.encode(prompt))
stop_strings = stop_strings if stop_strings is None or isinstance(stop_strings, tuple) else tuple(stop_strings)
assert stop_strings is None or isinstance(stop_strings, tuple), f'invalid {stop_strings}'
tic = time.perf_counter()
tokens = []
skip = 0
REPLACEMENT_CHAR = "\ufffd"
for (token, prob), n in zip(
generate_step(
prompt_tokens,
model,
temp,
repetition_penalty,
repetition_context_size,
),
range(max_tokens),
):
if token == tokenizer.eos_token_id:
break
if n == 0:
prompt_time = time.perf_counter() - tic
tic = time.perf_counter()
tokens.append(token.item())
if stop_strings is not None:
token_string = tokenizer.decode(tokens).replace(REPLACEMENT_CHAR, "")
if token_string.strip().endswith(stop_strings):
break
token_string = tokenizer.decode(tokens).replace(REPLACEMENT_CHAR, "")
return token_string
def generate_yield_string(
model: nn.Module,
tokenizer: PreTrainedTokenizer,
prompt: str,
temp: float = 0.0,
max_tokens: int = 100,
verbose: bool = False,
formatter: Callable = None,
repetition_penalty: Optional[float] = None,
repetition_context_size: Optional[int] = None,
stop_strings: Optional[Tuple[str]] = None
):
"""
Generate text from the model.
Args:
model (nn.Module): The language model.
tokenizer (PreTrainedTokenizer): The tokenizer.
prompt (str): The string prompt.
temp (float): The temperature for sampling (default 0).
max_tokens (int): The maximum number of tokens (default 100).
verbose (bool): If ``True``, print tokens and timing information
(default ``False``).
formatter (Optional[Callable]): A function which takes a token and a
probability and displays it.
repetition_penalty (float, optional): The penalty factor for repeating tokens.
repetition_context_size (int, optional): The number of tokens to consider for repetition penalty.
"""
if verbose:
print("=" * 10)
print("Prompt:", prompt)
stop_strings = stop_strings if stop_strings is None or isinstance(stop_strings, tuple) else tuple(stop_strings)
assert stop_strings is None or isinstance(stop_strings, tuple), f'invalid {stop_strings}'
prompt_tokens = mx.array(tokenizer.encode(prompt))
tic = time.perf_counter()
tokens = []
skip = 0
REPLACEMENT_CHAR = "\ufffd"
for (token, prob), n in zip(
generate_step(
prompt_tokens,
model,
temp,
repetition_penalty,
repetition_context_size,
),
range(max_tokens),
):
if token == tokenizer.eos_token_id:
break
# if n == 0:
# prompt_time = time.perf_counter() - tic
# tic = time.perf_counter()
tokens.append(token.item())
# if verbose:
# s = tokenizer.decode(tokens)
# if formatter:
# formatter(s[skip:], prob.item())
# skip = len(s)
# elif REPLACEMENT_CHAR not in s:
# print(s[skip:], end="", flush=True)
# skip = len(s)
token_string = tokenizer.decode(tokens).replace(REPLACEMENT_CHAR, "")
yield token_string
if stop_strings is not None and token_string.strip().endswith(stop_strings):
break
# token_count = len(tokens)
# token_string = tokenizer.decode(tokens).replace(REPLACEMENT_CHAR, "")
# if verbose:
# print(token_string[skip:], flush=True)
# gen_time = time.perf_counter() - tic
# print("=" * 10)
# if token_count == 0:
# print("No tokens generated for this prompt")
# return
# prompt_tps = prompt_tokens.size / prompt_time
# gen_tps = (token_count - 1) / gen_time
# print(f"Prompt: {prompt_tps:.3f} tokens-per-sec")
# print(f"Generation: {gen_tps:.3f} tokens-per-sec")
# return token_string
class MlxEngine(BaseEngine):
def __init__(self, **kwargs) -> None:
super().__init__(**kwargs)
self._model = None
self._tokenizer = None
@property
def tokenizer(self) -> PreTrainedTokenizer:
return self._tokenizer
def load_model(self, ):
model_path = MODEL_PATH
self._model, self._tokenizer = load(model_path)
self.model_path = model_path
print(f'Load MLX model from {model_path}')
def generate_yield_string(self, prompt, temperature, max_tokens, stop_strings: Optional[Tuple[str]] = None, **kwargs):
num_tokens = len(self.tokenizer.encode(prompt))
response = None
for response in generate_yield_string(
self._model, self._tokenizer,
prompt, temp=temperature, max_tokens=max_tokens,
repetition_penalty=kwargs.get("repetition_penalty", None),
stop_strings=stop_strings,
):
yield response, num_tokens
if response is not None:
full_text = prompt + response
num_tokens = len(self.tokenizer.encode(full_text))
yield response, num_tokens
def batch_generate(self, prompts, temperature, max_tokens, stop_strings: Optional[Tuple[str]] = None, **kwargs):
"""
! MLX does not support
"""
responses = [
generate_string(
self._model, self._tokenizer,
s, temp=temperature, max_tokens=max_tokens,
repetition_penalty=kwargs.get("repetition_penalty", None),
stop_strings=stop_strings,
)
for s in prompts
]
return responses