Spaces:
Running
on
Zero
Running
on
Zero
# from transformers_stream_generator import init_stream_support | |
# init_stream_support() | |
import os | |
import numpy as np | |
import argparse | |
import spaces | |
import torch | |
import gradio as gr | |
from typing import Any, Iterator | |
from typing import Iterator, List, Optional, Tuple | |
import filelock | |
import glob | |
import json | |
import time | |
from gradio.routes import Request | |
from gradio.utils import SyncToAsyncIterator, async_iteration | |
from gradio.helpers import special_args | |
import anyio | |
from typing import AsyncGenerator, Callable, Literal, Union, cast | |
from gradio_client.documentation import document, set_documentation_group | |
from typing import List, Optional, Union, Dict, Tuple | |
from tqdm.auto import tqdm | |
from huggingface_hub import snapshot_download | |
from gradio.components import Button | |
from gradio.events import Dependency, EventListenerMethod | |
from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizer | |
import types | |
import sys | |
from .base_engine import BaseEngine | |
from .transformers_engine import TransformersEngine, NewGenerationMixin | |
from ..configs import ( | |
STREAM_CHECK_MULTIPLE, | |
STREAM_YIELD_MULTIPLE, | |
) | |
CODE_PATH = os.environ.get("CODE_PATH", "") | |
MODEL_PATH = os.environ.get("MODEL_PATH", "") | |
IMAGE_TOKEN = "[IMAGE]<|image|>[/IMAGE]" | |
IMAGE_LENGTH = 576 | |
MAX_PACHES = 1 | |
BLOCK_LANGS = str(os.environ.get("BLOCK_LANGS", "")) | |
BLOCK_LANGS = [x.strip() for x in BLOCK_LANGS.strip().split(";")] if len(BLOCK_LANGS.strip()) > 0 else [] | |
LANG_BLOCK_HISTORY = bool(int(os.environ.get("LANG_BLOCK_HISTORY", "0"))) | |
KEYWORDS = os.environ.get("KEYWORDS", "").strip() | |
KEYWORDS = KEYWORDS.split(";") if len(KEYWORDS) > 0 else [] | |
KEYWORDS = [x.lower() for x in KEYWORDS] | |
LANG_BLOCK_MESSAGE = """Unsupported language.""" | |
KEYWORD_BLOCK_MESSAGE = "Invalid request." | |
def _detect_lang(text): | |
# Disable language that may have safety risk | |
from langdetect import detect as detect_lang | |
dlang = None | |
try: | |
dlang = detect_lang(text) | |
except Exception as e: | |
if "No features in text." in str(e): | |
return "en" | |
else: | |
return "zh" | |
return dlang | |
def block_lang( | |
message: str, | |
history: List[Tuple[str, str]] = None, | |
) -> str: | |
# relieve history base block | |
if len(BLOCK_LANGS) == 0: | |
return False | |
if LANG_BLOCK_HISTORY and history is not None and any((LANG_BLOCK_MESSAGE in x[1].strip()) for x in history): | |
return True | |
else: | |
_lang = _detect_lang(message) | |
if _lang in BLOCK_LANGS: | |
# print(f'Detect blocked {_lang}: {message}') | |
return True | |
else: | |
return False | |
def safety_check(text, history=None, ) -> Optional[str]: | |
""" | |
Despite our effort in safety tuning and red teaming, our models may still generate harmful or illegal content. | |
This provides an additional security measure to enhance safety and compliance with local regulations. | |
""" | |
if len(KEYWORDS) > 0 and any(x in text.lower() for x in KEYWORDS): | |
return KEYWORD_BLOCK_MESSAGE | |
if len(BLOCK_LANGS) > 0: | |
if block_lang(text, history): | |
return LANG_BLOCK_MESSAGE | |
return None | |
def safety_check_conversation_string(text, delimiter=None) -> Optional[str]: | |
if len(KEYWORDS) > 0 and any(x in text.lower() for x in KEYWORDS): | |
return KEYWORD_BLOCK_MESSAGE | |
if len(BLOCK_LANGS) > 0: | |
import re | |
delimiter = delimiter or (r"</s><\|im_start\|>user\n", r"</s><\|im_start\|>assistant\n", r"<\|im_start\|>system\n") | |
turns = re.split(r"|".join(delimiter), text) | |
turns = [t for t in turns if t.strip() != ''] | |
for t in turns: | |
if block_lang(t): | |
return LANG_BLOCK_MESSAGE | |
return None | |
def is_check_safety(): | |
return len(KEYWORDS) > 0 or len(BLOCK_LANGS) > 0 | |
def safety_check_conversation(conversation) -> Optional[str]: | |
""" | |
Despite our effort in safety tuning and red teaming, our models may still generate harmful or illegal content. | |
This provides an additional security measure to enhance safety and compliance with local regulations. | |
""" | |
texts = [c['content'] for c in conversation] | |
for text in texts: | |
if len(KEYWORDS) > 0 and any(x in text.lower() for x in KEYWORDS): | |
return KEYWORD_BLOCK_MESSAGE | |
if len(BLOCK_LANGS) > 0: | |
if block_lang(text): | |
return LANG_BLOCK_MESSAGE | |
return None | |
class SeaLMMMv0Engine(TransformersEngine): | |
def image_token(self): | |
return IMAGE_TOKEN | |
def max_position_embeddings(self) -> int: | |
return self._model.config.max_position_embeddings | |
def tokenizer(self): | |
return self._tokenizer | |
def processor(self): | |
return self._processor | |
def load_model(self): | |
from transformers import AutoProcessor | |
import sys | |
# caution: path[0] is reserved for script path (or '' in REPL) | |
# sys.path.append(CODE_PATH) | |
# from examples.llm.src.models.sealmm.modeling_sealmm import ( | |
# SeaLMMForCausalLM | |
# ) | |
from .modeling_sealmm import (SeaLMMForCausalLM, ) | |
model_path = MODEL_PATH | |
print(f'Loading model from {model_path}') | |
print(f'model_path={model_path}') | |
if os.path.exists(f"{model_path}/pytorch_model_fsdp.bin") and not os.path.exists(f"{model_path}/pytorch_model.bin"): | |
os.symlink("pytorch_model_fsdp.bin", f"{model_path}/pytorch_model.bin") | |
self._processor = AutoProcessor.from_pretrained(model_path) | |
self._model = SeaLMMForCausalLM.from_pretrained(model_path, torch_dtype=torch.bfloat16, device_map="cuda").eval() | |
self._model.sample_old = self._model.sample | |
self._model.sample = types.MethodType(NewGenerationMixin.sample_stream, self._model) | |
self._tokenizer = self._processor.tokenizer | |
print(self._model) | |
print(f"{self.max_position_embeddings=}") | |
def get_multimodal_tokens(self, full_prompt, image_paths=None): | |
num_tokens = len(self.tokenizer.encode(full_prompt)) | |
for image_path in image_paths: | |
num_tokens += IMAGE_LENGTH * MAX_PACHES | |
return num_tokens | |
def maybe_raise_safety(self, message, gen_index=-1): | |
if is_check_safety(): | |
if gen_index < 0: | |
message_safety = safety_check_conversation_string(message) | |
if message_safety is not None: | |
raise gr.Error(message_safety) | |
else: | |
if STREAM_CHECK_MULTIPLE > 0 and gen_index % STREAM_CHECK_MULTIPLE == 0: | |
message_safety = safety_check_conversation_string(message) | |
if message_safety is not None: | |
raise gr.Error(message_safety) | |
def generate_yield_string(self, prompt, temperature, max_tokens, stop_strings: Optional[Tuple[str]] = None, **kwargs): | |
from transformers.generation.utils import GenerationConfig | |
from PIL import Image | |
image_paths = kwargs.get("image_paths", None) | |
image_paths = image_paths or [] | |
images = [Image.open(x) for x in image_paths] if len(image_paths) > 0 else None | |
with torch.no_grad(): | |
inputs = self.processor(prompt, images, return_tensors='pt') | |
# inputs = {k: v.to("cuda", torch.bfloat16) for k, v in inputs.items() if v is not None} | |
# model.device | |
inputs = {k: v.to(self._model.device) for k, v in inputs.items() if v is not None} | |
num_tokens = self.get_multimodal_tokens(prompt, image_paths) | |
# non-streaming generation | |
# output = self._model.generate( | |
# **inputs, | |
# do_sample=True, | |
# temperature=temperature, | |
# max_new_tokens=max_tokens, | |
# pad_token_id=self.processor.tokenizer.pad_token_id, | |
# ) | |
# # response = self.processor.tokenizer.decode(output[0][-inputs.input_ids.size(-1):], skip_special_tokens=True) | |
# full_output_text = self.processor.decode(output[0], skip_special_tokens=True) | |
# response = full_output_text.split("<|im_start|>assistant\n")[-1] | |
# num_tokens = self.get_multimodal_tokens(prompt + response, image_paths) | |
# print(prompt) | |
# print(response) | |
# print(num_tokens) | |
# yield response, num_tokens | |
# if i % 4 == 0 and i > 1: | |
# message_safety = safety_check(response) | |
# if message_safety is not None: | |
# history = undo_history(history) | |
# yield history, "", None | |
# raise gr.Error(message_safety) | |
self.maybe_raise_safety(prompt) | |
# # ! streaming | |
generator = self._model.generate( | |
**inputs, | |
do_sample=True, | |
temperature=temperature, | |
max_new_tokens=max_tokens, | |
pad_token_id=self.processor.tokenizer.pad_token_id, | |
) | |
out_tokens = [] | |
response = None | |
for index, token in enumerate(generator): | |
out_tokens.append(token.item()) | |
response = self.processor.tokenizer.decode(out_tokens) | |
self.maybe_raise_safety(response, gen_index=index) | |
yield response, num_tokens | |
del generator | |
if response is not None: | |
self.maybe_raise_safety(prompt) | |
full_text = prompt + response | |
num_tokens = self.get_multimodal_tokens(full_text, image_paths) | |
yield response, num_tokens | |