Spaces:
Paused
Paused
File size: 6,620 Bytes
d3653d5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
import torch
import torch.nn as nn
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import warnings
import logging
from torch import Tensor
from diffusers import ModelMixin
from transformers.models.t5.modeling_t5 import T5LayerSelfAttention, T5LayerFF, T5LayerNorm
logger = logging.getLogger(__name__)
class T5EncoderBlock(nn.Module):
def __init__(self, config, has_relative_attention_bias=False):
super().__init__()
self.layer = nn.ModuleList()
self.layer.append(T5LayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias))
self.layer.append(T5LayerFF(config))
def forward(
self,
hidden_states,
attention_mask=None,
position_bias=None,
layer_head_mask=None,
output_attentions=False,
):
self_attn_past_key_value, cross_attn_past_key_value = None, None
self_attention_outputs = self.layer[0](
hidden_states,
attention_mask=attention_mask,
position_bias=position_bias,
layer_head_mask=layer_head_mask,
past_key_value=self_attn_past_key_value,
use_cache=False,
output_attentions=output_attentions,
)
hidden_states, present_key_value_state = self_attention_outputs[:2]
attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights
# clamp inf values to enable fp16 training
if hidden_states.dtype == torch.float16:
clamp_value = torch.where(
torch.isinf(hidden_states).any(),
torch.finfo(hidden_states.dtype).max - 1000,
torch.finfo(hidden_states.dtype).max,
)
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
# Apply Feed Forward layer
hidden_states = self.layer[-1](hidden_states)
# clamp inf values to enable fp16 training
if hidden_states.dtype == torch.float16:
clamp_value = torch.where(
torch.isinf(hidden_states).any(),
torch.finfo(hidden_states.dtype).max - 1000,
torch.finfo(hidden_states.dtype).max,
)
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
outputs = (hidden_states,) + attention_outputs
return outputs # hidden-states, present_key_value_states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights)
class T5EncoderBlockByT5Mapper(ModelMixin):
def __init__(self, byt5_config, num_layers, sdxl_channels=None):
super().__init__()
if num_layers > 0:
self.blocks = nn.ModuleList(
[
T5EncoderBlock(
byt5_config,
has_relative_attention_bias=bool(i == 0))
for i in range(num_layers)
]
)
else:
self.blocks = None
self.layer_norm = T5LayerNorm(byt5_config.d_model, eps=byt5_config.layer_norm_epsilon)
if sdxl_channels is not None:
self.channel_mapper = nn.Linear(byt5_config.d_model, sdxl_channels)
self.final_layer_norm = T5LayerNorm(sdxl_channels, eps=byt5_config.layer_norm_epsilon)
else:
self.channel_mapper = None
self.final_layer_norm = None
def get_extended_attention_mask(
self, attention_mask: Tensor, input_shape: Tuple[int], device: torch.device = None, dtype: torch.float = None
) -> Tensor:
"""
Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
Arguments:
attention_mask (`torch.Tensor`):
Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
input_shape (`Tuple[int]`):
The shape of the input to the model.
Returns:
`torch.Tensor` The extended attention mask, with a the same dtype as `attention_mask.dtype`.
"""
if dtype is None:
dtype = self.dtype
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
if attention_mask.dim() == 3:
extended_attention_mask = attention_mask[:, None, :, :]
elif attention_mask.dim() == 2:
# Provided a padding mask of dimensions [batch_size, seq_length]
# - if the model is a decoder, apply a causal mask in addition to the padding mask
# - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
extended_attention_mask = attention_mask[:, None, None, :]
else:
raise ValueError(
f"Wrong shape for input_ids (shape {input_shape}) or attention_mask (shape {attention_mask.shape})"
)
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and the dtype's smallest value for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
extended_attention_mask = extended_attention_mask.to(dtype=dtype) # fp16 compatibility
extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(dtype).min
return extended_attention_mask
def forward(self, inputs_embeds, attention_mask):
input_shape = inputs_embeds.size()[:-1]
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
hidden_states = inputs_embeds
position_bias = None
if self.blocks is not None:
for layer_module in self.blocks:
layer_outputs = layer_module(
hidden_states,
attention_mask=extended_attention_mask,
position_bias=position_bias,
)
hidden_states, position_bias = layer_outputs
hidden_states = self.layer_norm(hidden_states)
if self.channel_mapper is not None:
hidden_states = self.channel_mapper(hidden_states)
hidden_states = self.final_layer_norm(hidden_states)
return hidden_states
|