import math import torch import torch.nn.functional as F import torch.utils.checkpoint from torch import nn from torch.nn import CrossEntropyLoss from transformers.modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions, ) from transformers.modeling_utils import PreTrainedModel from transformers.utils import ( logging, ) from typing import List, Optional, Tuple, Union from .configuration_gpt_refact import GPTRefactConfig logger = logging.get_logger(__name__) @torch.jit.script def upcast_masked_softmax( x: torch.Tensor, mask: torch.Tensor, mask_value: torch.Tensor, scale: float, softmax_dtype: torch.dtype ): input_dtype = x.dtype x = x.to(softmax_dtype) * scale x = torch.where(mask, x, mask_value) x = torch.nn.functional.softmax(x, dim=-1).to(input_dtype) return x @torch.jit.script def upcast_softmax(x: torch.Tensor, scale: float, softmax_dtype: torch.dtype): input_dtype = x.dtype x = x.to(softmax_dtype) * scale x = torch.nn.functional.softmax(x, dim=-1).to(input_dtype) return x @torch.jit.script def masked_softmax(x: torch.Tensor, mask: torch.Tensor, mask_value: torch.Tensor): x = torch.where(mask, x, mask_value) x = torch.nn.functional.softmax(x, dim=-1) return x @torch.jit.script def _get_slopes(attn_heads: int, dev: torch.device) -> torch.Tensor: """ ## Get head-specific slope $m$ for each head * `n_heads` is the number of heads in the attention layer $n$ The slope for first head is $$\frac{1}{2^{\frac{8}{n}}} = 2^{-\frac{8}{n}}$$ The slopes for the rest of the heads are in a geometric series with a ratio same as above. For instance when the number of heads is $8$ the slopes are $$\frac{1}{2^1}, \frac{1}{2^2}, \dots, \frac{1}{2^8}$$ """ # Get the closest power of 2 to `n_heads`. # If `n_heads` is not a power of 2, then we first calculate slopes to the closest (smaller) power of 2, # and then add the remaining slopes. n = 2 ** math.floor(math.log(attn_heads, 2)) # $2^{-\frac{8}{n}}$ m_0 = 2.0 ** (-8.0 / n) # $2^{-1\frac{8}{n}}, 2^{-2 \frac{8}{n}}, 2^{-3 \frac{8}{n}}, \dots$ m = torch.pow(m_0, torch.arange(1, 1 + n, device=dev)) # If `n_heads` is not a power of 2, then we add the remaining slopes. # We calculate the remaining slopes for $n * 2$ (avoiding slopes added previously). # And pick the slopes upto `n_heads`. if n < attn_heads: # $2^{-\frac{8}{2n}}$ m_hat_0 = 2.0 ** (-4.0 / n) # $2^{-1\frac{8}{2n}}, 2^{-3 \frac{8}{2n}}, 2^{-5 \frac{8}{2n}}, \dots$ # Note that we take steps by $2$ to avoid slopes added previously. m_hat = torch.pow(m_hat_0, torch.arange(1, 1 + 2 * (attn_heads - n), 2, device=dev)) # Concatenate the slopes with the remaining slopes. m = torch.cat([m, m_hat]) return m @torch.jit.script def get_alibi_biases( B: int, T: int, attn_heads: int, dev: torch.device, dtype: torch.dtype, causal: bool = True) -> torch.Tensor: """ ## Calculate the attention biases matrix * `n_heads` is the number of heads in the attention layer * `mask` is the attention mask of shape `[seq_len_q, seq_len_k]` This returns a matrix of shape `[seq_len_q, seq_len_k, n_heads, ]` with ALiBi attention biases. """ # Get slopes $m$ for each head if causal: mask = (torch.triu(torch.ones((T, T), device=dev)) == 1).transpose(0, 1) else: mask = torch.ones((T, T), device=dev, dtype=torch.bool) m = _get_slopes(attn_heads, dev) # Calculate distances $[0, 1, \dots, N]$ # Here we calculate the distances using the mask. # # Since it's causal mask we can just use $[0, 1, \dots, N]$ too. # `distance = torch.arange(mask.shape[1], dtype=torch.long, device=mask.device)[None, :]` distance = mask.cumsum(dim=-1) # Multiply them pair-wise to get the AliBi bias matrix biases = distance[:, :, None] * m[None, None, :] biases = biases.permute(2, 0, 1)[None, :, :T, :T] biases = biases.repeat(B, 1, 1, 1) return biases.to(dtype).contiguous() class Attention(nn.Module): def __init__(self, config, layer_idx=None): super().__init__() self.mask_value = None self.embed_dim = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = self.embed_dim // self.num_heads self.kv_attn_heads = 1 self.scale = self.head_dim ** -0.5 if self.head_dim * self.num_heads != self.embed_dim: raise ValueError( f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" f" {self.num_heads})." ) self.layer_idx = layer_idx self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32 self.scale_attention_softmax_in_fp32 = ( config.scale_attention_softmax_in_fp32 and config.attention_softmax_in_fp32 ) self.q = nn.Linear(self.embed_dim, self.embed_dim, bias=False) self.k = nn.Linear(self.embed_dim, self.head_dim, bias=False) self.v = nn.Linear(self.embed_dim, self.head_dim, bias=False) self.c_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False) def _attn(self, query, key, value, attention_mask=None, alibi=None): dtype = query.dtype softmax_dtype = torch.float32 if self.attention_softmax_in_fp32 else dtype upcast = dtype != softmax_dtype unscale = self.layer_idx + 1 if self.scale_attention_softmax_in_fp32 and upcast else 1 attn_weights = alibi + torch.matmul(query * self.scale, key) if upcast: if attention_mask is None: attn_weights = upcast_softmax(attn_weights, unscale, softmax_dtype) else: mask_value = self._get_mask_value(attn_weights.device, softmax_dtype) attn_weights = upcast_masked_softmax(attn_weights, attention_mask, mask_value, unscale, softmax_dtype) else: if attention_mask is not None: attn_weights = torch.masked_fill(attn_weights, attention_mask, -10000) attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1) attn_output = torch.matmul(attn_weights, value) return attn_output, attn_weights def _split_heads(self, tensor): new_shape = tensor.shape[:-1] + (self.num_heads, self.head_dim) tensor = tensor.view(new_shape) return tensor.permute(0, 2, 1, 3) def forward( self, hidden_states: torch.Tensor, layer_past: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, alibi: Optional[torch.Tensor] = None, use_cache: Optional[bool] = False, output_attentions: Optional[bool] = False, ) -> Union[ Tuple[torch.Tensor, Optional[torch.Tensor]], Tuple[torch.Tensor, Optional[torch.Tensor], Tuple[torch.Tensor, ...]], ]: b, t, _ = hidden_states.shape query = self.q(hidden_states) key = self.k(hidden_states) value = self.v(hidden_states) query = self._split_heads(query) key = key.view(b, t, self.kv_attn_heads, self.head_dim).permute(0, 2, 1, 3) value = value.view(b, t, self.kv_attn_heads, self.head_dim).permute(0, 2, 1, 3) if layer_past is not None: past_key, past_value = layer_past key = torch.cat((past_key, key), dim=-2) value = torch.cat((past_value, value), dim=-2) if use_cache is True: present = (key, value) else: present = None attn_output, attn_weights = self._attn(query, key.transpose(-1, -2), value, attention_mask, alibi) attn_output = attn_output.transpose(1, 2).reshape(hidden_states.shape) attn_output = self.c_proj(attn_output) outputs = (attn_output, present) if output_attentions: outputs += (attn_weights,) return outputs # a, present, (attentions) class MLP(nn.Module): def __init__(self, intermediate_size, config, multiple_of: int = 256): super().__init__() embed_dim = config.hidden_size hidden_dim = intermediate_size hidden_dim = int(2 * hidden_dim / 3) hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) self.linear_1 = nn.Linear(embed_dim, hidden_dim, bias=False) self.linear_3 = nn.Linear(embed_dim, hidden_dim, bias=False) self.c_proj = nn.Linear(hidden_dim, embed_dim, bias=False) def forward(self, x: Optional[Tuple[torch.Tensor]]) -> torch.Tensor: x1 = F.silu(self.linear_1(x)) x2 = self.linear_3(x) x = self.c_proj(x1 * x2) return x class LayerNormNoBias(nn.Module): def __init__(self, shape: int, eps: float = 1e-5): super().__init__() self.shape = (shape,) self.eps = eps self.weight = nn.Parameter(torch.empty(self.shape)) def forward(self, x: torch.Tensor) -> torch.Tensor: return F.layer_norm(x, self.shape, self.weight, None, self.eps) class GPTRefactBlock(nn.Module): def __init__(self, config, layer_idx=None): super().__init__() hidden_size = config.hidden_size self.inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size self.ln_1 = LayerNormNoBias(hidden_size, eps=config.layer_norm_epsilon) self.attn = Attention(config, layer_idx=layer_idx) self.ln_2 = LayerNormNoBias(hidden_size, eps=config.layer_norm_epsilon) self.mlp = MLP(self.inner_dim, config) def forward( self, hidden_states: Optional[Tuple[torch.Tensor]], layer_past: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, alibi: Optional[torch.Tensor] = None, use_cache: Optional[bool] = False, output_attentions: Optional[bool] = False, ) -> Union[ Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, torch.Tensor] ]: hidden_states_norm = self.ln_1(hidden_states) attn_outputs = self.attn( hidden_states_norm, layer_past=layer_past, attention_mask=attention_mask, alibi=alibi, use_cache=use_cache, output_attentions=output_attentions, ) attn_output = attn_outputs[0] # output_attn: a, present, (attentions) outputs = attn_outputs[1:] # residual connection mix = attn_output + hidden_states norm_mix = self.ln_2(mix) feed_forward_hidden_states = self.mlp(norm_mix) # residual connection hidden_states = mix + feed_forward_hidden_states if use_cache: outputs = (hidden_states,) + outputs else: outputs = (hidden_states,) + outputs[1:] return outputs # hidden_states, present, (attentions, cross_attentions) class GPTRefactPreTrainedModel(PreTrainedModel): config_class = GPTRefactConfig base_model_prefix = "transformer" supports_gradient_checkpointing = True _no_split_modules = ["GPTRefactBlock"] _skip_keys_device_placement = "past_key_values" def __init__(self, *inputs, **kwargs): super().__init__(*inputs, **kwargs) def _init_weights(self, module): if isinstance(module, (MLP, Attention)): # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. # > -- GPT-2 :: https://openai.com/blog/better-language-models/ # # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py module.c_proj.weight.data.normal_( mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer)) ) module.c_proj._is_hf_initialized = True elif isinstance(module, nn.Linear): # Slightly different from the TF version which uses truncated_normal for initialization # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() elif isinstance(module, LayerNormNoBias): module.weight.data.fill_(1.0) def _set_gradient_checkpointing(self, module, value=False): if isinstance(module, GPTRefactModel): module.gradient_checkpointing = value class GPTRefactModel(GPTRefactPreTrainedModel): def __init__(self, config): super().__init__(config) self.embed_dim = config.hidden_size self.num_heads = config.num_attention_heads self.multi_query = config.multi_query self.wte = nn.Embedding(config.vocab_size, self.embed_dim) self.h = nn.ModuleList([GPTRefactBlock(config, layer_idx=i) for i in range(config.num_hidden_layers)]) max_positions = config.max_position_embeddings self.register_buffer( "bias", torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)), persistent=False ) self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() @staticmethod def _make_mask(seq_len: int, past_key_values_length: int): # prompt if past_key_values_length == 0: mask = torch.ones((seq_len, seq_len + past_key_values_length), dtype=torch.bool) mask = torch.triu(mask, 1) else: mask = torch.zeros((seq_len, seq_len + past_key_values_length), dtype=torch.bool) return mask def forward( self, input_ids: Optional[torch.Tensor] = None, past_key_values: Optional[List[torch.Tensor]] = None, attention_mask: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") elif input_ids is not None: input_shape = input_ids.size() input_ids = input_ids.view(-1, input_shape[-1]) batch_size = input_ids.shape[0] elif inputs_embeds is not None: input_shape = inputs_embeds.size()[:-1] batch_size = inputs_embeds.shape[0] else: raise ValueError("You have to specify either input_ids or inputs_embeds") if batch_size <= 0: raise ValueError("batch_size has to be defined and > 0") device = input_ids.device if input_ids is not None else inputs_embeds.device if past_key_values is None: past_length = 0 past_key_values = tuple([None] * len(self.h)) else: past_length = past_key_values[0][0].size(-2) # Self-attention mask. query_length = input_shape[-1] seq_length_with_past = past_length + query_length if attention_mask is None: attention_mask = self._make_mask(query_length, past_length).to(device) else: attention_mask = attention_mask.to(device) hidden_states = self.wte(input_ids) if inputs_embeds is None else inputs_embeds alibi = get_alibi_biases(hidden_states.shape[0], seq_length_with_past, self.num_heads, device, self.wte.weight.dtype)[:, :, -query_length:, :] output_shape = input_shape + (hidden_states.size(-1),) presents = [] if use_cache else None all_self_attentions = () if output_attentions else None all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None all_hidden_states = () if output_hidden_states else None for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: def create_custom_forward(module): def custom_forward(*inputs): # None for past_key_value return module(*inputs, use_cache, output_attentions) return custom_forward outputs = torch.utils.checkpoint.checkpoint( create_custom_forward(block), hidden_states, None, attention_mask, alibi ) else: outputs = block( hidden_states, layer_past=layer_past, attention_mask=attention_mask, alibi=alibi, use_cache=use_cache, output_attentions=output_attentions, ) hidden_states = outputs[0] if use_cache: presents.append(outputs[1]) if output_attentions: all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) if self.config.add_cross_attention: all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],) hidden_states = hidden_states.view(output_shape) # Add last hidden state if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) if not return_dict: return tuple( v for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions] if v is not None ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, past_key_values=presents, hidden_states=all_hidden_states, attentions=all_self_attentions, cross_attentions=all_cross_attentions, ) class GPTRefactForCausalLM(GPTRefactPreTrainedModel): _tied_weights_keys = ["lm_head.weight", "ln_f.weight"] def __init__(self, config): super().__init__(config) self.transformer = GPTRefactModel(config) self.ln_f = LayerNormNoBias(self.transformer.embed_dim, eps=config.layer_norm_epsilon) self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) # Initialize weights and apply final processing self.post_init() def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): if inputs_embeds is not None and past_key_values is None: model_inputs = {"inputs_embeds": inputs_embeds} else: if past_key_values is not None: model_inputs = {"input_ids": input_ids[..., -1:]} else: model_inputs = {"input_ids": input_ids} model_inputs.update( { "past_key_values": past_key_values, "use_cache": kwargs.get("use_cache"), } ) return model_inputs def forward( self, input_ids: Optional[torch.Tensor] = None, past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, attention_mask: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]: r""" labels (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict transformer_outputs = self.transformer( input_ids, past_key_values=past_key_values, attention_mask=attention_mask, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) hidden_states = transformer_outputs[0] x = self.ln_f(hidden_states) lm_logits = self.lm_head(x) loss = None if labels is not None: # Shift so that tokens < n predict n shift_logits = lm_logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous().to(shift_logits.device) # Flatten the tokens loss_fct = CrossEntropyLoss() loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) if not return_dict: output = (lm_logits,) + transformer_outputs[1:] return ((loss,) + output) if loss is not None else output return CausalLMOutputWithCrossAttentions( loss=loss, logits=lm_logits, past_key_values=transformer_outputs.past_key_values, hidden_states=transformer_outputs.hidden_states, attentions=transformer_outputs.attentions, cross_attentions=transformer_outputs.cross_attentions, ) @staticmethod def _reorder_cache( past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor ) -> Tuple[Tuple[torch.Tensor]]: """ This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct beam_idx at every generation step. """ return tuple(layer_past.index_select(0, beam_idx.to(layer_past.device)) for layer_past in past_key_values)