|
from datetime import datetime |
|
from typing import Dict, Union, Optional |
|
|
|
import deepspeed |
|
import torch |
|
import PIL.Image |
|
from torch.nn.functional import softmax, gumbel_softmax |
|
from torch import Tensor |
|
from transformers import PretrainedConfig, PreTrainedModel, AutoImageProcessor, AutoConfig, AutoModel |
|
from transformers import CLIPVisionModel, CLIPImageProcessor |
|
from transformers.integrations import is_deepspeed_zero3_enabled |
|
|
|
from .utils import BEGIN_LINE, END_LINE, rank0_print |
|
|
|
MODEL_TYPE = "clip_visual_tokenizer" |
|
|
|
|
|
class BaseVisualTokenizerConfig(PretrainedConfig): |
|
def __init__(self, |
|
vocab_size=16384, |
|
tokenize_function="softmax", |
|
tau=1.0, |
|
depths=None, |
|
use_indicators=False, |
|
drop_cls_token=False, |
|
backbone_config: Optional[Union[PretrainedConfig, dict]] = None, |
|
hidden_stride: int = 1, |
|
**kwargs): |
|
super().__init__(**kwargs) |
|
self.vocab_size = vocab_size |
|
self.tokenize_function = tokenize_function |
|
self.tau = tau |
|
if isinstance(depths, str): |
|
depths = [int(x) for x in depths.split('|')] |
|
self.depths = depths |
|
self.backbone_kwargs = {} |
|
self.use_indicators = use_indicators |
|
self.drop_cls_token = drop_cls_token |
|
if backbone_config is not None: |
|
assert isinstance(backbone_config, (PretrainedConfig, dict)), \ |
|
f"expect `backbone_config` to be instance of PretrainedConfig or dict, but got {type(backbone_config)} type" |
|
if not isinstance(backbone_config, PretrainedConfig): |
|
model_type = backbone_config['model_type'] |
|
backbone_config.pop('model_type') |
|
backbone_config = AutoConfig.for_model(model_type, **backbone_config) |
|
self.backbone_config = backbone_config |
|
self.hidden_stride = hidden_stride |
|
|
|
|
|
class BaseVisualTokenizer(PreTrainedModel): |
|
base_model_prefix = "backbone" |
|
main_input_name = None |
|
_image_processor_class = None |
|
_image_processor_kwargs = {} |
|
_backbone_class = None |
|
_backbone_name_or_path = None |
|
|
|
def __init__(self, config: BaseVisualTokenizerConfig, *inputs, **kwargs): |
|
super().__init__(config, *inputs, **kwargs) |
|
if kwargs.get('train_from_scratch'): |
|
self.image_processor = self._image_processor_class.from_pretrained(self._backbone_name_or_path, |
|
**self._image_processor_kwargs) |
|
self.backbone = self._backbone_class.from_pretrained(self._backbone_name_or_path, |
|
**self.config.backbone_kwargs) |
|
self.config.backbone_config = self.backbone.config |
|
else: |
|
self.image_processor = AutoImageProcessor.from_pretrained(kwargs['image_processor_name_or_path']) |
|
self.backbone = AutoModel.from_config(self.config.backbone_config) |
|
self.head = None |
|
|
|
assert all((self.image_processor.do_resize, |
|
not getattr(self.image_processor, 'do_center_crop', False), |
|
self.image_processor.do_rescale, |
|
self.image_processor.do_normalize |
|
)), f"image_processor `{self.image_processor}` is not supported currently" |
|
|
|
def get_backbone(self): |
|
return self.backbone |
|
|
|
def get_monitor_tensors(self): |
|
raise NotImplementedError |
|
|
|
def get_image_processor(self): |
|
return self.image_processor |
|
|
|
def get_head(self): |
|
return self.head |
|
|
|
def get_image_size(self): |
|
raise NotImplementedError |
|
|
|
def preprocess_image(self, image: PIL.Image.Image, convert_to_rgb=True): |
|
if convert_to_rgb and image.mode != 'RGB': |
|
image = image.convert('RGB') |
|
|
|
|
|
sides = self.get_image_size() |
|
if sides[0] != sides[1]: |
|
raise ValueError('get_image_size() returns non-square size') |
|
side = sides[0] |
|
|
|
width, height = image.size |
|
if width == height: |
|
new_width = new_height = side |
|
elif width > height: |
|
new_width = side |
|
new_height = int(height / width * new_width) |
|
else: |
|
new_height = side |
|
new_width = int(width / height * new_height) |
|
new_size = dict(height=new_height, width=new_width) |
|
pixel_values = self.image_processor.preprocess(image, size=new_size, return_tensors='pt')['pixel_values'] |
|
|
|
|
|
square_values = torch.zeros([1, 3, side, side], dtype=pixel_values.dtype, device=pixel_values.device) |
|
new_height, new_width = pixel_values.shape[2:] |
|
if new_height == new_width: |
|
square_values[:, :, :, :] = pixel_values |
|
elif new_height > new_width: |
|
from_index = (side - new_width) // 2 |
|
square_values[:, :, :, from_index:from_index + new_width] = pixel_values |
|
else: |
|
from_index = (side - new_height) // 2 |
|
square_values[:, :, from_index:from_index + new_height, :] = pixel_values |
|
|
|
return square_values |
|
|
|
def get_layer_norm(self): |
|
return self.layer_norm |
|
|
|
def tokenize(self, logits): |
|
def st_argmax(y_soft, dim): |
|
index = y_soft.max(dim, keepdim=True)[1] |
|
y_hard = torch.zeros_like(y_soft, memory_format=torch.legacy_contiguous_format).scatter_(dim, index, 1.0) |
|
ret = y_hard - y_soft.detach() + y_soft |
|
return ret |
|
|
|
if self.config.tokenize_function == 'softmax': |
|
tokens = softmax(logits, dim=-1) |
|
elif self.config.tokenize_function == 'gumbel_argmax': |
|
tokens = gumbel_softmax(logits, tau=self.config.tau, hard=True) |
|
elif self.config.tokenize_function == 'st_argmax': |
|
tokens = st_argmax(logits, dim=-1) |
|
else: |
|
raise ValueError( |
|
f'Invalid `max_type`, expected softmax or gumbel_argmax or st_argmax, but got {self.config.tokenize_function}') |
|
return tokens |
|
|
|
|
|
|
|
class ClipVisualTokenizerConfig(BaseVisualTokenizerConfig): |
|
model_type = MODEL_TYPE |
|
|
|
def __init__(self, **kwargs): |
|
super().__init__(**kwargs) |
|
if self.depths: |
|
assert len(self.depths) == 1 |
|
self.backbone_kwargs['num_hidden_layers'] = self.depths[0] |
|
|
|
|
|
class ClipVisualTokenizer(BaseVisualTokenizer): |
|
config_class = ClipVisualTokenizerConfig |
|
supports_gradient_checkpointing = True |
|
_no_split_modules = ["CLIPEncoderLayer"] |
|
_image_processor_class = CLIPImageProcessor |
|
_image_processor_kwargs = dict(do_center_crop=False) |
|
_backbone_class = CLIPVisionModel |
|
_backbone_name_or_path = "openai/clip-vit-large-patch14-336" |
|
|
|
def __init__(self, config: ClipVisualTokenizerConfig = None, *inputs, **kwargs): |
|
super().__init__(config, *inputs, **kwargs) |
|
head_dim = self.config.vocab_size |
|
if self.config.use_indicators: |
|
head_dim -= 2 |
|
self.head = torch.nn.Sequential( |
|
torch.nn.Linear(self.backbone.config.hidden_size, head_dim, bias=False), |
|
torch.nn.LayerNorm(head_dim) |
|
) |
|
|
|
def re_init_layers(self, re_init_layer_begin): |
|
layer_dict = self.get_re_init_layer_dict(re_init_layer_begin) |
|
for name, layer in layer_dict.items(): |
|
rank0_print(BEGIN_LINE) |
|
rank0_print(f'[{datetime.now()}] Before layer re-initialization of {name}: ') |
|
for k, v in layer.named_parameters(): |
|
with deepspeed.zero.GatheredParameters([v]): |
|
rank0_print(f'{k}: {v}') |
|
with deepspeed.zero.GatheredParameters(list(layer.parameters(recurse=True)), modifier_rank=0): |
|
if not is_deepspeed_zero3_enabled() or deepspeed.comm.get_rank() == 0: |
|
layer.apply(self.backbone._init_weights) |
|
rank0_print(f'[{datetime.now()}] After layer re-initialization of {name}:') |
|
for k, v in layer.named_parameters(): |
|
with deepspeed.zero.GatheredParameters([v]): |
|
rank0_print(f'{k}: {v}') |
|
rank0_print(END_LINE) |
|
|
|
def get_re_init_layer_dict(self, re_init_layer_begin: int) -> Dict[str, torch.nn.Module]: |
|
assert re_init_layer_begin >= 0, "negative index is prohibited" |
|
layer_dict = dict() |
|
for i in range(re_init_layer_begin, self.backbone.config.num_hidden_layers): |
|
layer_dict[f'backbone.vision_model.encoder.layers.{i}'] = self.backbone.vision_model.encoder.layers[i] |
|
return layer_dict |
|
|
|
def get_monitor_tensors(self): |
|
return dict( |
|
backbone_bottom=self.backbone.vision_model.encoder.layers[0].self_attn.k_proj.weight, |
|
backbone_top=self.backbone.vision_model.encoder.layers[-1].self_attn.out_proj.weight, |
|
head=self.head[0].weight |
|
) |
|
|
|
def get_image_size(self): |
|
height = self.image_processor.crop_size["height"] |
|
width = self.image_processor.crop_size["width"] |
|
return height, width |
|
|
|
def forward(self, pixel_values) -> Tensor: |
|
output = self.backbone( |
|
pixel_values, output_hidden_states=True, return_dict=True) |
|
features = output.last_hidden_state |
|
if self.config.drop_cls_token: |
|
features = features[:, 1:, :] |
|
logits = self.head(features) |
|
tokens = self.tokenize(logits) |
|
if self.config.use_indicators: |
|
|
|
|
|
batch_size, token_len, _ = tokens.shape |
|
padding_tensor = torch.zeros(size=(batch_size, token_len, 2), |
|
dtype=tokens.dtype, |
|
device=tokens.device, |
|
layout=tokens.layout, |
|
requires_grad=False) |
|
tokens = torch.cat((tokens, padding_tensor), dim=2) |
|
|
|
|
|
begin_indicator = torch.zeros(size=(batch_size, 1), |
|
dtype=torch.long, |
|
device=tokens.device, |
|
requires_grad=False) + self.config.vocab_size - 2 |
|
begin_indicator_token = torch.nn.functional.one_hot(begin_indicator, |
|
num_classes=self.config.vocab_size).to( |
|
dtype=tokens.dtype) |
|
end_indicator = torch.zeros(size=(batch_size, 1), |
|
dtype=torch.long, |
|
device=tokens.device, |
|
requires_grad=False) + self.config.vocab_size - 1 |
|
end_indicator_token = torch.nn.functional.one_hot(end_indicator, |
|
num_classes=self.config.vocab_size).to(dtype=tokens.dtype) |
|
tokens = torch.cat((begin_indicator_token, tokens, end_indicator_token), dim=1) |
|
return tokens |
|
|
|
|
|
AutoConfig.register(MODEL_TYPE, ClipVisualTokenizerConfig) |
|
AutoModel.register(ClipVisualTokenizerConfig, ClipVisualTokenizer) |
|
|