Glyph-SDXL-v2 / glyph_sdxl /modules /byt5_block_byt5_mapper.py
lzy-tony
feat: add english ver
d3653d5
import torch
import torch.nn as nn
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import warnings
import logging
from torch import Tensor
from diffusers import ModelMixin
from transformers.models.t5.modeling_t5 import T5LayerSelfAttention, T5LayerFF, T5LayerNorm
logger = logging.getLogger(__name__)
class T5EncoderBlock(nn.Module):
def __init__(self, config, has_relative_attention_bias=False):
super().__init__()
self.layer = nn.ModuleList()
self.layer.append(T5LayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias))
self.layer.append(T5LayerFF(config))
def forward(
self,
hidden_states,
attention_mask=None,
position_bias=None,
layer_head_mask=None,
output_attentions=False,
):
self_attn_past_key_value, cross_attn_past_key_value = None, None
self_attention_outputs = self.layer[0](
hidden_states,
attention_mask=attention_mask,
position_bias=position_bias,
layer_head_mask=layer_head_mask,
past_key_value=self_attn_past_key_value,
use_cache=False,
output_attentions=output_attentions,
)
hidden_states, present_key_value_state = self_attention_outputs[:2]
attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights
# clamp inf values to enable fp16 training
if hidden_states.dtype == torch.float16:
clamp_value = torch.where(
torch.isinf(hidden_states).any(),
torch.finfo(hidden_states.dtype).max - 1000,
torch.finfo(hidden_states.dtype).max,
)
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
# Apply Feed Forward layer
hidden_states = self.layer[-1](hidden_states)
# clamp inf values to enable fp16 training
if hidden_states.dtype == torch.float16:
clamp_value = torch.where(
torch.isinf(hidden_states).any(),
torch.finfo(hidden_states.dtype).max - 1000,
torch.finfo(hidden_states.dtype).max,
)
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
outputs = (hidden_states,) + attention_outputs
return outputs # hidden-states, present_key_value_states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights)
class T5EncoderBlockByT5Mapper(ModelMixin):
def __init__(self, byt5_config, num_layers, sdxl_channels=None):
super().__init__()
if num_layers > 0:
self.blocks = nn.ModuleList(
[
T5EncoderBlock(
byt5_config,
has_relative_attention_bias=bool(i == 0))
for i in range(num_layers)
]
)
else:
self.blocks = None
self.layer_norm = T5LayerNorm(byt5_config.d_model, eps=byt5_config.layer_norm_epsilon)
if sdxl_channels is not None:
self.channel_mapper = nn.Linear(byt5_config.d_model, sdxl_channels)
self.final_layer_norm = T5LayerNorm(sdxl_channels, eps=byt5_config.layer_norm_epsilon)
else:
self.channel_mapper = None
self.final_layer_norm = None
def get_extended_attention_mask(
self, attention_mask: Tensor, input_shape: Tuple[int], device: torch.device = None, dtype: torch.float = None
) -> Tensor:
"""
Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
Arguments:
attention_mask (`torch.Tensor`):
Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
input_shape (`Tuple[int]`):
The shape of the input to the model.
Returns:
`torch.Tensor` The extended attention mask, with a the same dtype as `attention_mask.dtype`.
"""
if dtype is None:
dtype = self.dtype
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
if attention_mask.dim() == 3:
extended_attention_mask = attention_mask[:, None, :, :]
elif attention_mask.dim() == 2:
# Provided a padding mask of dimensions [batch_size, seq_length]
# - if the model is a decoder, apply a causal mask in addition to the padding mask
# - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
extended_attention_mask = attention_mask[:, None, None, :]
else:
raise ValueError(
f"Wrong shape for input_ids (shape {input_shape}) or attention_mask (shape {attention_mask.shape})"
)
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and the dtype's smallest value for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
extended_attention_mask = extended_attention_mask.to(dtype=dtype) # fp16 compatibility
extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(dtype).min
return extended_attention_mask
def forward(self, inputs_embeds, attention_mask):
input_shape = inputs_embeds.size()[:-1]
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
hidden_states = inputs_embeds
position_bias = None
if self.blocks is not None:
for layer_module in self.blocks:
layer_outputs = layer_module(
hidden_states,
attention_mask=extended_attention_mask,
position_bias=position_bias,
)
hidden_states, position_bias = layer_outputs
hidden_states = self.layer_norm(hidden_states)
if self.channel_mapper is not None:
hidden_states = self.channel_mapper(hidden_states)
hidden_states = self.final_layer_norm(hidden_states)
return hidden_states