|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import copy |
|
import math |
|
import os |
|
import warnings |
|
from dataclasses import dataclass |
|
from typing import Optional, Tuple, Union, Callable |
|
|
|
import torch |
|
from torch import nn |
|
from torch.nn import CrossEntropyLoss |
|
from torch.utils.checkpoint import checkpoint |
|
|
|
try: |
|
from torch.nn import Identity |
|
except ImportError: |
|
|
|
class Identity(nn.Module): |
|
r"""A placeholder identity operator that is argument-insensitive.""" |
|
|
|
def __init__(self, *args, **kwargs): |
|
super().__init__() |
|
|
|
def forward(self, input): |
|
return input |
|
|
|
from transformers.models.t5.modeling_t5 import ( |
|
T5LayerSelfAttention, |
|
T5LayerCrossAttention, |
|
T5LayerFF, |
|
T5PreTrainedModel, |
|
T5LayerNorm, |
|
PARALLELIZE_DOCSTRING, |
|
DEPARALLELIZE_DOCSTRING, |
|
__HEAD_MASK_WARNING_MSG, |
|
T5_START_DOCSTRING, |
|
T5_INPUTS_DOCSTRING |
|
) |
|
from transformers.modeling_outputs import ( |
|
BaseModelOutputWithPastAndCrossAttentions, |
|
Seq2SeqLMOutput, |
|
BaseModelOutput |
|
) |
|
from transformers.utils import ( |
|
DUMMY_INPUTS, |
|
DUMMY_MASK, |
|
add_start_docstrings, |
|
add_start_docstrings_to_model_forward, |
|
is_torch_fx_proxy, |
|
logging, |
|
replace_return_docstrings, |
|
ModelOutput, |
|
) |
|
from transformers.utils.model_parallel_utils import assert_device_map, get_device_map |
|
from transformers import T5Config |
|
from transformers.configuration_utils import PretrainedConfig |
|
from transformers.activations import get_activation |
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
_CONFIG_FOR_DOC_DDT5 = "T5Config" |
|
|
|
def get_last_token_index(mask): |
|
|
|
|
|
batch_size, seq_length = mask.shape[:2] |
|
incr = torch.arange(seq_length, device=mask.device, requires_grad=False) |
|
incr_m = torch.einsum("i,ji->ji", incr, mask) |
|
return torch.argmax(incr_m, dim=1) |
|
|
|
|
|
class SequenceSummary(nn.Module): |
|
r""" |
|
Compute a single vector summary of a sequence hidden states. |
|
|
|
Args: |
|
config ([`PretrainedConfig`]): |
|
The config used by the model. Relevant arguments in the config class of the model are (refer to the actual |
|
config class of your model for the default values it uses): |
|
|
|
- **summary_type** (`str`) -- The method to use to make this summary. Accepted values are: |
|
|
|
- `"last"` -- Take the last token hidden state (like XLNet) |
|
- `"first"` -- Take the first token hidden state (like Bert) |
|
- `"mean"` -- Take the mean of all tokens hidden states |
|
- `"cls_index"` -- Supply a Tensor of classification token position (GPT/GPT-2) |
|
- `"attn"` -- Not implemented now, use multi-head attention |
|
|
|
- **summary_use_proj** (`bool`) -- Add a projection after the vector extraction. |
|
- **summary_proj_to_labels** (`bool`) -- If `True`, the projection outputs to `config.num_labels` classes |
|
(otherwise to `config.hidden_size`). |
|
- **summary_activation** (`Optional[str]`) -- Set to `"tanh"` to add a tanh activation to the output, |
|
another string or `None` will add no activation. |
|
- **summary_first_dropout** (`float`) -- Optional dropout probability before the projection and activation. |
|
- **summary_last_dropout** (`float`)-- Optional dropout probability after the projection and activation. |
|
""" |
|
|
|
def __init__(self, config: PretrainedConfig, num_queries=1): |
|
super().__init__() |
|
|
|
self.summary_type = getattr(config, "summary_type", "last") |
|
if self.summary_type == "attn": |
|
|
|
|
|
|
|
self.queries = nn.Parameter(torch.empty(num_queries, config.hidden_size)) |
|
nn.init.kaiming_uniform_(self.queries, a=math.sqrt(5)) |
|
self.MultiheadAttention = nn.MultiheadAttention( |
|
config.hidden_size, |
|
config.num_attention_heads, |
|
batch_first=True |
|
) |
|
layer_norm_eps = getattr(config, "layer_norm_eps", 1e-6) |
|
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=layer_norm_eps) |
|
|
|
self.summary = Identity() |
|
if hasattr(config, "summary_use_proj") and config.summary_use_proj: |
|
if hasattr(config, "summary_proj_to_labels") and config.summary_proj_to_labels and config.num_labels > 0: |
|
num_classes = config.num_labels |
|
else: |
|
num_classes = config.hidden_size |
|
self.summary = nn.Linear(config.hidden_size, num_classes) |
|
|
|
activation_string = getattr(config, "summary_activation", None) |
|
self.activation: Callable = get_activation(activation_string) if activation_string else Identity() |
|
|
|
self.first_dropout = Identity() |
|
if hasattr(config, "summary_first_dropout") and config.summary_first_dropout > 0: |
|
self.first_dropout = nn.Dropout(config.summary_first_dropout) |
|
|
|
self.last_dropout = Identity() |
|
if hasattr(config, "summary_last_dropout") and config.summary_last_dropout > 0: |
|
self.last_dropout = nn.Dropout(config.summary_last_dropout) |
|
|
|
def forward( |
|
self, hidden_states: torch.FloatTensor, cls_index: Optional[torch.LongTensor] = None |
|
) -> torch.FloatTensor: |
|
""" |
|
Compute a single vector summary of a sequence hidden states. |
|
|
|
Args: |
|
hidden_states (`torch.FloatTensor` of shape `[batch_size, seq_len, hidden_size]`): |
|
The hidden states of the last layer. |
|
cls_index (`torch.LongTensor` of shape `[batch_size]` or `[batch_size, ...]` where ... are optional leading dimensions of `hidden_states`, *optional*): |
|
Used if `summary_type == "cls_index"` and takes the last token of the sequence as classification token. |
|
|
|
Returns: |
|
`torch.FloatTensor`: The summary of the sequence hidden states. |
|
""" |
|
if self.summary_type == "last": |
|
output = hidden_states[:, -1] |
|
elif self.summary_type == "first": |
|
output = hidden_states[:, 0] |
|
elif self.summary_type == "mean": |
|
output = hidden_states.mean(dim=1) |
|
elif self.summary_type == "cls_index": |
|
if cls_index is None: |
|
cls_index = torch.full_like( |
|
hidden_states[..., :1, :], |
|
hidden_states.shape[-2] - 1, |
|
dtype=torch.long, |
|
) |
|
else: |
|
cls_index = cls_index.unsqueeze(-1).unsqueeze(-1) |
|
cls_index = cls_index.expand((-1,) * (cls_index.dim() - 1) + (hidden_states.size(-1),)) |
|
|
|
output = hidden_states.gather(-2, cls_index).squeeze(-2) |
|
elif self.summary_type == "attn": |
|
batch_size = hidden_states.size(0) |
|
queries = self.queries.repeat(batch_size, 1, 1) |
|
output = self.MultiheadAttention(queries, hidden_states, hidden_states, need_weights=False)[0] |
|
|
|
output = self.LayerNorm(output) |
|
|
|
output = self.first_dropout(output) |
|
output = self.summary(output) |
|
output = self.activation(output) |
|
output = self.last_dropout(output) |
|
|
|
return output |
|
|
|
|
|
class T5DecoderBlock(nn.Module): |
|
def __init__(self, config, has_relative_attention_bias=False): |
|
super().__init__() |
|
self.is_decoder = config.is_decoder |
|
self.has_cross_attention = config.add_cross_attention |
|
self.layer = nn.ModuleList() |
|
self.layer.append(T5LayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias)) |
|
if self.has_cross_attention: |
|
self.layer.append(T5LayerCrossAttention(config)) |
|
|
|
self.layer.append(T5LayerFF(config)) |
|
|
|
def forward( |
|
self, |
|
hidden_states, |
|
attention_mask=None, |
|
position_bias=None, |
|
encoder_hidden_states=None, |
|
encoder_attention_mask=None, |
|
encoder_decoder_position_bias=None, |
|
layer_head_mask=None, |
|
cross_attn_layer_head_mask=None, |
|
past_key_value=None, |
|
use_cache=False, |
|
output_attentions=False, |
|
return_dict=True, |
|
): |
|
|
|
if past_key_value is not None: |
|
if not self.is_decoder: |
|
logger.warning("`past_key_values` is passed to the encoder. Please make sure this is intended.") |
|
expected_num_past_key_values = 2 if encoder_hidden_states is None else 4 |
|
|
|
if len(past_key_value) != expected_num_past_key_values: |
|
raise ValueError( |
|
f"There should be {expected_num_past_key_values} past states. " |
|
f"{'2 (past / key) for cross attention. ' if expected_num_past_key_values == 4 else ''}" |
|
f"Got {len(past_key_value)} past key / value states" |
|
) |
|
|
|
self_attn_past_key_value = past_key_value[:2] |
|
cross_attn_past_key_value = past_key_value[2:] |
|
else: |
|
self_attn_past_key_value, cross_attn_past_key_value = None, None |
|
|
|
self_attention_outputs = self.layer[0]( |
|
hidden_states, |
|
attention_mask=attention_mask, |
|
position_bias=position_bias, |
|
layer_head_mask=layer_head_mask, |
|
past_key_value=self_attn_past_key_value, |
|
use_cache=use_cache, |
|
output_attentions=output_attentions, |
|
) |
|
hidden_states, present_key_value_state = self_attention_outputs[:2] |
|
attention_outputs = self_attention_outputs[2:] |
|
|
|
|
|
if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any(): |
|
clamp_value = torch.finfo(hidden_states.dtype).max - 1000 |
|
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) |
|
|
|
do_cross_attention = self.has_cross_attention and encoder_hidden_states is not None |
|
if do_cross_attention: |
|
|
|
|
|
if present_key_value_state is not None: |
|
query_length = present_key_value_state[0].shape[2] |
|
else: |
|
query_length = None |
|
|
|
cross_attention_outputs = self.layer[1]( |
|
hidden_states, |
|
key_value_states=encoder_hidden_states, |
|
attention_mask=encoder_attention_mask, |
|
position_bias=encoder_decoder_position_bias, |
|
layer_head_mask=cross_attn_layer_head_mask, |
|
past_key_value=cross_attn_past_key_value, |
|
query_length=query_length, |
|
use_cache=use_cache, |
|
output_attentions=output_attentions, |
|
) |
|
hidden_states = cross_attention_outputs[0] |
|
|
|
|
|
if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any(): |
|
clamp_value = torch.finfo(hidden_states.dtype).max - 1000 |
|
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) |
|
|
|
|
|
if present_key_value_state is not None: |
|
present_key_value_state = present_key_value_state + cross_attention_outputs[1] |
|
|
|
|
|
attention_outputs = attention_outputs + cross_attention_outputs[2:] |
|
|
|
|
|
hidden_states = self.layer[-1](hidden_states) |
|
|
|
|
|
if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any(): |
|
clamp_value = torch.finfo(hidden_states.dtype).max - 1000 |
|
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) |
|
|
|
outputs = (hidden_states,) |
|
|
|
if use_cache: |
|
outputs = outputs + (present_key_value_state,) + attention_outputs |
|
else: |
|
outputs = outputs + attention_outputs |
|
|
|
return outputs |
|
|
|
|
|
@dataclass |
|
class BaseModelOutputWithPastAndCrossAttentionsAndPositionBias(ModelOutput): |
|
""" |
|
Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding) plus position bias. |
|
|
|
Args: |
|
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): |
|
Sequence of hidden-states at the output of the last layer of the model. |
|
|
|
If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1, |
|
hidden_size)` is output. |
|
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): |
|
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape |
|
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if |
|
`config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads, |
|
encoder_sequence_length, embed_size_per_head)`. |
|
|
|
Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if |
|
`config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values` |
|
input) to speed up sequential decoding. |
|
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): |
|
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + |
|
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. |
|
|
|
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. |
|
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): |
|
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, |
|
sequence_length)`. |
|
|
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention |
|
heads. |
|
cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`): |
|
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, |
|
sequence_length)`. |
|
|
|
Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the |
|
weighted average in the cross-attention heads. |
|
position_bias (`torch.FloatTensor`, *optional*, returned when the model is self-attention decoder): |
|
position_bias is created in the first layer of the self-attention decoder, and it passes through all the layers including layers of the cross-attention decoder. |
|
`torch.FloatTensor` of shape `(batch_size, num_heads, sequence_length, sequence_length)`. |
|
""" |
|
|
|
last_hidden_state: torch.FloatTensor = None |
|
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None |
|
hidden_states: Optional[Tuple[torch.FloatTensor]] = None |
|
attentions: Optional[Tuple[torch.FloatTensor]] = None |
|
cross_attentions: Optional[Tuple[torch.FloatTensor]] = None |
|
position_bias: Optional[torch.FloatTensor] = None |
|
|
|
|
|
class T5DecoderStack(T5PreTrainedModel): |
|
def __init__(self, config, embed_tokens=None, has_relative_attention_bias=True): |
|
super().__init__(config) |
|
|
|
self.embed_tokens = embed_tokens |
|
self.is_decoder = config.is_decoder |
|
self.has_cross_attention = config.add_cross_attention |
|
|
|
self.block = nn.ModuleList( |
|
[T5DecoderBlock(config, has_relative_attention_bias=bool(i == 0) and has_relative_attention_bias) for i in range(config.num_layers)] |
|
) |
|
self.final_layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) |
|
self.dropout = nn.Dropout(config.dropout_rate) |
|
|
|
|
|
self.post_init() |
|
|
|
self.model_parallel = False |
|
self.device_map = None |
|
self.gradient_checkpointing = False |
|
|
|
@add_start_docstrings(PARALLELIZE_DOCSTRING) |
|
def parallelize(self, device_map=None): |
|
|
|
self.device_map = ( |
|
get_device_map(len(self.block), range(torch.cuda.device_count())) if device_map is None else device_map |
|
) |
|
assert_device_map(self.device_map, len(self.block)) |
|
self.model_parallel = True |
|
self.first_device = "cpu" if "cpu" in self.device_map.keys() else "cuda:" + str(min(self.device_map.keys())) |
|
self.last_device = "cuda:" + str(max(self.device_map.keys())) |
|
|
|
for k, v in self.device_map.items(): |
|
for layer in v: |
|
cuda_device = "cuda:" + str(k) |
|
self.block[layer] = self.block[layer].to(cuda_device) |
|
|
|
|
|
self.embed_tokens = self.embed_tokens.to(self.first_device) if self.embed_tokens is not None else self.embed_tokens |
|
|
|
self.final_layer_norm = self.final_layer_norm.to(self.last_device) |
|
|
|
@add_start_docstrings(PARALLELIZE_DOCSTRING) |
|
def deparallelize(self): |
|
self.model_parallel = False |
|
self.device_map = None |
|
self.first_device = "cpu" |
|
self.last_device = "cpu" |
|
for i in range(len(self.block)): |
|
self.block[i] = self.block[i].to("cpu") |
|
self.embed_tokens = self.embed_tokens.to("cpu") if self.embed_tokens is not None else self.embed_tokens |
|
self.final_layer_norm = self.final_layer_norm.to("cpu") |
|
torch.cuda.empty_cache() |
|
|
|
def get_input_embeddings(self): |
|
return self.embed_tokens |
|
|
|
def set_input_embeddings(self, new_embeddings): |
|
self.embed_tokens = new_embeddings |
|
|
|
def forward( |
|
self, |
|
input_ids=None, |
|
attention_mask=None, |
|
encoder_hidden_states=None, |
|
encoder_attention_mask=None, |
|
inputs_embeds=None, |
|
position_bias=None, |
|
encoder_decoder_position_bias=None, |
|
head_mask=None, |
|
cross_attn_head_mask=None, |
|
past_key_values=None, |
|
use_cache=None, |
|
output_attentions=None, |
|
output_hidden_states=None, |
|
return_dict=None, |
|
): |
|
|
|
if self.model_parallel: |
|
torch.cuda.set_device(self.first_device) |
|
self.embed_tokens = self.embed_tokens.to(self.first_device) if self.embed_tokens is not None else self.embed_tokens |
|
use_cache = use_cache if use_cache is not None else self.config.use_cache |
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
) |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
if input_ids is not None and inputs_embeds is not None: |
|
err_msg_prefix = "decoder_" if self.is_decoder else "" |
|
raise ValueError( |
|
f"You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time" |
|
) |
|
elif input_ids is not None: |
|
input_shape = input_ids.size() |
|
input_ids = input_ids.view(-1, input_shape[-1]) |
|
elif inputs_embeds is not None: |
|
input_shape = inputs_embeds.size()[:-1] |
|
else: |
|
err_msg_prefix = "decoder_" if self.is_decoder else "" |
|
raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds") |
|
|
|
if inputs_embeds is None: |
|
assert self.embed_tokens is not None, "You have to initialize the model with valid token embeddings" |
|
inputs_embeds = self.embed_tokens(input_ids) |
|
|
|
batch_size, seq_length = input_shape |
|
|
|
|
|
mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length |
|
|
|
if use_cache is True: |
|
assert self.is_decoder, f"`use_cache` can only be set to `True` if {self} is used as a decoder" |
|
|
|
if attention_mask is None: |
|
attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) |
|
if self.has_cross_attention and encoder_attention_mask is None and encoder_hidden_states is not None: |
|
encoder_seq_length = encoder_hidden_states.shape[1] |
|
encoder_attention_mask = torch.ones( |
|
batch_size, encoder_seq_length, device=inputs_embeds.device, dtype=torch.long |
|
) |
|
|
|
|
|
if past_key_values is None: |
|
past_key_values = [None] * len(self.block) |
|
|
|
|
|
|
|
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape) |
|
|
|
|
|
|
|
if self.has_cross_attention and encoder_hidden_states is not None: |
|
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() |
|
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) |
|
if encoder_attention_mask is None: |
|
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=inputs_embeds.device) |
|
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) |
|
else: |
|
encoder_extended_attention_mask = None |
|
|
|
|
|
head_mask = self.get_head_mask(head_mask, self.config.num_layers) |
|
cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers) |
|
present_key_value_states = () if use_cache else None |
|
all_hidden_states = () if output_hidden_states else None |
|
all_attentions = () if output_attentions else None |
|
all_cross_attentions = () if (output_attentions and self.has_cross_attention) else None |
|
|
|
|
|
hidden_states = self.dropout(inputs_embeds) |
|
|
|
for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)): |
|
layer_head_mask = head_mask[i] |
|
cross_attn_layer_head_mask = cross_attn_head_mask[i] |
|
|
|
|
|
if self.model_parallel: |
|
torch.cuda.set_device(hidden_states.device) |
|
|
|
if attention_mask is not None: |
|
attention_mask = attention_mask.to(hidden_states.device) |
|
if position_bias is not None: |
|
position_bias = position_bias.to(hidden_states.device) |
|
if encoder_hidden_states is not None: |
|
encoder_hidden_states = encoder_hidden_states.to(hidden_states.device) |
|
if encoder_extended_attention_mask is not None: |
|
encoder_extended_attention_mask = encoder_extended_attention_mask.to(hidden_states.device) |
|
if encoder_decoder_position_bias is not None: |
|
encoder_decoder_position_bias = encoder_decoder_position_bias.to(hidden_states.device) |
|
if layer_head_mask is not None: |
|
layer_head_mask = layer_head_mask.to(hidden_states.device) |
|
if cross_attn_layer_head_mask is not None: |
|
cross_attn_layer_head_mask = cross_attn_layer_head_mask.to(hidden_states.device) |
|
if output_hidden_states: |
|
all_hidden_states = all_hidden_states + (hidden_states,) |
|
|
|
if self.gradient_checkpointing and self.training: |
|
if use_cache: |
|
logger.warning( |
|
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." |
|
) |
|
use_cache = False |
|
|
|
def create_custom_forward(module): |
|
def custom_forward(*inputs): |
|
return tuple(module(*inputs, use_cache, output_attentions)) |
|
|
|
return custom_forward |
|
|
|
layer_outputs = checkpoint( |
|
create_custom_forward(layer_module), |
|
hidden_states, |
|
extended_attention_mask, |
|
position_bias, |
|
encoder_hidden_states, |
|
encoder_extended_attention_mask, |
|
encoder_decoder_position_bias, |
|
layer_head_mask, |
|
cross_attn_layer_head_mask, |
|
None, |
|
) |
|
else: |
|
layer_outputs = layer_module( |
|
hidden_states, |
|
attention_mask=extended_attention_mask, |
|
position_bias=position_bias, |
|
encoder_hidden_states=encoder_hidden_states, |
|
encoder_attention_mask=encoder_extended_attention_mask, |
|
encoder_decoder_position_bias=encoder_decoder_position_bias, |
|
layer_head_mask=layer_head_mask, |
|
cross_attn_layer_head_mask=cross_attn_layer_head_mask, |
|
past_key_value=past_key_value, |
|
use_cache=use_cache, |
|
output_attentions=output_attentions, |
|
) |
|
|
|
|
|
|
|
if use_cache is False: |
|
layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:] |
|
|
|
hidden_states, present_key_value_state = layer_outputs[:2] |
|
|
|
|
|
|
|
|
|
position_bias = layer_outputs[2] |
|
if self.has_cross_attention and encoder_hidden_states is not None: |
|
encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3] |
|
|
|
if use_cache: |
|
present_key_value_states = present_key_value_states + (present_key_value_state,) |
|
|
|
if output_attentions: |
|
all_attentions = all_attentions + (layer_outputs[3],) |
|
if self.has_cross_attention: |
|
all_cross_attentions = all_cross_attentions + (layer_outputs[5],) |
|
|
|
|
|
if self.model_parallel: |
|
for k, v in self.device_map.items(): |
|
if i == v[-1] and "cuda:" + str(k) != self.last_device: |
|
hidden_states = hidden_states.to("cuda:" + str(k + 1)) |
|
|
|
hidden_states = self.final_layer_norm(hidden_states) |
|
hidden_states = self.dropout(hidden_states) |
|
|
|
|
|
if output_hidden_states: |
|
all_hidden_states = all_hidden_states + (hidden_states,) |
|
|
|
if not return_dict: |
|
outputs = tuple( |
|
v |
|
for v in [ |
|
hidden_states, |
|
present_key_value_states, |
|
all_hidden_states, |
|
all_attentions, |
|
all_cross_attentions, |
|
] |
|
if v is not None |
|
) |
|
outputs = outputs + (position_bias,) |
|
return outputs |
|
return BaseModelOutputWithPastAndCrossAttentionsAndPositionBias( |
|
last_hidden_state=hidden_states, |
|
past_key_values=present_key_value_states, |
|
hidden_states=all_hidden_states, |
|
attentions=all_attentions, |
|
cross_attentions=all_cross_attentions, |
|
position_bias=position_bias |
|
) |
|
|
|
|
|
@dataclass |
|
class DualDecoderModelOutput(ModelOutput): |
|
""" |
|
Base class for model dual decoder's outputs that also contains : pre-computed hidden states that can speed up sequential |
|
decoding. |
|
|
|
Args: |
|
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): |
|
Sequence of hidden-states at the output of the last layer of the decoder of the model. |
|
|
|
If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1, |
|
hidden_size)` is output. |
|
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): |
|
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape |
|
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape |
|
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. |
|
|
|
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention |
|
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. |
|
cross_decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): |
|
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + |
|
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. |
|
|
|
Hidden-states of the cross-attention decoder at the output of each layer. |
|
cross_decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): |
|
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, |
|
sequence_length)`. |
|
|
|
Attentions weights of the cross-attention decoder, after the attention softmax, used to compute the weighted average in the |
|
cross-attention heads. |
|
cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): |
|
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, |
|
sequence_length)`. |
|
|
|
Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the |
|
weighted average in the cross-attention heads. |
|
self_decoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): |
|
Sequence of hidden-states at the output of the last layer of the self-attention decoder of the model. |
|
self_decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): |
|
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + |
|
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. |
|
|
|
Hidden-states of the self-attention decoder at the output of each layer plus the optional initial embedding outputs. |
|
self_decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): |
|
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, |
|
sequence_length)`. |
|
|
|
Attentions weights of the self-attention decoder, after the attention softmax, used to compute the weighted average in the |
|
self-attention heads. |
|
""" |
|
|
|
last_hidden_state: torch.FloatTensor = None |
|
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None |
|
cross_decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None |
|
cross_decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None |
|
cross_attentions: Optional[Tuple[torch.FloatTensor]] = None |
|
self_decoder_last_hidden_state: Optional[torch.FloatTensor] = None |
|
self_decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None |
|
self_decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None |
|
|
|
@dataclass |
|
class DualDecoderLMOutput(ModelOutput): |
|
""" |
|
Base class for sequence-to-sequence language models outputs. |
|
|
|
Args: |
|
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): |
|
Language modeling loss. |
|
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): |
|
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). |
|
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): |
|
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape |
|
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape |
|
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. |
|
|
|
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention |
|
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. |
|
cross_decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): |
|
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + |
|
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. |
|
|
|
Hidden-states of the cross-attention decoder at the output of each layer. |
|
cross_decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): |
|
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, |
|
sequence_length)`. |
|
|
|
Attentions weights of the cross-attention decoder, after the attention softmax, used to compute the weighted average in the |
|
cross-attention heads. |
|
cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): |
|
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, |
|
sequence_length)`. |
|
|
|
Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the |
|
weighted average in the cross-attention heads. |
|
self_decoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): |
|
Sequence of hidden-states at the output of the last layer of the self-attention decoder of the model. |
|
self_decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): |
|
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + |
|
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. |
|
|
|
Hidden-states of the self-attention decoder at the output of each layer plus the optional initial embedding outputs. |
|
self_decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): |
|
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, |
|
sequence_length)`. |
|
|
|
Attentions weights of the self-attention decoder, after the attention softmax, used to compute the weighted average in the |
|
self-attention heads. |
|
""" |
|
|
|
loss: Optional[torch.FloatTensor] = None |
|
logits: torch.FloatTensor = None |
|
past_key_values: Optional[Tuple[Tuple[Tuple[torch.FloatTensor]]]] = None |
|
cross_decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None |
|
cross_decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None |
|
cross_attentions: Optional[Tuple[torch.FloatTensor]] = None |
|
self_decoder_last_hidden_state: Optional[torch.FloatTensor] = None |
|
self_decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None |
|
self_decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None |
|
|
|
@dataclass |
|
class DualDecoderDoubleHeadsOutput(ModelOutput): |
|
""" |
|
Base class for sequence-to-sequence language models outputs. |
|
|
|
Args: |
|
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): |
|
Language modeling loss. |
|
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): |
|
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). |
|
ss_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): |
|
Global representaion of the self-attention decoder. The last token of sequence is used to calculate this representation. |
|
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): |
|
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape |
|
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape |
|
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. |
|
|
|
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention |
|
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. |
|
cross_decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): |
|
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + |
|
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. |
|
|
|
Hidden-states of the cross-attention decoder at the output of each layer. |
|
cross_decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): |
|
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, |
|
sequence_length)`. |
|
|
|
Attentions weights of the cross-attention decoder, after the attention softmax, used to compute the weighted average in the |
|
cross-attention heads. |
|
cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): |
|
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, |
|
sequence_length)`. |
|
|
|
Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the |
|
weighted average in the cross-attention heads. |
|
self_decoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): |
|
Sequence of hidden-states at the output of the last layer of the self-attention decoder of the model. |
|
self_decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): |
|
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + |
|
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. |
|
|
|
Hidden-states of the self-attention decoder at the output of each layer plus the optional initial embedding outputs. |
|
self_decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): |
|
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, |
|
sequence_length)`. |
|
|
|
Attentions weights of the self-attention decoder, after the attention softmax, used to compute the weighted average in the |
|
self-attention heads. |
|
""" |
|
|
|
loss: Optional[torch.FloatTensor] = None |
|
logits: torch.FloatTensor = None |
|
ss_logits: torch.FloatTensor = None |
|
past_key_values: Optional[Tuple[Tuple[Tuple[torch.FloatTensor]]]] = None |
|
cross_decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None |
|
cross_decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None |
|
cross_attentions: Optional[Tuple[torch.FloatTensor]] = None |
|
self_decoder_last_hidden_state: Optional[torch.FloatTensor] = None |
|
self_decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None |
|
self_decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None |
|
|
|
|
|
@add_start_docstrings("""T5 Dual Decoder with a `language modeling` head on top.""", T5_START_DOCSTRING) |
|
class T5DualDecoderLMHeadModel(T5PreTrainedModel): |
|
|
|
def __init__(self, config: T5Config, add_pooling_layer: bool = True): |
|
config.is_encoder_decoder = False |
|
config.is_decoder = True |
|
super().__init__(config) |
|
self.model_dim = config.d_model |
|
|
|
self.shared = nn.Embedding(config.vocab_size, config.d_model) |
|
|
|
self_decoder_config = copy.deepcopy(config) |
|
self_decoder_config.is_decoder = True |
|
self_decoder_config.is_encoder_decoder = False |
|
self_decoder_config.add_cross_attention = False |
|
|
|
self.encoder = T5DecoderStack(self_decoder_config, self.shared) |
|
|
|
cross_decoder_config = copy.deepcopy(config) |
|
cross_decoder_config.is_decoder = True |
|
cross_decoder_config.is_encoder_decoder = False |
|
cross_decoder_config.add_cross_attention = True |
|
cross_decoder_config.num_layers = config.num_decoder_layers |
|
|
|
self.decoder = T5DecoderStack(cross_decoder_config, self.shared, has_relative_attention_bias=False) |
|
|
|
self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) |
|
|
|
|
|
self.post_init() |
|
|
|
|
|
self.model_parallel = False |
|
self.device_map = None |
|
|
|
@add_start_docstrings(PARALLELIZE_DOCSTRING) |
|
def parallelize(self, device_map=None): |
|
self.device_map = ( |
|
|
|
get_device_map(len(self.encoder.block), range(torch.cuda.device_count())) |
|
if device_map is None |
|
else device_map |
|
) |
|
|
|
assert_device_map(self.device_map, len(self.encoder.block)) |
|
|
|
|
|
|
|
self.encoder.parallelize(self.device_map) |
|
self.decoder.parallelize(self.device_map) |
|
self.lm_head = self.lm_head.to(self.decoder.first_device) |
|
self.model_parallel = True |
|
|
|
@add_start_docstrings(DEPARALLELIZE_DOCSTRING) |
|
def deparallelize(self): |
|
|
|
|
|
|
|
|
|
self.encoder.deparallelize() |
|
self.decoder.deparallelize() |
|
self.encoder = self.encoder.to("cpu") |
|
self.decoder = self.decoder.to("cpu") |
|
self.lm_head = self.lm_head.to("cpu") |
|
self.model_parallel = False |
|
self.device_map = None |
|
torch.cuda.empty_cache() |
|
|
|
def get_input_embeddings(self): |
|
return self.shared |
|
|
|
def set_input_embeddings(self, new_embeddings): |
|
self.shared = new_embeddings |
|
|
|
|
|
self.encoder.set_input_embeddings(new_embeddings) |
|
self.decoder.set_input_embeddings(new_embeddings) |
|
|
|
def set_output_embeddings(self, new_embeddings): |
|
self.lm_head = new_embeddings |
|
|
|
def get_output_embeddings(self): |
|
return self.lm_head |
|
|
|
def get_encoder(self): |
|
|
|
return self.encoder |
|
|
|
def get_decoder(self): |
|
|
|
return self.decoder |
|
|
|
@add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING) |
|
@replace_return_docstrings(output_type=DualDecoderLMOutput, config_class=_CONFIG_FOR_DOC_DDT5) |
|
def forward( |
|
self, |
|
input_ids: Optional[torch.LongTensor] = None, |
|
attention_mask: Optional[torch.FloatTensor] = None, |
|
head_mask: Optional[torch.FloatTensor] = None, |
|
decoder_head_mask: Optional[torch.FloatTensor] = None, |
|
cross_attn_head_mask: Optional[torch.Tensor] = None, |
|
|
|
encoder_hidden_states: Optional[torch.FloatTensor] = None, |
|
encoder_attention_mask: Optional[torch.FloatTensor] = None, |
|
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
labels: Optional[torch.LongTensor] = None, |
|
use_cache: Optional[bool] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
) -> Union[Tuple[torch.FloatTensor], DualDecoderLMOutput]: |
|
r""" |
|
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): |
|
Labels for computing the sequence classification/regression loss. Indices should be in `[-100, 0, ..., |
|
config.vocab_size - 1]`. All labels set to `-100` are ignored (masked), the loss is only computed for |
|
labels in `[0, ..., config.vocab_size]` |
|
Returns: |
|
Examples: |
|
```python |
|
>>> from transformers import T5Tokenizer, T5DualDecoderLMHeadModel |
|
>>> tokenizer = T5Tokenizer.from_pretrained("t5-small") |
|
>>> model = T5ForConditionalGeneration.from_pretrained("t5-small") |
|
>>> # training |
|
>>> input_ids = tokenizer("The <extra_id_0> walks in <extra_id_1> park", return_tensors="pt").input_ids |
|
>>> labels = tokenizer("<extra_id_0> cute dog <extra_id_1> the <extra_id_2>", return_tensors="pt").input_ids |
|
>>> outputs = model(input_ids=input_ids, labels=labels) |
|
>>> loss = outputs.loss |
|
>>> logits = outputs.logits |
|
>>> # inference |
|
>>> input_ids = tokenizer( |
|
... "summarize: studies have shown that owning a dog is good for you", return_tensors="pt" |
|
... ).input_ids # Batch size 1 |
|
>>> outputs = model.generate(input_ids) |
|
>>> print(tokenizer.decode(outputs[0], skip_special_tokens=True)) |
|
>>> # studies have shown that owning a dog is good for you. |
|
```""" |
|
use_cache = use_cache if use_cache is not None else self.config.use_cache |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
|
|
if head_mask is not None and decoder_head_mask is None: |
|
if self.config.num_layers == self.config.num_decoder_layers: |
|
warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning) |
|
decoder_head_mask = head_mask |
|
|
|
if past_key_values is not None: |
|
self_decoder_past_key_value = past_key_values[0] |
|
cross_decoder_past_key_value = past_key_values[1] |
|
else: |
|
self_decoder_past_key_value, cross_decoder_past_key_value = None, None |
|
|
|
if labels is not None and input_ids is None and inputs_embeds is None: |
|
|
|
input_ids = self._shift_right(labels) |
|
|
|
|
|
|
|
self_decoder_outputs = self.encoder( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
inputs_embeds=inputs_embeds, |
|
past_key_values=self_decoder_past_key_value, |
|
head_mask=head_mask, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
|
|
hidden_states = self_decoder_outputs[0] |
|
position_bias = self_decoder_outputs[-1] |
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.model_parallel: |
|
|
|
|
|
|
|
|
|
torch.cuda.set_device(self.decoder.first_device) |
|
hidden_states = hidden_states.to(self.decoder.first_device) |
|
if attention_mask is not None: |
|
attention_mask = attention_mask.to(self.decoder.first_device) |
|
|
|
|
|
|
|
|
|
cross_decoder_outputs = self.decoder( |
|
attention_mask=attention_mask, |
|
inputs_embeds=hidden_states, |
|
position_bias=position_bias, |
|
past_key_values=cross_decoder_past_key_value, |
|
encoder_hidden_states=encoder_hidden_states, |
|
encoder_attention_mask=encoder_attention_mask, |
|
head_mask=decoder_head_mask, |
|
cross_attn_head_mask=cross_attn_head_mask, |
|
use_cache=use_cache, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
|
|
sequence_output = cross_decoder_outputs[0] |
|
|
|
|
|
if self.model_parallel: |
|
|
|
|
|
torch.cuda.set_device(self.encoder.first_device) |
|
self.lm_head = self.lm_head.to(self.encoder.first_device) |
|
sequence_output = sequence_output.to(self.lm_head.weight.device) |
|
|
|
if self.config.tie_word_embeddings: |
|
|
|
|
|
sequence_output = sequence_output * (self.model_dim**-0.5) |
|
|
|
lm_logits = self.lm_head(sequence_output) |
|
|
|
loss = None |
|
if labels is not None: |
|
loss_fct = CrossEntropyLoss(ignore_index=-100) |
|
loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1)) |
|
|
|
|
|
if self_decoder_outputs.past_key_values is None or cross_decoder_outputs.past_key_values is None: |
|
past_key_values = None |
|
else: |
|
past_key_values=(self_decoder_outputs.past_key_values, cross_decoder_outputs.past_key_values) |
|
|
|
if not return_dict: |
|
output = (lm_logits, past_key_values) + cross_decoder_outputs[2:] + (self_decoder_outputs[0],) + self_decoder_outputs[2:] |
|
return ((loss,) + output) if loss is not None else output |
|
|
|
return DualDecoderLMOutput( |
|
loss=loss, |
|
logits=lm_logits, |
|
past_key_values=past_key_values, |
|
cross_decoder_hidden_states=cross_decoder_outputs.hidden_states, |
|
cross_decoder_attentions=cross_decoder_outputs.attentions, |
|
cross_attentions=cross_decoder_outputs.cross_attentions, |
|
self_decoder_last_hidden_state=self_decoder_outputs.last_hidden_state, |
|
self_decoder_hidden_states=self_decoder_outputs.hidden_states, |
|
self_decoder_attentions=self_decoder_outputs.attentions, |
|
) |
|
|
|
def prepare_inputs_for_generation( |
|
self, |
|
input_ids, |
|
past=None, |
|
attention_mask=None, |
|
head_mask=None, |
|
decoder_head_mask=None, |
|
cross_attn_head_mask=None, |
|
use_cache=None, |
|
|
|
encoder_hidden_states=None, |
|
encoder_attention_mask=None, |
|
**kwargs |
|
): |
|
|
|
|
|
if past is not None: |
|
input_ids = input_ids[:, -1:] |
|
|
|
return { |
|
"input_ids": input_ids, |
|
"past_key_values": past, |
|
|
|
"encoder_hidden_states": encoder_hidden_states, |
|
"encoder_attention_mask": encoder_attention_mask, |
|
"attention_mask": attention_mask, |
|
"head_mask": head_mask, |
|
"decoder_head_mask": decoder_head_mask, |
|
"cross_attn_head_mask": cross_attn_head_mask, |
|
"use_cache": use_cache, |
|
} |
|
|
|
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): |
|
return self._shift_right(labels) |
|
|
|
def _reorder_cache(self, past, beam_idx): |
|
if past is None: |
|
logger.warning("You might want to consider setting `use_cache=True` to speed up decoding") |
|
return past |
|
|
|
return (self._reorder_cache_single(past[0], beam_idx), self._reorder_cache_single(past[1], beam_idx)) |
|
|
|
def _reorder_cache_single(self, past, beam_idx): |
|
|
|
|
|
if past is None: |
|
logger.warning("You might want to consider setting `use_cache=True` to speed up decoding") |
|
return past |
|
|
|
reordered_decoder_past = () |
|
for layer_past_states in past: |
|
|
|
|
|
reordered_layer_past_states = () |
|
for layer_past_state in layer_past_states: |
|
|
|
reordered_layer_past_states = reordered_layer_past_states + ( |
|
layer_past_state.index_select(0, beam_idx.to(layer_past_state.device)), |
|
) |
|
|
|
assert reordered_layer_past_states[0].shape == layer_past_states[0].shape |
|
assert len(reordered_layer_past_states) == len(layer_past_states) |
|
|
|
reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,) |
|
return reordered_decoder_past |
|
|
|
|
|
|
|
@add_start_docstrings("""T5 Dual Decoder with a `language modeling` head on top.""", T5_START_DOCSTRING) |
|
class T5DualDecoderDoubleHeadsModel(T5PreTrainedModel): |
|
|
|
def __init__(self, config: T5Config, add_pooling_layer: bool = True): |
|
config.is_encoder_decoder = False |
|
config.is_decoder = True |
|
super().__init__(config) |
|
self.model_dim = config.d_model |
|
|
|
self.shared = nn.Embedding(config.vocab_size, config.d_model) |
|
|
|
self_decoder_config = copy.deepcopy(config) |
|
self_decoder_config.is_decoder = True |
|
self_decoder_config.is_encoder_decoder = False |
|
self_decoder_config.add_cross_attention = False |
|
|
|
self.encoder = T5DecoderStack(self_decoder_config, self.shared) |
|
|
|
cross_decoder_config = copy.deepcopy(config) |
|
cross_decoder_config.is_decoder = True |
|
cross_decoder_config.is_encoder_decoder = False |
|
cross_decoder_config.add_cross_attention = True |
|
cross_decoder_config.num_layers = config.num_decoder_layers |
|
|
|
self.decoder = T5DecoderStack(cross_decoder_config, self.shared, has_relative_attention_bias=False) |
|
|
|
self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) |
|
sequence_summary_config = copy.deepcopy(config) |
|
sequence_summary_config.summary_type = "cls_index" |
|
self.ss_head = SequenceSummary(config) |
|
|
|
|
|
self.post_init() |
|
|
|
|
|
self.model_parallel = False |
|
self.device_map = None |
|
|
|
@add_start_docstrings(PARALLELIZE_DOCSTRING) |
|
def parallelize(self, device_map=None): |
|
self.device_map = ( |
|
|
|
get_device_map(len(self.encoder.block), range(torch.cuda.device_count())) |
|
if device_map is None |
|
else device_map |
|
) |
|
|
|
assert_device_map(self.device_map, len(self.encoder.block)) |
|
|
|
|
|
|
|
|
|
self.encoder.parallelize(self.device_map) |
|
self.decoder.parallelize(self.device_map) |
|
self.lm_head = self.lm_head.to(self.decoder.first_device) |
|
self.ss_head = self.ss_head.to(self.decoder.first_device) |
|
self.model_parallel = True |
|
|
|
@add_start_docstrings(DEPARALLELIZE_DOCSTRING) |
|
def deparallelize(self): |
|
|
|
|
|
|
|
|
|
self.encoder.deparallelize() |
|
self.decoder.deparallelize() |
|
self.encoder = self.encoder.to("cpu") |
|
self.decoder = self.decoder.to("cpu") |
|
self.lm_head = self.lm_head.to("cpu") |
|
self.ss_head = self.ss_head.to("cpu") |
|
self.model_parallel = False |
|
self.device_map = None |
|
torch.cuda.empty_cache() |
|
|
|
def get_input_embeddings(self): |
|
return self.shared |
|
|
|
def set_input_embeddings(self, new_embeddings): |
|
self.shared = new_embeddings |
|
|
|
|
|
self.encoder.set_input_embeddings(new_embeddings) |
|
self.decoder.set_input_embeddings(new_embeddings) |
|
|
|
def set_output_embeddings(self, new_embeddings): |
|
self.lm_head = new_embeddings |
|
|
|
def get_output_embeddings(self): |
|
return self.lm_head |
|
|
|
def get_encoder(self): |
|
|
|
return self.encoder |
|
|
|
def get_decoder(self): |
|
|
|
return self.decoder |
|
|
|
@add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING) |
|
@replace_return_docstrings(output_type=DualDecoderDoubleHeadsOutput, config_class=_CONFIG_FOR_DOC_DDT5) |
|
def forward( |
|
self, |
|
input_ids: Optional[torch.LongTensor] = None, |
|
attention_mask: Optional[torch.FloatTensor] = None, |
|
head_mask: Optional[torch.FloatTensor] = None, |
|
decoder_head_mask: Optional[torch.FloatTensor] = None, |
|
cross_attn_head_mask: Optional[torch.Tensor] = None, |
|
|
|
encoder_hidden_states: Optional[torch.FloatTensor] = None, |
|
encoder_attention_mask: Optional[torch.FloatTensor] = None, |
|
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
labels: Optional[torch.LongTensor] = None, |
|
use_cache: Optional[bool] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
) -> Union[Tuple[torch.FloatTensor], DualDecoderDoubleHeadsOutput]: |
|
r""" |
|
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): |
|
Labels for computing the sequence classification/regression loss. Indices should be in `[-100, 0, ..., |
|
config.vocab_size - 1]`. All labels set to `-100` are ignored (masked), the loss is only computed for |
|
labels in `[0, ..., config.vocab_size]` |
|
Returns: |
|
Examples: |
|
```python |
|
>>> from transformers import T5Tokenizer, T5DualDecoderDoubleHeadsModel |
|
>>> tokenizer = T5Tokenizer.from_pretrained("veld-t5-base") |
|
>>> model = T5DualDecoderDoubleHeadsModel.from_pretrained("veld-t5-base") |
|
>>> # training |
|
>>> input_ids = tokenizer("The <extra_id_0> walks in <extra_id_1> park", return_tensors="pt").input_ids |
|
>>> labels = tokenizer("<extra_id_0> cute dog <extra_id_1> the <extra_id_2>", return_tensors="pt").input_ids |
|
>>> outputs = model(input_ids=input_ids, labels=labels) |
|
>>> loss = outputs.loss |
|
>>> logits = outputs.logits |
|
>>> # inference |
|
>>> input_ids = tokenizer( |
|
... "summarize: studies have shown that owning a dog is good for you", return_tensors="pt" |
|
... ).input_ids # Batch size 1 |
|
>>> outputs = model.generate(input_ids) |
|
>>> print(tokenizer.decode(outputs[0], skip_special_tokens=True)) |
|
>>> # studies have shown that owning a dog is good for you. |
|
```""" |
|
use_cache = use_cache if use_cache is not None else self.config.use_cache |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
|
|
if head_mask is not None and decoder_head_mask is None: |
|
if self.config.num_layers == self.config.num_decoder_layers: |
|
warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning) |
|
decoder_head_mask = head_mask |
|
|
|
if past_key_values is not None: |
|
self_decoder_past_key_value = past_key_values[0] |
|
cross_decoder_past_key_value = past_key_values[1] |
|
else: |
|
self_decoder_past_key_value, cross_decoder_past_key_value = None, None |
|
|
|
if labels is not None and input_ids is None and inputs_embeds is None: |
|
|
|
input_ids = self._shift_right(labels) |
|
|
|
|
|
|
|
self_decoder_outputs = self.encoder( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
inputs_embeds=inputs_embeds, |
|
past_key_values=self_decoder_past_key_value, |
|
head_mask=head_mask, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
|
|
hidden_states = self_decoder_outputs[0] |
|
position_bias = self_decoder_outputs[-1] |
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.model_parallel: |
|
|
|
|
|
|
|
|
|
torch.cuda.set_device(self.decoder.first_device) |
|
hidden_states = hidden_states.to(self.decoder.first_device) |
|
if attention_mask is not None: |
|
attention_mask = attention_mask.to(self.decoder.first_device) |
|
|
|
|
|
|
|
|
|
cross_decoder_outputs = self.decoder( |
|
attention_mask=attention_mask, |
|
inputs_embeds=hidden_states, |
|
position_bias=position_bias, |
|
past_key_values=cross_decoder_past_key_value, |
|
encoder_hidden_states=encoder_hidden_states, |
|
encoder_attention_mask=encoder_attention_mask, |
|
head_mask=decoder_head_mask, |
|
cross_attn_head_mask=cross_attn_head_mask, |
|
use_cache=use_cache, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
|
|
sequence_output = cross_decoder_outputs[0] |
|
|
|
|
|
if self.model_parallel: |
|
|
|
|
|
torch.cuda.set_device(self.encoder.first_device) |
|
self.lm_head = self.lm_head.to(self.encoder.first_device) |
|
sequence_output = sequence_output.to(self.lm_head.weight.device) |
|
|
|
if self.config.tie_word_embeddings: |
|
|
|
|
|
sequence_output = sequence_output * (self.model_dim**-0.5) |
|
|
|
lm_logits = self.lm_head(sequence_output) |
|
|
|
if self.config.pad_token_id is None: |
|
cls_index = None |
|
else: |
|
if input_ids is not None: |
|
cls_index = torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1 |
|
else: |
|
cls_index = None |
|
logger.warning( |
|
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " |
|
"unexpected if using padding tokens in conjunction with `inputs_embeds.`" |
|
) |
|
ss_logits = self.ss_head(hidden_states, cls_index=cls_index) |
|
|
|
loss = None |
|
if labels is not None: |
|
loss_fct = CrossEntropyLoss(ignore_index=-100) |
|
loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1)) |
|
|
|
|
|
if self_decoder_outputs.past_key_values is None or cross_decoder_outputs.past_key_values is None: |
|
past_key_values = None |
|
else: |
|
past_key_values=(self_decoder_outputs.past_key_values, cross_decoder_outputs.past_key_values) |
|
|
|
if not return_dict: |
|
output = (lm_logits, ss_logits, past_key_values) + cross_decoder_outputs[2:] + (self_decoder_outputs[0],) + self_decoder_outputs[2:] |
|
return ((loss,) + output) if loss is not None else output |
|
|
|
return DualDecoderDoubleHeadsOutput( |
|
loss=loss, |
|
logits=lm_logits, |
|
ss_logits=ss_logits, |
|
past_key_values=past_key_values, |
|
cross_decoder_hidden_states=cross_decoder_outputs.hidden_states, |
|
cross_decoder_attentions=cross_decoder_outputs.attentions, |
|
cross_attentions=cross_decoder_outputs.cross_attentions, |
|
self_decoder_last_hidden_state=self_decoder_outputs.last_hidden_state, |
|
self_decoder_hidden_states=self_decoder_outputs.hidden_states, |
|
self_decoder_attentions=self_decoder_outputs.attentions, |
|
) |
|
|
|
def prepare_inputs_for_generation( |
|
self, |
|
input_ids, |
|
past=None, |
|
attention_mask=None, |
|
head_mask=None, |
|
decoder_head_mask=None, |
|
cross_attn_head_mask=None, |
|
use_cache=None, |
|
|
|
encoder_hidden_states=None, |
|
encoder_attention_mask=None, |
|
**kwargs |
|
): |
|
|
|
|
|
if past is not None: |
|
input_ids = input_ids[:, -1:] |
|
|
|
return { |
|
"input_ids": input_ids, |
|
"past_key_values": past, |
|
|
|
"encoder_hidden_states": encoder_hidden_states, |
|
"encoder_attention_mask": encoder_attention_mask, |
|
"attention_mask": attention_mask, |
|
"head_mask": head_mask, |
|
"decoder_head_mask": decoder_head_mask, |
|
"cross_attn_head_mask": cross_attn_head_mask, |
|
"use_cache": use_cache, |
|
} |
|
|
|
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): |
|
return self._shift_right(labels) |
|
|
|
def _reorder_cache(self, past, beam_idx): |
|
if past is None: |
|
logger.warning("You might want to consider setting `use_cache=True` to speed up decoding") |
|
return past |
|
|
|
return (self._reorder_cache_single(past[0], beam_idx), self._reorder_cache_single(past[1], beam_idx)) |
|
|
|
def _reorder_cache_single(self, past, beam_idx): |
|
|
|
|
|
if past is None: |
|
logger.warning("You might want to consider setting `use_cache=True` to speed up decoding") |
|
return past |
|
|
|
reordered_decoder_past = () |
|
for layer_past_states in past: |
|
|
|
|
|
reordered_layer_past_states = () |
|
for layer_past_state in layer_past_states: |
|
|
|
reordered_layer_past_states = reordered_layer_past_states + ( |
|
layer_past_state.index_select(0, beam_idx.to(layer_past_state.device)), |
|
) |
|
|
|
assert reordered_layer_past_states[0].shape == layer_past_states[0].shape |
|
assert len(reordered_layer_past_states) == len(layer_past_states) |
|
|
|
reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,) |
|
return reordered_decoder_past |
|
|
|
|
|
from transformers.configuration_utils import PretrainedConfig |
|
from transformers.modeling_utils import PreTrainedModel |
|
from transformers.models.vision_encoder_decoder.modeling_vision_encoder_decoder import ( |
|
VISION_ENCODER_DECODER_START_DOCSTRING, |
|
VISION_ENCODER_DECODER_INPUTS_DOCSTRING, |
|
) |
|
from transformers.models.auto.configuration_auto import AutoConfig |
|
from transformers.models.auto.modeling_auto import AutoModel |
|
from transformers import ViTModel, ViTConfig |
|
|
|
from .configuration_veld import VELDConfig |
|
|
|
_CONFIG_FOR_DOC_VELDT5 = "VELDConfig" |
|
|
|
@dataclass |
|
class VELDDoubleHeadsOutput(ModelOutput): |
|
""" |
|
Base class for sequence-to-sequence language models outputs. |
|
|
|
Args: |
|
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): |
|
Language modeling loss. |
|
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): |
|
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). |
|
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): |
|
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape |
|
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape |
|
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. |
|
|
|
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention |
|
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. |
|
decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): |
|
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + |
|
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. |
|
|
|
Hidden-states of the decoder at the output of each layer plus the initial embedding outputs. |
|
decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): |
|
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, |
|
sequence_length)`. |
|
|
|
Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the |
|
self-attention heads. |
|
cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): |
|
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, |
|
sequence_length)`. |
|
|
|
Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the |
|
weighted average in the cross-attention heads. |
|
encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): |
|
Sequence of hidden-states at the output of the last layer of the encoder of the model. |
|
encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): |
|
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + |
|
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. |
|
|
|
Hidden-states of the encoder at the output of each layer plus the initial embedding outputs. |
|
encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): |
|
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, |
|
sequence_length)`. |
|
|
|
Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the |
|
self-attention heads. |
|
""" |
|
|
|
loss: Optional[torch.FloatTensor] = None |
|
c_loss: Optional[torch.FloatTensor] = None |
|
logits: torch.FloatTensor = None |
|
e_logits_g: torch.FloatTensor = None |
|
e_logits_l: torch.FloatTensor = None |
|
d_logits: torch.FloatTensor = None |
|
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None |
|
decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None |
|
decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None |
|
cross_attentions: Optional[Tuple[torch.FloatTensor]] = None |
|
encoder_last_hidden_state: Optional[torch.FloatTensor] = None |
|
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None |
|
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None |
|
|
|
@add_start_docstrings(VISION_ENCODER_DECODER_START_DOCSTRING) |
|
class VELDModel(PreTrainedModel): |
|
r""" |
|
[`VELDModel`] is a generic model class that will be instantiated as a transformer architecture with |
|
one of the base vision model classes of the library as encoder and another one as dual decoder when created with the |
|
:meth*~transformers.AutoModel.from_pretrained* class method for the encoder. |
|
""" |
|
config_class = VELDConfig |
|
base_model_prefix = "veld" |
|
main_input_name = "pixel_values" |
|
supports_gradient_checkpointing = True |
|
|
|
def __init__( |
|
self, |
|
config: Optional[PretrainedConfig] = None, |
|
encoder: Optional[PreTrainedModel] = None, |
|
decoder: Optional[PreTrainedModel] = None, |
|
): |
|
if config is None and (encoder is None or decoder is None): |
|
raise ValueError("Either a configuration or an encoder and a decoder has to be provided.") |
|
if config is None: |
|
config = VELDConfig.from_encoder_decoder_configs(encoder.config, decoder.config) |
|
else: |
|
if not isinstance(config, self.config_class): |
|
raise ValueError(f"Config: {config} has to be of type {self.config_class}") |
|
|
|
if config.decoder.cross_attention_hidden_size is not None: |
|
if config.decoder.cross_attention_hidden_size != config.encoder.hidden_size: |
|
raise ValueError( |
|
"If `cross_attention_hidden_size` is specified in the decoder's configuration, it has to be equal" |
|
f" to the encoder's `hidden_size`. Got {config.decoder.cross_attention_hidden_size} for" |
|
f" `config.decoder.cross_attention_hidden_size` and {config.encoder.hidden_size} for" |
|
" `config.encoder.hidden_size`." |
|
) |
|
|
|
|
|
|
|
config.tie_word_embeddings = False |
|
super().__init__(config) |
|
|
|
if encoder is None: |
|
encoder = ViTModel(config.encoder, add_pooling_layer=False) |
|
|
|
if decoder is None: |
|
decoder = T5DualDecoderDoubleHeadsModel(config.decoder) |
|
|
|
self.encoder = encoder |
|
self.decoder = decoder |
|
|
|
if self.encoder.config.to_dict() != self.config.encoder.to_dict(): |
|
logger.warning( |
|
f"Config of the encoder: {self.encoder.__class__} is overwritten by shared encoder config:" |
|
f" {self.config.encoder}" |
|
) |
|
if self.decoder.config.to_dict() != self.config.decoder.to_dict(): |
|
logger.warning( |
|
f"Config of the decoder: {self.decoder.__class__} is overwritten by shared decoder config:" |
|
f" {self.config.decoder}" |
|
) |
|
|
|
|
|
|
|
self.encoder.config = self.config.encoder |
|
self.decoder.config = self.config.decoder |
|
|
|
|
|
if ( |
|
self.encoder.config.hidden_size != self.decoder.config.hidden_size |
|
and self.decoder.config.cross_attention_hidden_size is None |
|
): |
|
self.enc_to_dec_proj = nn.Linear(self.encoder.config.hidden_size, self.decoder.config.hidden_size) |
|
|
|
if self.encoder.get_output_embeddings() is not None: |
|
raise ValueError( |
|
f"The encoder {self.encoder} should not have a LM Head. Please use a model without LM Head" |
|
) |
|
|
|
|
|
pooling_config = copy.deepcopy(self.encoder.config) |
|
pooling_config.summary_type = "attn" |
|
self.global_pooling = SequenceSummary(pooling_config, num_queries=self.config.num_queries_global) |
|
self.local_pooling = SequenceSummary(pooling_config, num_queries=self.config.num_queries_local) |
|
|
|
|
|
def _set_gradient_checkpointing(self, module, value=False): |
|
|
|
self.encoder._set_gradient_checkpointing(module, value=value) |
|
self.decoder._set_gradient_checkpointing(module, value=value) |
|
|
|
def get_encoder(self): |
|
return self.encoder |
|
|
|
def get_decoder(self): |
|
return self.decoder |
|
|
|
def get_output_embeddings(self): |
|
return self.decoder.get_output_embeddings() |
|
|
|
def set_output_embeddings(self, new_embeddings): |
|
return self.decoder.set_output_embeddings(new_embeddings) |
|
|
|
@classmethod |
|
def from_pretrained(cls, *args, **kwargs): |
|
|
|
if kwargs.get("_fast_init", False): |
|
logger.warning( |
|
"Fast initialization is currently not supported for VELDModel. " |
|
"Falling back to slow initialization..." |
|
) |
|
kwargs["_fast_init"] = False |
|
return super().from_pretrained(*args, **kwargs) |
|
|
|
@classmethod |
|
def from_encoder_decoder_pretrained( |
|
cls, |
|
encoder_pretrained_model_name_or_path: str = None, |
|
decoder_pretrained_model_name_or_path: str = None, |
|
*model_args, |
|
**kwargs |
|
) -> PreTrainedModel: |
|
r""" |
|
Instantiate an encoder and a decoder from one or two base classes of the library from pretrained model |
|
checkpoints. |
|
|
|
|
|
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train |
|
the model, you need to first set it back in training mode with `model.train()`. |
|
|
|
Params: |
|
encoder_pretrained_model_name_or_path (`str`, *optional*): |
|
Information necessary to initiate the image encoder. Can be either: |
|
|
|
- A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. An |
|
example is `google/vit-base-patch16-224-in21k`. |
|
- A path to a *directory* containing model weights saved using |
|
[`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`. |
|
- A path or url to a *tensorflow index checkpoint file* (e.g, `./tf_model/model.ckpt.index`). In |
|
this case, `from_tf` should be set to `True` and a configuration object should be provided as |
|
`config` argument. This loading path is slower than converting the TensorFlow checkpoint in a |
|
PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards. |
|
|
|
decoder_pretrained_model_name_or_path (`str`, *optional*, defaults to `None`): |
|
Information necessary to initiate the text decoder. Can be either: |
|
|
|
- A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. |
|
Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced under a |
|
user or organization name, like `dbmdz/bert-base-german-cased`. |
|
- A path to a *directory* containing model weights saved using |
|
[`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`. |
|
- A path or url to a *tensorflow index checkpoint file* (e.g, `./tf_model/model.ckpt.index`). In |
|
this case, `from_tf` should be set to `True` and a configuration object should be provided as |
|
`config` argument. This loading path is slower than converting the TensorFlow checkpoint in a |
|
PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards. |
|
|
|
model_args (remaining positional arguments, *optional*): |
|
All remaning positional arguments will be passed to the underlying model's `__init__` method. |
|
|
|
kwargs (remaining dictionary of keyword arguments, *optional*): |
|
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g., |
|
`output_attentions=True`). |
|
|
|
- To update the encoder configuration, use the prefix *encoder_* for each configuration parameter. |
|
- To update the decoder configuration, use the prefix *decoder_* for each configuration parameter. |
|
- To update the parent model configuration, do not use a prefix for each configuration parameter. |
|
|
|
Behaves differently depending on whether a `config` is provided or automatically loaded. |
|
|
|
Example: |
|
|
|
```python |
|
>>> from modeling_veld import VELDModel |
|
|
|
>>> # initialize a vit-t5 from a pretrained ViT and a pretrained T5 model. Note that the cross-attention layers will be randomly initialized |
|
>>> model = VELDModel.from_encoder_decoder_pretrained( |
|
... "google/vit-base-patch16-224-in21k", "t5-base" |
|
... ) |
|
>>> # saving model after fine-tuning |
|
>>> model.save_pretrained("./vit-t5") |
|
>>> # load fine-tuned model |
|
>>> model = VELDModel.from_pretrained("./vit-t5") |
|
```""" |
|
|
|
kwargs_encoder = { |
|
argument[len("encoder_") :]: value for argument, value in kwargs.items() if argument.startswith("encoder_") |
|
} |
|
|
|
kwargs_decoder = { |
|
argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_") |
|
} |
|
|
|
|
|
for key in kwargs_encoder.keys(): |
|
del kwargs["encoder_" + key] |
|
for key in kwargs_decoder.keys(): |
|
del kwargs["decoder_" + key] |
|
|
|
|
|
|
|
|
|
encoder = kwargs_encoder.pop("model", None) |
|
if encoder is None: |
|
if encoder_pretrained_model_name_or_path is None: |
|
raise ValueError( |
|
"If `encoder_model` is not defined as an argument, a `encoder_pretrained_model_name_or_path` has " |
|
"to be defined." |
|
) |
|
|
|
if "config" not in kwargs_encoder: |
|
encoder_config, kwargs_encoder = ViTConfig.from_pretrained( |
|
encoder_pretrained_model_name_or_path, **kwargs_encoder, return_unused_kwargs=True |
|
) |
|
|
|
if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True: |
|
logger.info( |
|
f"Initializing {encoder_pretrained_model_name_or_path} as a encoder model " |
|
"from a decoder model. Cross-attention and casual mask are disabled." |
|
) |
|
encoder_config.is_decoder = False |
|
encoder_config.add_cross_attention = False |
|
|
|
kwargs_encoder["config"] = encoder_config |
|
|
|
encoder = ViTModel.from_pretrained(encoder_pretrained_model_name_or_path, add_pooling_layer=False, *model_args, **kwargs_encoder) |
|
|
|
decoder = kwargs_decoder.pop("model", None) |
|
if decoder is None: |
|
if decoder_pretrained_model_name_or_path is None: |
|
raise ValueError( |
|
"If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has " |
|
"to be defined." |
|
) |
|
|
|
if "config" not in kwargs_decoder: |
|
decoder_config, kwargs_decoder = T5Config.from_pretrained( |
|
decoder_pretrained_model_name_or_path, **kwargs_decoder, return_unused_kwargs=True |
|
) |
|
|
|
if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False: |
|
logger.info( |
|
f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. Cross attention" |
|
f" layers are added to {decoder_pretrained_model_name_or_path} and randomly initialized if" |
|
f" {decoder_pretrained_model_name_or_path}'s architecture allows for cross attention layers." |
|
) |
|
decoder_config.is_decoder = True |
|
decoder_config.add_cross_attention = True |
|
|
|
kwargs_decoder["config"] = decoder_config |
|
|
|
if kwargs_decoder["config"].is_decoder is False or kwargs_decoder["config"].add_cross_attention is False: |
|
logger.warning( |
|
f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. " |
|
f"In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, " |
|
"make sure that the attributes `is_decoder` and `add_cross_attention` of `decoder_config` " |
|
"passed to `.from_encoder_decoder_pretrained(...)` are set to `True` or do not pass a " |
|
"`decoder_config` to `.from_encoder_decoder_pretrained(...)`" |
|
) |
|
|
|
decoder = T5DualDecoderDoubleHeadsModel.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder) |
|
|
|
|
|
config = VELDConfig.from_encoder_decoder_configs(encoder.config, decoder.config, **kwargs) |
|
|
|
|
|
config.tie_word_embeddings = False |
|
return cls(encoder=encoder, decoder=decoder, config=config) |
|
|
|
@add_start_docstrings_to_model_forward(VISION_ENCODER_DECODER_INPUTS_DOCSTRING) |
|
@replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC_VELDT5) |
|
def forward( |
|
self, |
|
pixel_values=None, |
|
decoder_input_ids=None, |
|
decoder_attention_mask=None, |
|
encoder_outputs=None, |
|
past_key_values=None, |
|
decoder_inputs_embeds=None, |
|
labels=None, |
|
return_contrastive_loss=None, |
|
use_cache=None, |
|
output_attentions=None, |
|
output_hidden_states=None, |
|
return_dict=None, |
|
logit_temperature=1.0, |
|
label_smoothing=0.0, |
|
**kwargs, |
|
): |
|
r""" |
|
Returns: |
|
|
|
Examples: |
|
|
|
```python |
|
>>> from transformers import AutoTokenizer, ViTFeatureExtractor, VELDModel |
|
>>> import requests |
|
>>> from PIL import Image |
|
>>> import torch |
|
|
|
>>> processor = ViTFeatureExtractor.from_pretrained("KETI-AIR/veld-base") |
|
>>> tokenizer = AutoTokenizer.from_pretrained("KETI-AIR/veld-base") |
|
>>> model = VELDModel.from_pretrained("KETI-AIR/veld-base") |
|
|
|
>>> # load image from the IAM dataset |
|
>>> url = "https://fki.tic.heia-fr.ch/static/img/a01-122-02.jpg" |
|
>>> image = Image.open(requests.get(url, stream=True).raw).convert("RGB") |
|
|
|
>>> # training |
|
>>> pixel_values = processor(image, return_tensors="pt").pixel_values |
|
>>> text = "hello world" |
|
>>> labels = tokenizer(text, return_tensors="pt").input_ids |
|
>>> outputs = model(pixel_values=pixel_values, labels=labels) |
|
>>> loss = outputs.loss |
|
|
|
>>> # inference (generation) |
|
>>> generated_ids = model.generate(pixel_values, max_new_tokens=20) |
|
>>> generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] |
|
```""" |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
kwargs_encoder = {argument: value for argument, value in kwargs.items() if not argument.startswith("decoder_")} |
|
|
|
kwargs_decoder = { |
|
argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_") |
|
} |
|
|
|
if encoder_outputs is None and pixel_values is not None: |
|
|
|
|
|
|
|
encoder_outputs = self.encoder( |
|
pixel_values, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
**kwargs_encoder, |
|
) |
|
elif isinstance(encoder_outputs, tuple): |
|
encoder_outputs = BaseModelOutput(*encoder_outputs) |
|
|
|
encoder_hidden_states = None if encoder_outputs is None else encoder_outputs[0] |
|
pooler_output_local = None if encoder_outputs is None else self.local_pooling(encoder_hidden_states) |
|
pooler_output_global = None if encoder_outputs is None or return_contrastive_loss is None else self.global_pooling(pooler_output_local).squeeze(1) |
|
|
|
|
|
if ( |
|
self.encoder.config.hidden_size != self.decoder.config.hidden_size |
|
and self.decoder.config.cross_attention_hidden_size is None |
|
and pooler_output_local is not None |
|
): |
|
pooler_output_local = self.enc_to_dec_proj(pooler_output_local) |
|
|
|
|
|
|
|
encoder_attention_mask = None |
|
|
|
if (labels is not None) and (decoder_input_ids is None and decoder_inputs_embeds is None): |
|
decoder_input_ids = self.decoder.prepare_decoder_input_ids_from_labels(labels) |
|
|
|
|
|
decoder_outputs = self.decoder( |
|
input_ids=decoder_input_ids, |
|
attention_mask=decoder_attention_mask, |
|
encoder_hidden_states=pooler_output_local, |
|
encoder_attention_mask=encoder_attention_mask, |
|
inputs_embeds=decoder_inputs_embeds, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
use_cache=use_cache, |
|
past_key_values=past_key_values, |
|
return_dict=return_dict, |
|
**kwargs_decoder, |
|
) |
|
|
|
|
|
loss = None |
|
if labels is not None: |
|
logits = decoder_outputs.logits if return_dict else decoder_outputs[0] |
|
loss_fct = CrossEntropyLoss() |
|
loss = loss_fct(logits.reshape(-1, self.decoder.config.vocab_size), labels.view(-1)) |
|
|
|
c_loss = None |
|
if return_contrastive_loss is not None and encoder_outputs is not None: |
|
decoder_logits = decoder_outputs.ss_logits if return_dict else decoder_outputs[0] |
|
encoder_logits = pooler_output_global |
|
loss_fct = CrossEntropyLoss(label_smoothing=label_smoothing) |
|
|
|
if ( |
|
self.encoder.config.hidden_size != self.decoder.config.hidden_size |
|
and self.decoder.config.cross_attention_hidden_size is None |
|
): |
|
encoder_logits = self.enc_to_dec_proj(encoder_logits) |
|
|
|
|
|
encoder_logits = nn.functional.normalize(encoder_logits) |
|
decoder_logits = nn.functional.normalize(decoder_logits) |
|
|
|
batch_size = encoder_logits.size(0) |
|
scores = torch.mm(decoder_logits, encoder_logits.t()) |
|
target = torch.arange(batch_size).to(decoder_logits.device) |
|
|
|
c_loss = loss_fct(scores/logit_temperature, target) + loss_fct(scores.t()/logit_temperature, target) |
|
|
|
|
|
if decoder_outputs.self_decoder_hidden_states is not None and decoder_outputs.cross_decoder_hidden_states is not None: |
|
decoder_hidden_states = decoder_outputs.self_decoder_hidden_states + decoder_outputs.cross_decoder_hidden_states |
|
else: |
|
decoder_hidden_states = None |
|
|
|
if decoder_outputs.self_decoder_attentions is not None and decoder_outputs.cross_decoder_attentions is not None: |
|
decoder_attentions = decoder_outputs.self_decoder_attentions + decoder_outputs.cross_decoder_attentions |
|
else: |
|
decoder_attentions = None |
|
|
|
if not return_dict: |
|
outputs = ( |
|
decoder_outputs.logits, |
|
pooler_output_global, |
|
pooler_output_local, |
|
decoder_outputs.ss_logits, |
|
decoder_outputs.past_key_values, |
|
decoder_hidden_states, |
|
decoder_attentions, |
|
decoder_outputs.cross_attentions, |
|
None if encoder_outputs is None else encoder_outputs.last_hidden_state, |
|
None if encoder_outputs is None else encoder_outputs.hidden_states, |
|
None if encoder_outputs is None else encoder_outputs.attentions, |
|
) |
|
if c_loss is not None: |
|
outputs = (c_loss,) + outputs |
|
if loss is not None: |
|
return (loss,) + outputs |
|
else: |
|
return outputs |
|
|
|
return VELDDoubleHeadsOutput( |
|
loss=loss, |
|
c_loss=c_loss, |
|
logits=decoder_outputs.logits, |
|
e_logits_g=pooler_output_global, |
|
e_logits_l=pooler_output_local, |
|
d_logits=decoder_outputs.ss_logits, |
|
past_key_values=decoder_outputs.past_key_values, |
|
decoder_hidden_states=decoder_hidden_states, |
|
decoder_attentions=decoder_attentions, |
|
cross_attentions=decoder_outputs.cross_attentions, |
|
encoder_last_hidden_state=None if encoder_outputs is None else encoder_outputs.last_hidden_state, |
|
encoder_hidden_states=None if encoder_outputs is None else encoder_outputs.hidden_states, |
|
encoder_attentions=None if encoder_outputs is None else encoder_outputs.attentions, |
|
) |
|
|
|
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): |
|
return self.decoder.prepare_decoder_input_ids_from_labels(labels) |
|
|
|
def prepare_inputs_for_generation( |
|
self, input_ids, past=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs |
|
): |
|
decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids, past=past) |
|
decoder_attention_mask = decoder_inputs["attention_mask"] if "attention_mask" in decoder_inputs else None |
|
input_dict = { |
|
"attention_mask": attention_mask, |
|
"decoder_attention_mask": decoder_attention_mask, |
|
"decoder_input_ids": decoder_inputs["input_ids"], |
|
"encoder_outputs": encoder_outputs, |
|
"past_key_values": decoder_inputs["past_key_values"], |
|
"use_cache": use_cache, |
|
} |
|
return input_dict |
|
|
|
def resize_token_embeddings(self, *args, **kwargs): |
|
raise NotImplementedError( |
|
"Resizing the embedding layers via the VisionEncoderDecoderModel directly is not supported.Please use the" |
|
" respective methods of the wrapped decoder object (model.decoder.resize_token_embeddings(...))" |
|
) |
|
|
|
def _reorder_cache(self, past, beam_idx): |
|
|
|
return self.decoder._reorder_cache(past, beam_idx) |
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
from transformers import AutoTokenizer, ViTFeatureExtractor |
|
from PIL import Image |
|
|
|
VISION_PRETRAINED_MODEL = "google/vit-base-patch16-384" |
|
LANGUAGE_PRETRAINED_MODEL = "KETI-AIR/ke-t5-base" |
|
|
|
test_inputs = [ |
|
"To update the encoder configuration, use the prefix *encoder_* for each configuration parameter.", |
|
"To update the parent model configuration,", |
|
] |
|
tokenizer = AutoTokenizer.from_pretrained(LANGUAGE_PRETRAINED_MODEL) |
|
inps = tokenizer(test_inputs, padding=True, truncation="longest_first", return_tensors="pt") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
feature_extractor = ViTFeatureExtractor.from_pretrained(VISION_PRETRAINED_MODEL) |
|
images = [Image.open("images/sample.jpg"), Image.open("images/sample2.jpg")] |
|
pixel_values = feature_extractor(images=images, return_tensors="pt").pixel_values |
|
|
|
model = VELDModel.from_encoder_decoder_pretrained( |
|
VISION_PRETRAINED_MODEL, |
|
LANGUAGE_PRETRAINED_MODEL |
|
) |
|
|
|
outputs = model( |
|
labels=inps.input_ids, |
|
return_contrastive_loss=True, |
|
decoder_attention_mask=inps.attention_mask |
|
) |
|
print(outputs.loss) |
|
print(outputs.c_loss) |
|
|
|
outputs = model( |
|
pixel_values=pixel_values, |
|
labels=inps.input_ids, |
|
return_contrastive_loss=True, |
|
decoder_attention_mask=inps.attention_mask) |
|
print(outputs.loss) |
|
print(outputs.c_loss) |
|
|
|
outputs = model( |
|
pixel_values=pixel_values, |
|
labels=inps.input_ids, |
|
decoder_attention_mask=inps.attention_mask) |
|
print(outputs.loss) |
|
print(outputs.c_loss) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|