File size: 5,582 Bytes
d643072 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 |
import torch
from transformers.models.gemma2 import modeling_gemma2
# Monkey patch the Gemma2Model's forward function
# Save a reference to the original forward function
original_forward = modeling_gemma2.Gemma2Model.forward
# Define the patched version of the forward function
def patched_forward(self,
input_ids=None,
attention_mask=None,
position_ids=None,
past_key_values=None,
inputs_embeds=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
cache_position=None):
# Update parameters based on the input or configuration defaults
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# Ensure either input_ids or inputs_embeds is specified, but not both
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
# Handle gradient checkpointing case to ensure compatibility with caching
if self.gradient_checkpointing and self.training and use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
)
use_cache = False
# Embed tokens if inputs_embeds is not provided
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
# Handle caching mechanism
if use_cache and past_key_values is None and not self.training:
batch_size, seq_len, _ = inputs_embeds.shape
past_key_values = modeling_gemma2.HybridCache(
self.config,
batch_size=batch_size,
max_cache_len=seq_len,
device=self.device,
dtype=inputs_embeds.dtype,
)
# Handle cache position
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
# Handle position IDs
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
# Compute causal mask
causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
)
# Embed positions and initialize hidden states
hidden_states = inputs_embeds
# Create the normalizer tensor on the correct device
normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype, device=hidden_states.device)
hidden_states = hidden_states * normalizer
# Initialize variables to store outputs if required
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
# Pass through decoder layers
for decoder_layer in self.layers:
# Store the hidden state if requested
if output_hidden_states:
all_hidden_states += (hidden_states,)
# Use gradient checkpointing if applicable
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
causal_mask,
position_ids,
past_key_values,
output_attentions,
use_cache,
cache_position,
)
else:
# Normal forward pass through the layer
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
)
# Update hidden states with the output from the layer
hidden_states = layer_outputs[0]
# Store self-attentions if requested
if output_attentions:
all_self_attns += (layer_outputs[1],)
# Apply final normalization
hidden_states = self.norm(hidden_states)
# Store the last hidden state if required
if output_hidden_states:
all_hidden_states += (hidden_states,)
# Handle caching mechanism
next_cache = past_key_values if use_cache else None
# Prepare output
if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
return modeling_gemma2.BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
# Apply the patched forward function to the Gemma2Model class
#modeling_gemma2.Gemma2Model.forward = patched_forward
# Optional: You can define a function here that runs the patch.
def apply_patch():
print("Gemma2Model's forward function has been patched.")
modeling_gemma2.Gemma2Model.forward = patched_forward
|