Runtime autograd error due to inplace operations

#4
by xianbin - opened

Error

While performing fine tuning of the Gemma2 models using TRL, the following errors were encountered:

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [CUDABFloat16Type [1, 308, 256000]], which is output 0 of TanhBackward0, is at version 1; expected version 0 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).

Cause

This was found to be due to the use of inplace operations in the Gemma2 transformer model definition that modifies a variable needed for gradient computation

Possible solution

The following lines of codes should be modified in diff_gemma2.py (and by extension modeling_gemma2.py)

Line 163-165:

            attention_mask *= torch.tril(
                torch.ones_like(attention_mask),
                diagonal=(self.sliding_window - cache_position[-1]),
            )

Replacement:

            attention_mask = torch.mul(
                attention_mask,
                torch.tril(
                    torch.ones_like(attention_mask),
                    diagonal=(self.sliding_window - cache_position[-1]),
                ),
            )

Line 119-121:

            attn_weights.div_(self.config.attn_logit_softcapping)
            attn_weights = torch.tanh(attn_weights)
            attn_weights.mul_(self.config.attn_logit_softcapping)

Replacement:

            attn_weights = torch.div(attn_weights, self.config.attn_logit_softcapping)
            attn_weights = self.attn_weights_tanh(attn_weights)
            attn_weights = torch.mul(attn_weights, self.config.attn_logit_softcapping)

Place this in the init of Gemma2Attention:

            self.attn_weights_tanh = nn.Tanh()

Line 202-204:

            logits.div_(self.config.final_logit_softcapping)
            logits = torch.tanh(logits)
            logits.mul_(self.config.final_logit_softcapping)

Replacement:

            logits = torch.div(logits, self.config.final_logit_softcapping)
            logits = self.final_logit_tanh(logits)
            logits = torch.mul(logits, self.config.final_logit_softcapping)

Place this in the init of Gemma2ForCausalLM:

            self.final_logit_tanh = nn.Tanh()
Google org

Yes will fix this in a bit!

Google org

Hi @xianbin , I hope the issue has been resolved. Could you please confirm if you have any concerns let us know will assist you or feel free to close.

Thank you.

xianbin changed discussion status to closed

Sign up or log in to comment