Spaces:
Running
on
Zero
Running
on
Zero
import spaces | |
import os | |
import numpy as np | |
import argparse | |
import torch | |
import sys | |
import gradio as gr | |
from typing import Any, Iterator | |
from typing import Iterator, List, Optional, Tuple | |
import filelock | |
import glob | |
import json | |
import time | |
from gradio.routes import Request | |
from gradio.utils import SyncToAsyncIterator, async_iteration | |
from gradio.helpers import special_args | |
import anyio | |
from typing import AsyncGenerator, Callable, Literal, Union, cast | |
from gradio_client.documentation import document, set_documentation_group | |
from typing import List, Optional, Union, Dict, Tuple | |
from tqdm.auto import tqdm | |
from huggingface_hub import snapshot_download | |
import types | |
from gradio.components import Button | |
from gradio.events import Dependency, EventListenerMethod | |
from .base_engine import BaseEngine | |
# ! Remember to use static cache | |
from transformers import ( | |
GenerationConfig, | |
GenerationMixin, | |
LogitsProcessorList, | |
StoppingCriteriaList, | |
DisjunctiveConstraint, | |
BeamSearchScorer, | |
PhrasalConstraint, | |
ConstrainedBeamSearchScorer, | |
PreTrainedModel, | |
) | |
import numpy as np | |
import random | |
import warnings | |
import inspect | |
from transformers.generation.utils import GenerateOutput, SampleOutput, logger | |
import torch | |
from typing import Callable, List, Optional, Union | |
from torch import nn | |
import torch.distributed as dist | |
import copy | |
from ..configs import ( | |
MODEL_PATH, | |
DTYPE, | |
DEVICE, | |
) | |
def setup_seed(seed): | |
if seed == -1: | |
return | |
torch.manual_seed(seed) | |
if torch.cuda.is_available(): | |
torch.cuda.manual_seed_all(seed) | |
np.random.seed(seed) | |
random.seed(seed) | |
torch.backends.cudnn.deterministic = True | |
class NewGenerationMixin(GenerationMixin): | |
""" | |
Allow generator sampling | |
""" | |
# ! Copy from transformers.generation.utils -> GenerationMixin | |
# Change sample function to sample_stream | |
def sample_stream( | |
self, | |
input_ids: torch.LongTensor, | |
logits_processor: Optional[LogitsProcessorList] = None, | |
stopping_criteria: Optional[StoppingCriteriaList] = None, | |
logits_warper: Optional[LogitsProcessorList] = None, | |
max_length: Optional[int] = None, | |
pad_token_id: Optional[int] = None, | |
eos_token_id: Optional[Union[int, List[int]]] = None, | |
output_attentions: Optional[bool] = None, | |
output_hidden_states: Optional[bool] = None, | |
output_scores: Optional[bool] = None, | |
output_logits: Optional[bool] = None, | |
return_dict_in_generate: Optional[bool] = None, | |
synced_gpus: bool = False, | |
streamer: Optional["BaseStreamer"] = None, | |
**model_kwargs, | |
): | |
r""" | |
Generates sequences of token ids for models with a language modeling head using **multinomial sampling** and | |
can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. | |
<Tip warning={true}> | |
In most cases, you do not need to call [`~generation.GenerationMixin.sample`] directly. Use generate() instead. | |
For an overview of generation strategies and code examples, check the [following | |
guide](../generation_strategies). | |
</Tip> | |
Parameters: | |
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): | |
The sequence used as a prompt for the generation. | |
logits_processor (`LogitsProcessorList`, *optional*): | |
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] | |
used to modify the prediction scores of the language modeling head applied at each generation step. | |
stopping_criteria (`StoppingCriteriaList`, *optional*): | |
An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] | |
used to tell if the generation loop should stop. | |
logits_warper (`LogitsProcessorList`, *optional*): | |
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used | |
to warp the prediction score distribution of the language modeling head applied before multinomial | |
sampling at each generation step. | |
max_length (`int`, *optional*, defaults to 20): | |
**DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated | |
tokens. The maximum length of the sequence to be generated. | |
pad_token_id (`int`, *optional*): | |
The id of the *padding* token. | |
eos_token_id (`Union[int, List[int]]`, *optional*): | |
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens. | |
output_attentions (`bool`, *optional*, defaults to `False`): | |
Whether or not to return the attentions tensors of all attention layers. See `attentions` under | |
returned tensors for more details. | |
output_hidden_states (`bool`, *optional*, defaults to `False`): | |
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors | |
for more details. | |
output_scores (`bool`, *optional*, defaults to `False`): | |
Whether or not to return the prediction scores. See `scores` under returned tensors for more details. | |
output_logits (`bool`, *optional*, defaults to `False`): | |
Whether or not to return the raw prediction logit scores. See `logits` under returned tensors for | |
more details. | |
return_dict_in_generate (`bool`, *optional*, defaults to `False`): | |
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. | |
synced_gpus (`bool`, *optional*, defaults to `False`): | |
Whether to continue running the while loop until max_length (needed for ZeRO stage 3) | |
streamer (`BaseStreamer`, *optional*): | |
Streamer object that will be used to stream the generated sequences. Generated tokens are passed | |
through `streamer.put(token_ids)` and the streamer is responsible for any further processing. | |
model_kwargs: | |
Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is | |
an encoder-decoder model the kwargs should include `encoder_outputs`. | |
Return: | |
[`~generation.GenerateDecoderOnlyOutput`], [`~generation.GenerateEncoderDecoderOutput`] or `torch.LongTensor`: | |
A `torch.LongTensor` containing the generated tokens (default behaviour) or a | |
[`~generation.GenerateDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and | |
`return_dict_in_generate=True` or a [`~generation.GenerateEncoderDecoderOutput`] if | |
`model.config.is_encoder_decoder=True`. | |
Examples: | |
```python | |
>>> from transformers import ( | |
... AutoTokenizer, | |
... AutoModelForCausalLM, | |
... LogitsProcessorList, | |
... MinLengthLogitsProcessor, | |
... TopKLogitsWarper, | |
... TemperatureLogitsWarper, | |
... StoppingCriteriaList, | |
... MaxLengthCriteria, | |
... ) | |
>>> import torch | |
>>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2") | |
>>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2") | |
>>> # set pad_token_id to eos_token_id because GPT2 does not have a EOS token | |
>>> model.config.pad_token_id = model.config.eos_token_id | |
>>> model.generation_config.pad_token_id = model.config.eos_token_id | |
>>> input_prompt = "Today is a beautiful day, and" | |
>>> input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids | |
>>> # instantiate logits processors | |
>>> logits_processor = LogitsProcessorList( | |
... [ | |
... MinLengthLogitsProcessor(15, eos_token_id=model.generation_config.eos_token_id), | |
... ] | |
... ) | |
>>> # instantiate logits processors | |
>>> logits_warper = LogitsProcessorList( | |
... [ | |
... TopKLogitsWarper(50), | |
... TemperatureLogitsWarper(0.7), | |
... ] | |
... ) | |
>>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)]) | |
>>> torch.manual_seed(0) # doctest: +IGNORE_RESULT | |
>>> outputs = model.sample( | |
... input_ids, | |
... logits_processor=logits_processor, | |
... logits_warper=logits_warper, | |
... stopping_criteria=stopping_criteria, | |
... ) | |
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True) | |
['Today is a beautiful day, and we must do everything possible to make it a day of celebration.'] | |
```""" | |
# init values | |
print(f'Streaming tokens...') | |
from transformers.generation.utils import ( | |
validate_stopping_criteria, GenerateEncoderDecoderOutput, GenerateDecoderOnlyOutput | |
) | |
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() | |
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() | |
if max_length is not None: | |
warnings.warn( | |
"`max_length` is deprecated in this function, use" | |
" `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead.", | |
UserWarning, | |
) | |
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) | |
logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList() | |
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id | |
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id | |
if isinstance(eos_token_id, int): | |
eos_token_id = [eos_token_id] | |
eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None | |
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores | |
output_logits = output_logits if output_logits is not None else self.generation_config.output_logits | |
output_attentions = ( | |
output_attentions if output_attentions is not None else self.generation_config.output_attentions | |
) | |
output_hidden_states = ( | |
output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states | |
) | |
return_dict_in_generate = ( | |
return_dict_in_generate | |
if return_dict_in_generate is not None | |
else self.generation_config.return_dict_in_generate | |
) | |
# init attention / hidden states / scores tuples | |
scores = () if (return_dict_in_generate and output_scores) else None | |
raw_logits = () if (return_dict_in_generate and output_logits) else None | |
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None | |
cross_attentions = () if (return_dict_in_generate and output_attentions) else None | |
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None | |
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states | |
if return_dict_in_generate and self.config.is_encoder_decoder: | |
encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None | |
encoder_hidden_states = ( | |
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None | |
) | |
# keep track of which sequences are already finished | |
unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device) | |
this_peer_finished = False # used by synced_gpus only | |
# auto-regressive generation | |
while True: | |
if synced_gpus: | |
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence. | |
# The following logic allows an early break if all peers finished generating their sequence | |
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) | |
# send 0.0 if we finished, 1.0 otherwise | |
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) | |
# did all peers finish? the reduced sum will be 0.0 then | |
if this_peer_finished_flag.item() == 0.0: | |
break | |
# prepare model inputs | |
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) | |
# forward pass to get next token | |
outputs = self( | |
**model_inputs, | |
return_dict=True, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
) | |
if synced_gpus and this_peer_finished: | |
continue # don't waste resources running the code we don't need | |
next_token_logits = outputs.logits[:, -1, :] | |
# pre-process distribution | |
next_token_scores = logits_processor(input_ids, next_token_logits) | |
next_token_scores = logits_warper(input_ids, next_token_scores) | |
# Store scores, attentions and hidden_states when required | |
if return_dict_in_generate: | |
if output_scores: | |
scores += (next_token_scores,) | |
if output_logits: | |
raw_logits += (next_token_logits,) | |
if output_attentions: | |
decoder_attentions += ( | |
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) | |
) | |
if self.config.is_encoder_decoder: | |
cross_attentions += (outputs.cross_attentions,) | |
if output_hidden_states: | |
decoder_hidden_states += ( | |
(outputs.decoder_hidden_states,) | |
if self.config.is_encoder_decoder | |
else (outputs.hidden_states,) | |
) | |
# sample | |
probs = nn.functional.softmax(next_token_scores, dim=-1) | |
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) | |
# finished sentences should have their next token be a padding token | |
if eos_token_id is not None: | |
if pad_token_id is None: | |
raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") | |
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) | |
yield next_tokens.cpu() | |
# update generated ids, model inputs, and length for next step | |
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) | |
if streamer is not None: | |
streamer.put(next_tokens.cpu()) | |
next_model_inputs = {} | |
if "cache_position" in model_inputs: | |
next_model_inputs['cache_position'] = model_inputs['cache_position'] | |
try: | |
model_kwargs = self._update_model_kwargs_for_generation( | |
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, | |
# model_inputs=model_inputs | |
model_inputs=next_model_inputs, | |
) | |
except Exception as e: | |
# ! some transformers version don't have model_inputs in generation | |
model_kwargs = self._update_model_kwargs_for_generation( | |
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, | |
# model_inputs=model_inputs | |
# model_inputs=next_model_inputs, | |
) | |
# if eos_token was found in one sentence, set sentence to finished | |
if eos_token_id_tensor is not None: | |
unfinished_sequences = unfinished_sequences.mul( | |
next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0) | |
) | |
# stop when each sentence is finished | |
if unfinished_sequences.max() == 0: | |
this_peer_finished = True | |
# stop if we exceed the maximum length | |
if stopping_criteria(input_ids, scores): | |
this_peer_finished = True | |
if this_peer_finished and not synced_gpus: | |
break | |
if streamer is not None: | |
streamer.end() | |
# if return_dict_in_generate: | |
# if self.config.is_encoder_decoder: | |
# return GenerateEncoderDecoderOutput( | |
# sequences=input_ids, | |
# scores=scores, | |
# logits=raw_logits, | |
# encoder_attentions=encoder_attentions, | |
# encoder_hidden_states=encoder_hidden_states, | |
# decoder_attentions=decoder_attentions, | |
# cross_attentions=cross_attentions, | |
# decoder_hidden_states=decoder_hidden_states, | |
# past_key_values=model_kwargs.get("past_key_values"), | |
# ) | |
# else: | |
# return GenerateDecoderOnlyOutput( | |
# sequences=input_ids, | |
# scores=scores, | |
# logits=raw_logits, | |
# attentions=decoder_attentions, | |
# hidden_states=decoder_hidden_states, | |
# past_key_values=model_kwargs.get("past_key_values"), | |
# ) | |
# else: | |
# return input_ids | |
from ..configs import ( | |
STREAM_CHECK_MULTIPLE, | |
STREAM_YIELD_MULTIPLE, | |
) | |
BLOCK_LANGS = str(os.environ.get("BLOCK_LANGS", "")) | |
BLOCK_LANGS = [x.strip() for x in BLOCK_LANGS.strip().split(";")] if len(BLOCK_LANGS.strip()) > 0 else [] | |
LANG_BLOCK_HISTORY = bool(int(os.environ.get("LANG_BLOCK_HISTORY", "0"))) | |
KEYWORDS = os.environ.get("KEYWORDS", "").strip() | |
KEYWORDS = KEYWORDS.split(";") if len(KEYWORDS) > 0 else [] | |
KEYWORDS = [x.lower() for x in KEYWORDS] | |
LANG_BLOCK_MESSAGE = """Unsupported language.""" | |
KEYWORD_BLOCK_MESSAGE = "Invalid request." | |
def _detect_lang(text): | |
# Disable language that may have safety risk | |
from langdetect import detect as detect_lang | |
dlang = None | |
try: | |
dlang = detect_lang(text) | |
except Exception as e: | |
if "No features in text." in str(e): | |
return "en" | |
else: | |
return "zh" | |
return dlang | |
def block_lang( | |
message: str, | |
history: List[Tuple[str, str]] = None, | |
) -> str: | |
# relieve history base block | |
if len(BLOCK_LANGS) == 0: | |
return False | |
if LANG_BLOCK_HISTORY and history is not None and any((LANG_BLOCK_MESSAGE in x[1].strip()) for x in history): | |
return True | |
else: | |
_lang = _detect_lang(message) | |
if _lang in BLOCK_LANGS: | |
# print(f'Detect blocked {_lang}: {message}') | |
return True | |
else: | |
return False | |
def safety_check(text, history=None, ) -> Optional[str]: | |
""" | |
Despite our effort in safety tuning and red teaming, our models may still generate harmful or illegal content. | |
This provides an additional security measure to enhance safety and compliance with local regulations. | |
""" | |
if len(KEYWORDS) > 0 and any(x in text.lower() for x in KEYWORDS): | |
return KEYWORD_BLOCK_MESSAGE | |
if len(BLOCK_LANGS) > 0: | |
if block_lang(text, history): | |
return LANG_BLOCK_MESSAGE | |
return None | |
def safety_check_conversation_string(text, delimiter=None) -> Optional[str]: | |
if len(KEYWORDS) > 0 and any(x in text.lower() for x in KEYWORDS): | |
return KEYWORD_BLOCK_MESSAGE | |
if len(BLOCK_LANGS) > 0: | |
import re | |
delimiter = delimiter or (r"</s><\|im_start\|>user\n", r"</s><\|im_start\|>assistant\n", r"<\|im_start\|>system\n") | |
turns = re.split(r"|".join(delimiter), text) | |
turns = [t for t in turns if t.strip() != ''] | |
for t in turns: | |
if block_lang(t): | |
return LANG_BLOCK_MESSAGE | |
return None | |
def is_check_safety(): | |
return len(KEYWORDS) > 0 or len(BLOCK_LANGS) > 0 | |
def safety_check_conversation(conversation) -> Optional[str]: | |
""" | |
Despite our effort in safety tuning and red teaming, our models may still generate harmful or illegal content. | |
This provides an additional security measure to enhance safety and compliance with local regulations. | |
""" | |
texts = [c['content'] for c in conversation] | |
for text in texts: | |
if len(KEYWORDS) > 0 and any(x in text.lower() for x in KEYWORDS): | |
return KEYWORD_BLOCK_MESSAGE | |
if len(BLOCK_LANGS) > 0: | |
if block_lang(text): | |
return LANG_BLOCK_MESSAGE | |
return None | |
class TransformersEngine(BaseEngine): | |
def max_position_embeddings(self) -> int: | |
return self._model.config.max_position_embeddings | |
def tokenizer(self): | |
return self._tokenizer | |
def load_model(self): | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import sys | |
# caution: path[0] is reserved for script path (or '' in REPL) | |
# sys.path.append(CODE_PATH) | |
self.model_path = model_path = MODEL_PATH | |
self.torch_dtype = torch.bfloat16 if DTYPE == 'bfloat16' else torch.float16 | |
self.device_map = DEVICE | |
print(f'Loading model from {model_path} on {self.device_map} with {self.torch_dtype}') | |
self._tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) | |
assert self._tokenizer.chat_template is not None and self._tokenizer.chat_template != "", f"{self._tokenizer.chat_template=} not found!" | |
self._model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=self.torch_dtype, device_map=self.device_map, trust_remote_code=True).eval() | |
self._model.sample_old = self._model.sample | |
self._model._sample = types.MethodType(NewGenerationMixin.sample_stream, self._model) | |
print(self._model) | |
print(f"{self.max_position_embeddings=}") | |
def maybe_raise_safety(self, message, gen_index=-1): | |
if is_check_safety(): | |
if gen_index < 0: | |
message_safety = safety_check_conversation_string(message) | |
if message_safety is not None: | |
raise gr.Error(message_safety) | |
else: | |
if STREAM_CHECK_MULTIPLE > 0 and gen_index % STREAM_CHECK_MULTIPLE == 0: | |
message_safety = safety_check_conversation_string(message) | |
if message_safety is not None: | |
raise gr.Error(message_safety) | |
def generate_yield_string(self, prompt, temperature, max_tokens, stop_strings: Optional[Tuple[str]] = None, **kwargs): | |
# ! MUST PUT INSIDE torch.no_grad() otherwise it will overflow OOM | |
import sys | |
# self._model._sample = types.MethodType(NewGenerationMixin.sample_stream, self._model) | |
self._model.sample = types.MethodType(NewGenerationMixin.sample_stream, self._model) | |
self.maybe_raise_safety(prompt) | |
with torch.no_grad(): | |
inputs = self.tokenizer(prompt, return_tensors='pt') | |
num_tokens = inputs.input_ids.size(1) | |
inputs = inputs.to(self._model.device) | |
generator = self._model.generate( | |
**inputs, | |
do_sample=True, | |
temperature=temperature, | |
max_new_tokens=max_tokens, | |
pad_token_id=self.tokenizer.pad_token_id, | |
) | |
out_tokens = [] | |
response = None | |
for index, token in enumerate(generator): | |
out_tokens.extend(token.tolist()) | |
response = self.tokenizer.decode(out_tokens) | |
if "<|im_start|>assistant\n" in response: | |
response = response.split("<|im_start|>assistant\n")[-1] | |
num_tokens += 1 | |
# print(f"{response}", end='\r') | |
# sys.stdout.flush() | |
self.maybe_raise_safety(response, gen_index=index) | |
yield response, num_tokens | |
del generator | |
if response is not None: | |
if "<|im_start|>assistant\n" in response: | |
response = response.split("<|im_start|>assistant\n")[-1] | |
self.maybe_raise_safety(response) | |
full_text = prompt + response | |
num_tokens = len(self.tokenizer.encode(full_text)) | |
yield response, num_tokens |