Spaces:
Paused
Paused
import torch | |
import torch.nn as nn | |
from typing import Any, Callable, Dict, List, Optional, Tuple, Union | |
import warnings | |
import logging | |
from torch import Tensor | |
from diffusers import ModelMixin | |
from transformers.models.t5.modeling_t5 import T5LayerSelfAttention, T5LayerFF, T5LayerNorm | |
logger = logging.getLogger(__name__) | |
class T5EncoderBlock(nn.Module): | |
def __init__(self, config, has_relative_attention_bias=False): | |
super().__init__() | |
self.layer = nn.ModuleList() | |
self.layer.append(T5LayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias)) | |
self.layer.append(T5LayerFF(config)) | |
def forward( | |
self, | |
hidden_states, | |
attention_mask=None, | |
position_bias=None, | |
layer_head_mask=None, | |
output_attentions=False, | |
): | |
self_attn_past_key_value, cross_attn_past_key_value = None, None | |
self_attention_outputs = self.layer[0]( | |
hidden_states, | |
attention_mask=attention_mask, | |
position_bias=position_bias, | |
layer_head_mask=layer_head_mask, | |
past_key_value=self_attn_past_key_value, | |
use_cache=False, | |
output_attentions=output_attentions, | |
) | |
hidden_states, present_key_value_state = self_attention_outputs[:2] | |
attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights | |
# clamp inf values to enable fp16 training | |
if hidden_states.dtype == torch.float16: | |
clamp_value = torch.where( | |
torch.isinf(hidden_states).any(), | |
torch.finfo(hidden_states.dtype).max - 1000, | |
torch.finfo(hidden_states.dtype).max, | |
) | |
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) | |
# Apply Feed Forward layer | |
hidden_states = self.layer[-1](hidden_states) | |
# clamp inf values to enable fp16 training | |
if hidden_states.dtype == torch.float16: | |
clamp_value = torch.where( | |
torch.isinf(hidden_states).any(), | |
torch.finfo(hidden_states.dtype).max - 1000, | |
torch.finfo(hidden_states.dtype).max, | |
) | |
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) | |
outputs = (hidden_states,) + attention_outputs | |
return outputs # hidden-states, present_key_value_states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) | |
class T5EncoderBlockByT5Mapper(ModelMixin): | |
def __init__(self, byt5_config, num_layers, sdxl_channels=None): | |
super().__init__() | |
if num_layers > 0: | |
self.blocks = nn.ModuleList( | |
[ | |
T5EncoderBlock( | |
byt5_config, | |
has_relative_attention_bias=bool(i == 0)) | |
for i in range(num_layers) | |
] | |
) | |
else: | |
self.blocks = None | |
self.layer_norm = T5LayerNorm(byt5_config.d_model, eps=byt5_config.layer_norm_epsilon) | |
if sdxl_channels is not None: | |
self.channel_mapper = nn.Linear(byt5_config.d_model, sdxl_channels) | |
self.final_layer_norm = T5LayerNorm(sdxl_channels, eps=byt5_config.layer_norm_epsilon) | |
else: | |
self.channel_mapper = None | |
self.final_layer_norm = None | |
def get_extended_attention_mask( | |
self, attention_mask: Tensor, input_shape: Tuple[int], device: torch.device = None, dtype: torch.float = None | |
) -> Tensor: | |
""" | |
Makes broadcastable attention and causal masks so that future and masked tokens are ignored. | |
Arguments: | |
attention_mask (`torch.Tensor`): | |
Mask with ones indicating tokens to attend to, zeros for tokens to ignore. | |
input_shape (`Tuple[int]`): | |
The shape of the input to the model. | |
Returns: | |
`torch.Tensor` The extended attention mask, with a the same dtype as `attention_mask.dtype`. | |
""" | |
if dtype is None: | |
dtype = self.dtype | |
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] | |
# ourselves in which case we just need to make it broadcastable to all heads. | |
if attention_mask.dim() == 3: | |
extended_attention_mask = attention_mask[:, None, :, :] | |
elif attention_mask.dim() == 2: | |
# Provided a padding mask of dimensions [batch_size, seq_length] | |
# - if the model is a decoder, apply a causal mask in addition to the padding mask | |
# - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length] | |
extended_attention_mask = attention_mask[:, None, None, :] | |
else: | |
raise ValueError( | |
f"Wrong shape for input_ids (shape {input_shape}) or attention_mask (shape {attention_mask.shape})" | |
) | |
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for | |
# masked positions, this operation will create a tensor which is 0.0 for | |
# positions we want to attend and the dtype's smallest value for masked positions. | |
# Since we are adding it to the raw scores before the softmax, this is | |
# effectively the same as removing these entirely. | |
extended_attention_mask = extended_attention_mask.to(dtype=dtype) # fp16 compatibility | |
extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(dtype).min | |
return extended_attention_mask | |
def forward(self, inputs_embeds, attention_mask): | |
input_shape = inputs_embeds.size()[:-1] | |
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape) | |
hidden_states = inputs_embeds | |
position_bias = None | |
if self.blocks is not None: | |
for layer_module in self.blocks: | |
layer_outputs = layer_module( | |
hidden_states, | |
attention_mask=extended_attention_mask, | |
position_bias=position_bias, | |
) | |
hidden_states, position_bias = layer_outputs | |
hidden_states = self.layer_norm(hidden_states) | |
if self.channel_mapper is not None: | |
hidden_states = self.channel_mapper(hidden_states) | |
hidden_states = self.final_layer_norm(hidden_states) | |
return hidden_states | |