Spaces:
Running
on
Zero
Running
on
Zero
from contextlib import nullcontext | |
from dataclasses import dataclass | |
from typing import List, Optional, Tuple, Union | |
import torch | |
import torch.utils.checkpoint | |
from torch import nn | |
from transformers import PreTrainedModel | |
from transformers.activations import ACT2FN | |
from transformers.cache_utils import Cache | |
from transformers.modeling_outputs import ModelOutput | |
from transformers.models.clip.configuration_clip import CLIPConfig | |
from transformers.utils import ( | |
add_start_docstrings, | |
add_start_docstrings_to_model_forward, | |
logging, | |
replace_return_docstrings, | |
) | |
from transformers import AutoModel, AutoModelForCausalLM | |
from transformers.models.llava.configuration_llava import LlavaConfig | |
from transformers.models.llava.modeling_llava import ( | |
LlavaCausalLMOutputWithPast, | |
LlavaMultiModalProjector, | |
LlavaPreTrainedModel, | |
LLAVA_START_DOCSTRING, | |
LLAVA_INPUTS_DOCSTRING, | |
LlavaForConditionalGeneration, | |
) | |
from transformers.models.blip_2.configuration_blip_2 import ( | |
Blip2Config, | |
Blip2QFormerConfig, | |
) | |
import os | |
from transformers.models.blip_2.modeling_blip_2 import ( | |
Blip2Config, | |
Blip2QFormerModel, | |
Blip2PreTrainedModel, | |
BLIP_2_INPUTS_DOCSTRING, | |
) | |
from transformers.utils.import_utils import is_flash_attn_greater_or_equal_2_10 | |
# from .configuration_sealmm import SeaLMMConfig | |
logger = logging.get_logger(__name__) | |
# _CONFIG_FOR_DOC = "LlavaConfig" | |
_CONFIG_FOR_DOC = "SeaLMMConfig" | |
class SeaLMMConfig(LlavaConfig): | |
def __init__(self, *args, **kwargs): | |
self.projector_num_layers = kwargs.get("projector_num_layers", 1) | |
super().__init__(*args, **kwargs) | |
""" | |
Llava | |
vision_config.num_hidden_layers = vision_config.num_hidden_layers + config.vision_feature_layer + 1 | |
# "num_hidden_layers": 24, | |
""" | |
IMAGE_TOKEN = "<|image|>" | |
DEBUG = bool(int(os.environ.get("DEBUG", "0"))) | |
def by_sample_merge_input_ids_with_image_features( | |
self, image_features, inputs_embeds, input_ids, attention_mask=None, position_ids=None | |
): | |
""" | |
input_ids: [tlen] | |
input_embeds: [tlen, dt] | |
img_embeds: [ilen, ifeat, di] | |
e.g: | |
input_ids: [ | |
a b c d e f X g h i j k X l m | |
] | |
img_embeds: [3, ifeat, id] # img_embeds has padding | |
""" | |
num_images, num_image_patches, embed_dim = image_features.shape | |
sequence_length = input_ids.size(0) | |
left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.pad_token_id)) | |
assert not left_padding, f'should only use right padding' | |
# 1. Create a mask to know where special image tokens are | |
special_image_token_mask = input_ids == self.config.image_token_index | |
num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1) | |
# Compute the maximum embed dimension | |
max_embed_dim = (num_special_image_tokens.max() * (num_image_patches - 1)) + sequence_length | |
from transformers.models.clip.configuration_clip import CLIPConfig, CLIPTextConfig, CLIPVisionConfig | |
from transformers.models.clip.modeling_clip import ( | |
contrastive_loss, | |
clip_loss, | |
CLIPVisionModelOutput, | |
CLIPTextModelOutput, | |
CLIPOutput, | |
CLIPTextEmbeddings, | |
CLIPVisionEmbeddings, | |
CLIPAttention, | |
CLIPMLP, | |
CLIPEncoderLayer, | |
CLIPPreTrainedModel, | |
CLIPTextTransformer, | |
CLIPTextModel, | |
CLIPVisionTransformer, | |
CLIPVisionModel, | |
CLIPModel, | |
CLIPEncoder, | |
CLIPTextModelWithProjection, | |
CLIPVisionModelWithProjection, | |
CLIP_START_DOCSTRING, | |
CLIP_TEXT_INPUTS_DOCSTRING, | |
CLIP_VISION_INPUTS_DOCSTRING, | |
CLIP_INPUTS_DOCSTRING, | |
) | |
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling | |
# Copied from transformers.models.llama.modeling_llama._get_unpad_data | |
def _get_unpad_data(attention_mask): | |
import torch.nn.functional as F | |
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) | |
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() | |
max_seqlen_in_batch = seqlens_in_batch.max().item() | |
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) | |
return ( | |
indices, | |
cu_seqlens, | |
max_seqlen_in_batch, | |
) | |
class CLIPFlashAttention2(CLIPAttention): | |
""" | |
CLIP flash attention module. This module inherits from `CLIPAttention` as the weights of the module stays | |
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of | |
flash attention and deal with padding tokens in case the input contains any of them. | |
""" | |
def __init__(self, config, is_causal=False): | |
super().__init__(config) | |
self.is_causal = is_causal | |
def forward( | |
self, | |
hidden_states: torch.Tensor, | |
attention_mask: Optional[torch.Tensor] = None, | |
causal_attention_mask: Optional[torch.Tensor] = None, | |
output_attentions: Optional[bool] = False, | |
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: | |
"""Input shape: Batch x Time x Channel""" | |
if output_attentions: | |
raise ValueError("CLIPFlashAttention2 does not support output_attentions") | |
if self.is_causal and causal_attention_mask is None: | |
raise ValueError("CLIPFlashAttention2 has causal=True but no causal_attention_mask provided") | |
bsz, tgt_len, embed_dim = hidden_states.size() | |
# [batch_size, tgt_len, embed_dim] | |
query_states = self.q_proj(hidden_states) | |
key_states = self.k_proj(hidden_states) | |
value_states = self.v_proj(hidden_states) | |
# [batch_size, tgt_len, embed_dim] -> [batch_size, tgt_len, num_heads, head_dim] | |
query_states = query_states.view(bsz, tgt_len, self.num_heads, self.head_dim).contiguous() | |
key_states = key_states.view(bsz, tgt_len, self.num_heads, self.head_dim).contiguous() | |
value_states = value_states.view(bsz, tgt_len, self.num_heads, self.head_dim).contiguous() | |
attn_output = self._flash_attention_forward( | |
query_states=query_states, | |
key_states=key_states, | |
value_states=value_states, | |
attention_mask=attention_mask, | |
query_length=tgt_len, | |
dropout=self.dropout, | |
softmax_scale=self.scale, | |
) | |
# [batch_size, tgt_len, num_heads, head_dim] -> [batch_size, tgt_len, embed_dim] | |
attn_output = attn_output.view(bsz, tgt_len, embed_dim) | |
attn_output = self.out_proj(attn_output) | |
return attn_output, None | |
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward | |
def _flash_attention_forward( | |
self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None | |
) -> torch.Tensor: | |
""" | |
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token | |
first unpad the input, then computes the attention scores and pad the final attention scores. | |
Args: | |
query_states (`torch.Tensor`): | |
Input query states to be passed to Flash Attention API | |
key_states (`torch.Tensor`): | |
Input key states to be passed to Flash Attention API | |
value_states (`torch.Tensor`): | |
Input value states to be passed to Flash Attention API | |
attention_mask (`torch.Tensor`): | |
The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the | |
position of padding tokens and 1 for the position of non-padding tokens. | |
dropout (`int`, *optional*): | |
Attention dropout | |
softmax_scale (`float`, *optional*): | |
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) | |
""" | |
from flash_attn import flash_attn_func, flash_attn_varlen_func | |
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa | |
# Contains at least one padding token in the sequence | |
if attention_mask is not None: | |
batch_size = query_states.shape[0] | |
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( | |
query_states, key_states, value_states, attention_mask, query_length | |
) | |
cu_seqlens_q, cu_seqlens_k = cu_seq_lens | |
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens | |
attn_output_unpad = flash_attn_varlen_func( | |
query_states, | |
key_states, | |
value_states, | |
cu_seqlens_q=cu_seqlens_q, | |
cu_seqlens_k=cu_seqlens_k, | |
max_seqlen_q=max_seqlen_in_batch_q, | |
max_seqlen_k=max_seqlen_in_batch_k, | |
dropout_p=dropout, | |
softmax_scale=softmax_scale, | |
causal=self.is_causal, | |
) | |
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) | |
else: | |
attn_output = flash_attn_func( | |
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=self.is_causal | |
) | |
return attn_output | |
def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): | |
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa | |
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) | |
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape | |
key_layer = index_first_axis( | |
key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k | |
) | |
value_layer = index_first_axis( | |
value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k | |
) | |
if query_length == kv_seq_len: | |
query_layer = index_first_axis( | |
query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k | |
) | |
cu_seqlens_q = cu_seqlens_k | |
max_seqlen_in_batch_q = max_seqlen_in_batch_k | |
indices_q = indices_k | |
elif query_length == 1: | |
max_seqlen_in_batch_q = 1 | |
# There is a memcpy here, that is very bad. | |
cu_seqlens_q = torch.arange(batch_size + 1, dtype=torch.int32, device=query_layer.device) | |
indices_q = cu_seqlens_q[:-1] | |
query_layer = query_layer.squeeze(1) | |
else: | |
# The :q_len slice assumes right padding. | |
attention_mask = attention_mask[:, :query_length] | |
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) | |
return ( | |
query_layer, | |
key_layer, | |
value_layer, | |
indices_q, | |
(cu_seqlens_q, cu_seqlens_k), | |
(max_seqlen_in_batch_q, max_seqlen_in_batch_k), | |
) | |
class SeaLMMCLIPEncoderLayer(CLIPEncoderLayer): | |
def __init__(self, config: CLIPConfig): | |
super(CLIPEncoderLayer, self).__init__() | |
self.embed_dim = config.hidden_size | |
# self.self_attn = LlavaCLIPFlashAttention(config) | |
if is_flash_attn_greater_or_equal_2_10(): | |
self.self_attn = CLIPFlashAttention2(config) | |
else: | |
self.self_attn = CLIPAttention(config) | |
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) | |
self.mlp = CLIPMLP(config) | |
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) | |
class SeaLMMCLIPEncoder(CLIPEncoder): | |
""" | |
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a | |
[`CLIPEncoderLayer`]. | |
Args: | |
config: CLIPConfig | |
""" | |
def __init__(self, config: CLIPConfig): | |
super(CLIPEncoder, self).__init__() | |
self.config = config | |
self.layers = nn.ModuleList([SeaLMMCLIPEncoderLayer(config) for _ in range(config.num_hidden_layers)]) | |
self.gradient_checkpointing = False | |
def forward( | |
self, | |
inputs_embeds, | |
attention_mask: Optional[torch.Tensor] = None, | |
causal_attention_mask: Optional[torch.Tensor] = None, | |
output_attentions: Optional[bool] = None, | |
output_hidden_states: Optional[bool] = None, | |
return_dict: Optional[bool] = None, | |
) -> Union[Tuple, BaseModelOutput]: | |
r""" | |
Args: | |
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): | |
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. | |
This is useful if you want more control over how to convert `input_ids` indices into associated vectors | |
than the model's internal embedding lookup matrix. | |
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): | |
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: | |
- 1 for tokens that are **not masked**, | |
- 0 for tokens that are **masked**. | |
[What are attention masks?](../glossary#attention-mask) | |
causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): | |
Causal mask for the text model. Mask values selected in `[0, 1]`: | |
- 1 for tokens that are **not masked**, | |
- 0 for tokens that are **masked**. | |
[What are attention masks?](../glossary#attention-mask) | |
output_attentions (`bool`, *optional*): | |
Whether or not to return the attentions tensors of all attention layers. See `attentions` under | |
returned tensors for more detail. | |
output_hidden_states (`bool`, *optional*): | |
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors | |
for more detail. | |
return_dict (`bool`, *optional*): | |
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. | |
""" | |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | |
output_hidden_states = ( | |
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | |
) | |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
output_hidden_states = False | |
output_attentions = False | |
# return_dict = False | |
encoder_states = () if output_hidden_states else None | |
all_attentions = () if output_attentions else None | |
hidden_states = inputs_embeds | |
for idx, encoder_layer in enumerate(self.layers): | |
if output_hidden_states: | |
encoder_states = encoder_states + (hidden_states,) | |
# if self.gradient_checkpointing and self.training: | |
# layer_outputs = self._gradient_checkpointing_func( | |
# encoder_layer.__call__, | |
# hidden_states, | |
# attention_mask, | |
# causal_attention_mask, | |
# output_attentions, | |
# ) | |
# else: | |
# ! enforce no checkpointing here | |
layer_outputs = encoder_layer( | |
hidden_states, | |
attention_mask, | |
causal_attention_mask, | |
output_attentions=output_attentions, | |
) | |
hidden_states = layer_outputs[0] | |
if output_attentions: | |
all_attentions = all_attentions + (layer_outputs[1],) | |
if output_hidden_states: | |
encoder_states = encoder_states + (hidden_states,) | |
if not return_dict: | |
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) | |
return BaseModelOutput( | |
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions | |
) | |
class SeaLMMVisionTransformer(nn.Module): | |
def __init__(self, config: CLIPVisionConfig): | |
super().__init__() | |
self.config = config | |
embed_dim = config.hidden_size | |
self.embeddings = CLIPVisionEmbeddings(config) | |
self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) | |
# self.encoder = CLIPEncoder(config) | |
self.encoder = SeaLMMCLIPEncoder(config) | |
# self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) | |
def forward( | |
self, | |
pixel_values: Optional[torch.FloatTensor] = None, | |
output_attentions: Optional[bool] = None, | |
output_hidden_states: Optional[bool] = None, | |
return_dict: Optional[bool] = None, | |
) -> Union[Tuple, BaseModelOutputWithPooling]: | |
r""" | |
Returns: | |
""" | |
assert output_attentions is None | |
assert output_hidden_states is None | |
# assert return_dict is None | |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | |
output_hidden_states = ( | |
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | |
) | |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
if pixel_values is None: | |
raise ValueError("You have to specify pixel_values") | |
hidden_states = self.embeddings(pixel_values) | |
hidden_states = self.pre_layrnorm(hidden_states) | |
encoder_outputs = self.encoder( | |
inputs_embeds=hidden_states, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict, | |
) | |
last_hidden_state = encoder_outputs[0] | |
if not return_dict: | |
raise ValueError(f'Not support return_dict') | |
return BaseModelOutputWithPooling( | |
last_hidden_state=last_hidden_state, | |
# pooler_output=pooled_output, | |
pooler_output=None, | |
hidden_states=encoder_outputs.hidden_states, | |
attentions=encoder_outputs.attentions, | |
) | |
class SeaLMMCLIPVisionModel(CLIPPreTrainedModel): | |
config_class = CLIPVisionConfig | |
main_input_name = "pixel_values" | |
_no_split_modules = ["SeaLMMCLIPEncoderLayer"] | |
def __init__(self, config: CLIPVisionConfig): | |
super().__init__(config) | |
self.vision_model = SeaLMMVisionTransformer(config) | |
# Initialize weights and apply final processing | |
self.post_init() | |
def get_input_embeddings(self) -> nn.Module: | |
return self.vision_model.embeddings.patch_embedding | |
def forward( | |
self, | |
pixel_values: Optional[torch.FloatTensor] = None, | |
output_attentions: Optional[bool] = None, | |
output_hidden_states: Optional[bool] = None, | |
return_dict: Optional[bool] = None, | |
) -> Union[Tuple, BaseModelOutputWithPooling]: | |
r""" | |
Returns: | |
Examples: | |
```python | |
>>> from PIL import Image | |
>>> import requests | |
>>> from transformers import AutoProcessor, CLIPVisionModel | |
>>> model = CLIPVisionModel.from_pretrained("openai/clip-vit-base-patch32") | |
>>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32") | |
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" | |
>>> image = Image.open(requests.get(url, stream=True).raw) | |
>>> inputs = processor(images=image, return_tensors="pt") | |
>>> outputs = model(**inputs) | |
>>> last_hidden_state = outputs.last_hidden_state | |
>>> pooled_output = outputs.pooler_output # pooled CLS states | |
```""" | |
# return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
return self.vision_model( | |
pixel_values=pixel_values, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict, | |
) | |
class SeaLMMMultiModalProjector(SeaLMMCLIPEncoder): | |
def __init__(self, config: SeaLMMConfig): | |
super(CLIPEncoder, self).__init__() | |
self.config = config | |
self.projector_num_layers = getattr(config, "projector_num_layers", 2) | |
self.vision_config = config.vision_config | |
self.num_vision_feature_layer = int(0 - config.vision_feature_layer) - 1 | |
assert self.num_vision_feature_layer > 0 | |
self.layers = nn.ModuleList([ | |
# LlavaCLIPFasterEncoderLayer(self.vision_config) | |
SeaLMMCLIPEncoderLayer(self.vision_config) | |
for _ in range(self.projector_num_layers)] | |
) | |
projector_layernorm_eps = getattr(config, "projector_layernorm_eps", 1e-05) | |
self.projector_layernorm = nn.LayerNorm( | |
# len(config.vision_feature_layers) * config.vision_config.hidden_size, eps=projector_layernorm_eps | |
config.vision_config.hidden_size, eps=projector_layernorm_eps | |
) | |
self.linear_1 = nn.Linear( | |
# len(config.vision_feature_layers) * config.vision_config.hidden_size, | |
config.vision_config.hidden_size, | |
config.text_config.hidden_size, | |
bias=True, | |
) | |
# self.act = ACT2FN[config.projector_hidden_act] | |
# self.linear_2 = nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size, bias=True) | |
self.gradient_checkpointing = False | |
def forward(self, hidden_states, attention_mask=None, causal_attention_mask=None): | |
""" | |
hidden_states must not be striped | |
""" | |
output_attentions = False | |
for idx, encoder_layer in enumerate(self.layers): | |
# if output_hidden_states: | |
# encoder_states = encoder_states + (hidden_states,) | |
# if self.gradient_checkpointing and self.training: | |
# layer_outputs = self._gradient_checkpointing_func( | |
# encoder_layer.__call__, | |
# hidden_states, | |
# attention_mask, | |
# causal_attention_mask, | |
# output_attentions, | |
# ) | |
# else: | |
# ! turn off checkpointing | |
layer_outputs = encoder_layer( | |
hidden_states, | |
attention_mask, | |
causal_attention_mask, | |
output_attentions=output_attentions, | |
) | |
hidden_states = layer_outputs[0] | |
hidden_states = hidden_states[:, 1:] | |
hidden_states = self.projector_layernorm(hidden_states) | |
hidden_states = self.linear_1(hidden_states) | |
# hidden_states = self.act(hidden_states) | |
# hidden_states = self.linear_2(hidden_states) | |
return hidden_states | |
class SeaLMMForCausalLM(LlavaPreTrainedModel): | |
def __init__(self, config: SeaLMMConfig, vision_tower=None, language_model=None): | |
super().__init__(config) | |
# self.vision_tower = AutoModel.from_config(config.vision_config) | |
# self.vision_tower = vision_tower or LlavaCLIPVisionModel(config=config.vision_config) | |
self.vision_tower = vision_tower or SeaLMMCLIPVisionModel(config=config.vision_config) | |
self.multi_modal_projector = SeaLMMMultiModalProjector(config) | |
# self.vocab_size = config.text_config.vocab_size | |
self.language_model = language_model or AutoModelForCausalLM.from_config( | |
config.text_config, attn_implementation=config._attn_implementation | |
) | |
self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1 | |
self.post_init() | |
self.freeze_vision_tower = True | |
def unfreeze_vision_tower(self): | |
logger.info(f'UNFREEZE {self.freeze_vision_tower=}') | |
self.freeze_vision_tower = False | |
def freeze_vision_tower(self): | |
logger.info(f'FREEZE {self.freeze_vision_tower=}') | |
self.freeze_vision_tower = True | |
def create_model_config_from_components( | |
cls, | |
lm_config=None, | |
vision_config=None, | |
tokenizer=None, | |
vision_feature_layer=None, | |
projector_num_layers=1, | |
**kwargs, | |
) -> SeaLMMConfig: | |
# self.projector_num_layers = kwargs.get("projector_num_layers", 1) | |
config = SeaLMMConfig(vision_config, lm_config, projector_num_layers=projector_num_layers, **kwargs) | |
config.vision_feature_layer = config.vision_feature_layer if vision_feature_layer is None else vision_feature_layer | |
if config.vision_feature_layer < 0: | |
config.vision_config.num_hidden_layers = config.vision_config.num_hidden_layers + config.vision_feature_layer + 1 | |
else: | |
config.vision_config.num_hidden_layers = config.vision_feature_layer + 1 | |
if IMAGE_TOKEN not in tokenizer.get_vocab(): | |
tokenizer.add_special_tokens({"cls_token": IMAGE_TOKEN}) | |
config.image_token_index = tokenizer.cls_token_id | |
config.vocab_size = config.text_config.vocab_size | |
config.architectures = ["SeaLMMForCausalLM"] | |
return config | |
def get_input_embeddings(self): | |
return self.language_model.get_input_embeddings() | |
def set_input_embeddings(self, value): | |
self.language_model.set_input_embeddings(value) | |
def get_output_embeddings(self): | |
return self.language_model.get_output_embeddings() | |
def set_output_embeddings(self, new_embeddings): | |
self.language_model.set_output_embeddings(new_embeddings) | |
def set_decoder(self, decoder): | |
self.language_model.set_decoder(decoder) | |
def get_decoder(self): | |
return self.language_model.get_decoder() | |
def tie_weights(self): | |
return self.language_model.tie_weights() | |
def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None) -> nn.Embedding: | |
model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of) | |
# update vocab size | |
self.config.text_config.vocab_size = model_embeds.num_embeddings | |
self.config.vocab_size = model_embeds.num_embeddings | |
self.vocab_size = model_embeds.num_embeddings | |
return model_embeds | |
# @torch.no_grad | |
def _merge_input_ids_with_image_features( | |
self, image_features, inputs_embeds, input_ids, attention_mask, position_ids, labels=None | |
): | |
""" | |
input_ids: [b, tlen] | |
input_embeds: [b, tlen, dt] | |
image_features: [b, ilen, ifeat, di] | |
labels: None or [b, tlen] --> must extend labels to input_ids, | |
# in input_ids, there may be image_token_index, number of image_token_index <= ilen | |
input_ids: [ | |
a b c d e f X g h i j k X l m | |
o p q r X s t u v _ _ _ _ _ _ | |
] | |
input_ids should be: [ | |
a b c d e f X X X X X g h i j k X X X X X l m | |
o p q r X X X X X s t u v _ _ _ _ _ _ _ _ _ _ | |
] | |
labels should be: [ | |
a b c d e f _ _ _ _ _ g h i j k _ _ _ _ _ l m | |
o p q r _ _ _ _ _ s t u v _ _ _ _ _ _ _ _ _ _ | |
] | |
# mask replace image onto it | |
# Use torch.vmap for simplicy | |
def sample_merge(): | |
input_ids: [tlen] | |
input_embeds: [tlen, dt] | |
img_embeds: [ilen, ifeat, di] | |
e.g: | |
input_ids: [ | |
a b c d e f X g h i j k X l m | |
] | |
img_embeds: [3, ifeat, id] # img_embeds has padding | |
""" | |
with torch.no_grad(): | |
num_images, num_image_patches, embed_dim = image_features.shape | |
batch_size, sequence_length = input_ids.shape | |
# left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.pad_token_id)) | |
left_padding = torch.any(attention_mask[:, 0] == 0) | |
# assert not left_padding or batch_size == 1 | |
# 1. Create a mask to know where special image tokens are | |
special_image_token_mask = input_ids == self.config.image_token_index | |
num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1) | |
# Reserve for padding of num_images | |
total_num_special_image_tokens = torch.sum(special_image_token_mask) | |
assert total_num_special_image_tokens == num_images, f'{total_num_special_image_tokens=} != {num_images=} | {image_features.shape} {input_ids}' | |
# Compute the maximum embed dimension | |
max_embed_dim = (num_special_image_tokens.max() * (num_image_patches - 1)) + sequence_length | |
batch_indices, non_image_indices = torch.where(input_ids != self.config.image_token_index) | |
# 2. Compute the positions where text should be written | |
# Calculate new positions for text tokens in merged image-text sequence. | |
# `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images - 1` text tokens. | |
# `torch.cumsum` computes how each image token shifts subsequent text token positions. | |
# - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one. | |
new_token_positions = torch.cumsum((special_image_token_mask * (num_image_patches - 1) + 1), -1) - 1 | |
nb_image_pad = max_embed_dim - 1 - new_token_positions[:, -1] | |
if left_padding: | |
new_token_positions += nb_image_pad[:, None] # offset for left padding | |
text_to_overwrite = new_token_positions[batch_indices, non_image_indices] | |
# 3. Create the full embedding, already padded to the maximum position | |
final_embedding = torch.zeros( | |
batch_size, max_embed_dim, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device | |
) | |
final_attention_mask = torch.zeros( | |
batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device | |
) | |
final_labels = None | |
if labels is not None: | |
final_labels = torch.full_like(final_attention_mask, self.config.ignore_index).to(torch.long) | |
# In case the Vision model or the Language model has been offloaded to CPU, we need to manually | |
# set the corresponding tensors into their correct target device. | |
target_device = inputs_embeds.device | |
batch_indices, non_image_indices, text_to_overwrite = ( | |
batch_indices.to(target_device), | |
non_image_indices.to(target_device), | |
text_to_overwrite.to(target_device), | |
) | |
attention_mask = attention_mask.to(target_device) | |
# 4. Fill the embeddings based on the mask. If we have ["hey" "<image>", "how", "are"] | |
# we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features | |
final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices] | |
final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices] | |
if labels is not None: | |
final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices] | |
# 5. Fill the embeddings corresponding to the images. Anything that is still zeros needs filling | |
image_to_overwrite = torch.all(final_embedding == 0, dim=-1) | |
# image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device) | |
if left_padding: | |
image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device) | |
else: | |
val = torch.arange(max_embed_dim).unsqueeze(0).to(target_device).expand(batch_size, max_embed_dim) < new_token_positions[:, -1:].to(target_device) | |
image_to_overwrite &= val | |
if image_to_overwrite.sum() != image_features.shape[:-1].numel(): | |
raise ValueError( | |
f"The input provided to the model are wrong. The number of image tokens is {torch.sum(special_image_token_mask)} while" | |
f" the number of image given to the model is {num_images}. This prevents correct indexing and breaks batch generation." | |
) | |
final_embedding[image_to_overwrite] = image_features.contiguous().reshape(-1, embed_dim).to(target_device) | |
final_attention_mask |= image_to_overwrite | |
position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1) | |
if not left_padding: | |
# Making sure its the same | |
seq_lens = final_attention_mask.sum(-1) | |
for i, (mask, seq_len) in enumerate(zip(final_attention_mask, seq_lens)): | |
# seq_len = mask.sum(-1) | |
assert torch.all(mask[:seq_len] == 1), f'final 1 mask[{i}]: {seq_len} {final_attention_mask.tolist()=}' | |
assert torch.all(mask[seq_len:] == 0), f'final 0 mask[{i}]: {seq_len} {final_attention_mask.tolist()=}' | |
# if DEBUG: | |
# print(f'final_attention_mask=\n{final_attention_mask.tolist()}') | |
# print(f'text_to_overwrite=\n{text_to_overwrite.int().tolist()}') | |
# print(f'image_to_overwrite=\n{image_to_overwrite.int().tolist()}') | |
# print(f'position_ids=\n{position_ids.tolist()}') | |
# print(f'labels=\n{labels.tolist()}') | |
# print(f'final_labels=\n{final_labels.tolist()}') | |
return final_embedding, final_attention_mask, position_ids, final_labels | |
def extract_image_features(self, pixel_values, vision_feature_select_strategy=None): | |
vision_feature_select_strategy = ( | |
vision_feature_select_strategy | |
if vision_feature_select_strategy is not None | |
else self.config.vision_feature_select_strategy | |
) | |
with (torch.no_grad() if self.freeze_vision_tower else nullcontext()): | |
image_outputs = self.vision_tower(pixel_values) | |
hiddent_states = image_outputs.last_hidden_state | |
image_features = self.multi_modal_projector(hiddent_states) | |
return image_features | |
def forward( | |
self, | |
input_ids: torch.LongTensor = None, | |
pixel_values: torch.FloatTensor = None, | |
attention_mask: Optional[torch.Tensor] = None, | |
position_ids: Optional[torch.LongTensor] = None, | |
past_key_values: Optional[List[torch.FloatTensor]] = None, | |
inputs_embeds: Optional[torch.FloatTensor] = None, | |
vision_feature_layer: Optional[int] = None, | |
vision_feature_select_strategy: Optional[str] = None, | |
labels: Optional[torch.LongTensor] = None, | |
use_cache: Optional[bool] = None, | |
output_attentions: Optional[bool] = None, | |
output_hidden_states: Optional[bool] = None, | |
return_dict: Optional[bool] = None, | |
) -> Union[Tuple, LlavaCausalLMOutputWithPast]: | |
r""" | |
Args: | |
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): | |
Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., | |
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored | |
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. | |
Returns: | |
Example: | |
```python | |
>>> from PIL import Image | |
>>> import requests | |
>>> from transformers import AutoProcessor, LlavaForConditionalGeneration | |
>>> model = LlavaForConditionalGeneration.from_pretrained("llava-hf/llava-1.5-7b-hf") | |
>>> processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf") | |
>>> prompt = "<image>\nUSER: What's the content of the image?\nASSISTANT:" | |
>>> url = "https://www.ilankelman.org/stopsigns/australia.jpg" | |
>>> image = Image.open(requests.get(url, stream=True).raw) | |
>>> inputs = processor(text=prompt, images=image, return_tensors="pt") | |
>>> # Generate | |
>>> generate_ids = model.generate(**inputs, max_length=30) | |
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] | |
"\nUSER: What's the content of the image?\nASSISTANT: The image features a stop sign on a street corner" | |
```""" | |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | |
output_hidden_states = ( | |
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | |
) | |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
vision_feature_layer = ( | |
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer | |
) | |
vision_feature_select_strategy = ( | |
vision_feature_select_strategy | |
if vision_feature_select_strategy is not None | |
else self.config.vision_feature_select_strategy | |
) | |
if inputs_embeds is None: | |
# 1. Extra the input embeddings | |
for_inputs_embeds_ids = input_ids.clone() | |
for_inputs_embeds_ids[(input_ids == self.config.image_token_index)] = 0 | |
# inputs_embeds = self.get_input_embeddings()(input_ids) | |
inputs_embeds = self.get_input_embeddings()(for_inputs_embeds_ids) | |
# 2. Merge text and images | |
if pixel_values is not None and input_ids.shape[1] != 1 and pixel_values.size(0) > 0: | |
num_images = pixel_values.size(0) | |
batch_size, sequence_length = input_ids.shape | |
special_image_token_mask = input_ids == self.config.image_token_index | |
# Reserve for padding of num_images | |
total_num_special_image_tokens = torch.sum(special_image_token_mask) | |
assert num_images == total_num_special_image_tokens, ( | |
f'{num_images} < {total_num_special_image_tokens} | {special_image_token_mask}' | |
) | |
# pixel_values = pixel_values[:total_num_special_image_tokens] | |
# image_outputs = self.vision_tower(pixel_values, output_hidden_states=True) | |
# with (torch.no_grad() if self.freeze_vision_tower else nullcontext()): | |
# image_outputs = self.vision_tower(pixel_values) | |
# # this is not memory efficient at all (output_hidden_states=True) will save all the hidden stated. | |
# # selected_image_feature = image_outputs.hidden_states[vision_feature_layer] | |
# selected_image_feature = image_outputs.last_hidden_state | |
# if vision_feature_select_strategy == "default": | |
# selected_image_feature = selected_image_feature[:, 1:] | |
# elif vision_feature_select_strategy == "full": | |
# selected_image_feature = selected_image_feature | |
# else: | |
# raise ValueError( | |
# f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}" | |
# ) | |
# image_features = self.multi_modal_projector(selected_image_feature) | |
# print(f"{pixel_values.size()=}") | |
# ! extract_image_features will handle all image features extraction | |
image_features = self.extract_image_features(pixel_values) | |
# if DEBUG: | |
# image_features = image_features[:, :3] | |
inputs_embeds, attention_mask, position_ids, labels = self._merge_input_ids_with_image_features( | |
image_features, inputs_embeds, input_ids, attention_mask, position_ids, | |
labels=labels | |
) | |
# if labels is None: | |
# # ! this is wrong! | |
# labels = torch.full_like(attention_mask, self.config.ignore_index).to(torch.long) | |
# print(inputs_embeds.size()) | |
elif pixel_values is not None and input_ids.shape[1] != 1 and pixel_values.size(0) == 0: | |
# there is no images | |
pass | |
else: | |
# In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of | |
# generation with cache | |
# ! (phi) why do we need to do this? | |
# if past_key_values is not None and pixel_values is not None and input_ids.shape[1] == 1: | |
# # ! it can possible the bug because if mistral, from the first layer_key like this | |
# # ! MUST UNDERSTAND and fix error | |
# # Retrieve the first layer to inspect the logits and mask out the hidden states | |
# # that are set to 0 | |
# first_layer_past_key_value = past_key_values[0][0][:, 0, :, 0] | |
# batch_index, non_attended_tokens = torch.where(first_layer_past_key_value == 0) | |
# # Get the target length | |
# target_seqlen = first_layer_past_key_value.shape[-1] + 1 | |
# extended_attention_mask = torch.ones( | |
# (attention_mask.shape[0], target_seqlen - attention_mask.shape[1]), | |
# dtype=attention_mask.dtype, | |
# device=attention_mask.device, | |
# ) | |
# # print(f'{extended_attention_mask.shape} | {batch_index=} | {non_attended_tokens=}') | |
# # Zero-out the places where we don't need to attend | |
# extended_attention_mask[batch_index, non_attended_tokens] = 0 | |
# attention_mask = torch.cat((attention_mask, extended_attention_mask), dim=1) | |
# position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1 | |
# ! fix: https://github.com/huggingface/transformers/blob/c90268de7560c3fef21a927e0bfcf2b611a8711e/src/transformers/models/llava/modeling_llava.py | |
# https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941 | |
if past_key_values is not None and pixel_values is not None and input_ids.shape[1] == 1: | |
# Retrieve the first layer to inspect the logits and mask out the hidden states | |
# that are set to 0 | |
first_layer_past_key_value = past_key_values[0][0][:, :, :, 0] | |
# Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941 | |
batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0) | |
# Get the target length | |
target_seqlen = first_layer_past_key_value.shape[-1] + 1 | |
extended_attention_mask = torch.ones( | |
(attention_mask.shape[0], target_seqlen - attention_mask.shape[1]), | |
dtype=attention_mask.dtype, | |
device=attention_mask.device, | |
) | |
# Filter out only the tokens that can be un-attended, this can happen | |
# in the case one uses Llava + Fused modules where the cache on the | |
# first iteration is already big enough, or if one passes custom cache | |
valid_indices = non_attended_tokens < extended_attention_mask.size(-1) | |
new_batch_index = batch_index[valid_indices] | |
new_non_attended_tokens = non_attended_tokens[valid_indices] | |
# Zero-out the places where we don't need to attend | |
extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0 | |
attention_mask = torch.cat((attention_mask, extended_attention_mask), dim=1) | |
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1 | |
outputs = self.language_model( | |
attention_mask=attention_mask, | |
position_ids=position_ids, | |
past_key_values=past_key_values, | |
inputs_embeds=inputs_embeds, | |
use_cache=use_cache, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict, | |
) | |
logits = outputs[0] | |
loss = None | |
if labels is not None: | |
# Shift so that tokens < n predict n | |
if attention_mask is not None: | |
shift_attention_mask = attention_mask[..., 1:] | |
shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous() | |
shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous() | |
else: | |
shift_logits = logits[..., :-1, :].contiguous() | |
shift_labels = labels[..., 1:].contiguous() | |
# Flatten the tokens | |
loss_fct = nn.CrossEntropyLoss() | |
loss = loss_fct( | |
shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device) | |
) | |
if not return_dict: | |
output = (logits,) + outputs[1:] | |
return (loss,) + output if loss is not None else output | |
return LlavaCausalLMOutputWithPast( | |
loss=loss, | |
logits=logits, | |
past_key_values=outputs.past_key_values, | |
hidden_states=outputs.hidden_states, | |
attentions=outputs.attentions, | |
) | |
def prepare_inputs_for_generation( | |
self, input_ids, past_key_values=None, inputs_embeds=None, pixel_values=None, attention_mask=None, **kwargs | |
): | |
if past_key_values is not None: | |
if isinstance(past_key_values, Cache): | |
cache_length = past_key_values.get_seq_length() | |
past_length = past_key_values.seen_tokens | |
else: | |
cache_length = past_length = past_key_values[0][0].shape[2] | |
# Keep only the unprocessed tokens: | |
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where | |
# some of the inputs are exclusivelly passed as part of the cache (e.g. when passing input_embeds as | |
# input) | |
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: | |
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] | |
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard | |
# input_ids based on the past_length. | |
elif past_length < input_ids.shape[1]: | |
input_ids = input_ids[:, past_length:] | |
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. | |
elif self.config.image_token_index in input_ids: | |
input_ids = input_ids[:, input_ids.shape[1] - 1 :] | |
# If the cache has seen more tokens than it can hold, then the cache has a size limit. Let's discard the | |
# older attention values, as their corresponding values are not part of the input. | |
if cache_length < past_length and attention_mask is not None: | |
attention_mask = attention_mask[:, -(cache_length + input_ids.shape[1]) :] | |
position_ids = kwargs.get("position_ids", None) | |
if attention_mask is not None and position_ids is None: | |
# create position_ids on the fly for batch generation | |
position_ids = attention_mask.long().cumsum(-1) - 1 | |
position_ids.masked_fill_(attention_mask == 0, 1) | |
if past_key_values: | |
position_ids = position_ids[:, -input_ids.shape[1] :] | |
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step | |
if inputs_embeds is not None and past_key_values is None: | |
model_inputs = {"inputs_embeds": inputs_embeds} | |
else: | |
model_inputs = {"input_ids": input_ids} | |
model_inputs.update( | |
{ | |
"position_ids": position_ids, | |
"past_key_values": past_key_values, | |
"use_cache": kwargs.get("use_cache"), | |
"attention_mask": attention_mask, | |
"pixel_values": pixel_values, | |
} | |
) | |
return model_inputs | |
def _reorder_cache(self, *args, **kwargs): | |
return self.language_model._reorder_cache(*args, **kwargs) | |