File size: 5,582 Bytes
d643072
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import torch
from transformers.models.gemma2 import modeling_gemma2

# Monkey patch the Gemma2Model's forward function
# Save a reference to the original forward function
original_forward = modeling_gemma2.Gemma2Model.forward


# Define the patched version of the forward function
def patched_forward(self,
                    input_ids=None,
                    attention_mask=None,
                    position_ids=None,
                    past_key_values=None,
                    inputs_embeds=None,
                    use_cache=None,
                    output_attentions=None,
                    output_hidden_states=None,
                    return_dict=None,
                    cache_position=None):
    # Update parameters based on the input or configuration defaults
    output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
    output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
    use_cache = use_cache if use_cache is not None else self.config.use_cache
    return_dict = return_dict if return_dict is not None else self.config.use_return_dict

    # Ensure either input_ids or inputs_embeds is specified, but not both
    if (input_ids is None) ^ (inputs_embeds is not None):
        raise ValueError("You must specify exactly one of input_ids or inputs_embeds")

    # Handle gradient checkpointing case to ensure compatibility with caching
    if self.gradient_checkpointing and self.training and use_cache:
        logger.warning_once(
            "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
        )
        use_cache = False

    # Embed tokens if inputs_embeds is not provided
    if inputs_embeds is None:
        inputs_embeds = self.embed_tokens(input_ids)

    # Handle caching mechanism
    if use_cache and past_key_values is None and not self.training:
        batch_size, seq_len, _ = inputs_embeds.shape
        past_key_values = modeling_gemma2.HybridCache(
            self.config,
            batch_size=batch_size,
            max_cache_len=seq_len,
            device=self.device,
            dtype=inputs_embeds.dtype,
        )

    # Handle cache position
    if cache_position is None:
        past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
        cache_position = torch.arange(
            past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
        )

    # Handle position IDs
    if position_ids is None:
        position_ids = cache_position.unsqueeze(0)

    # Compute causal mask
    causal_mask = self._update_causal_mask(
        attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
    )

    # Embed positions and initialize hidden states
    hidden_states = inputs_embeds

    # Create the normalizer tensor on the correct device
    normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype, device=hidden_states.device)
    hidden_states = hidden_states * normalizer

    # Initialize variables to store outputs if required
    all_hidden_states = () if output_hidden_states else None
    all_self_attns = () if output_attentions else None

    # Pass through decoder layers
    for decoder_layer in self.layers:
        # Store the hidden state if requested
        if output_hidden_states:
            all_hidden_states += (hidden_states,)

        # Use gradient checkpointing if applicable
        if self.gradient_checkpointing and self.training:
            layer_outputs = self._gradient_checkpointing_func(
                decoder_layer.__call__,
                hidden_states,
                causal_mask,
                position_ids,
                past_key_values,
                output_attentions,
                use_cache,
                cache_position,
            )
        else:
            # Normal forward pass through the layer
            layer_outputs = decoder_layer(
                hidden_states,
                attention_mask=causal_mask,
                position_ids=position_ids,
                past_key_value=past_key_values,
                output_attentions=output_attentions,
                use_cache=use_cache,
                cache_position=cache_position,
            )

        # Update hidden states with the output from the layer
        hidden_states = layer_outputs[0]

        # Store self-attentions if requested
        if output_attentions:
            all_self_attns += (layer_outputs[1],)

    # Apply final normalization
    hidden_states = self.norm(hidden_states)

    # Store the last hidden state if required
    if output_hidden_states:
        all_hidden_states += (hidden_states,)

    # Handle caching mechanism
    next_cache = past_key_values if use_cache else None

    # Prepare output
    if not return_dict:
        return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)

    return modeling_gemma2.BaseModelOutputWithPast(
        last_hidden_state=hidden_states,
        past_key_values=next_cache,
        hidden_states=all_hidden_states,
        attentions=all_self_attns,
    )

# Apply the patched forward function to the Gemma2Model class
#modeling_gemma2.Gemma2Model.forward = patched_forward


# Optional: You can define a function here that runs the patch.
def apply_patch():
    print("Gemma2Model's forward function has been patched.")
    modeling_gemma2.Gemma2Model.forward = patched_forward