Text Generation
Transformers
Safetensors
openelm
custom_code
jasonkrone commited on
Commit
13bd99e
·
verified ·
1 Parent(s): 9925bff

Fix For NaN Logits in HuggingFace Distribution of OpenELM

Browse files

I found that left padding of inputs led to NaN logits. The fix (credit to [this thread](https://github.com/huggingface/transformers/issues/32390), is to change the line ```min_dtype = torch.finfo(dtype).min``` to ```min_dtype = torch.finfo(dtype).min / 2``` in the function ```_update_causal_mask```.

I presume all other OpenELM model sizes and variations require the same fix.

Note: the ```if not is_tracing and torch.any(attention_mask != 1):``` condition in the ```_update_causal_mask``` function seems to be addressing the same issue; however, this mitigation only occurs when ```self.config._attn_implementation == "sdpa"```, whereas the issue is present even if ```self.config._attn_implementation == "eager"```.

P.S. thanks for your work on OpenELM!

Files changed (1) hide show
  1. modeling_openelm.py +1 -1
modeling_openelm.py CHANGED
@@ -766,7 +766,7 @@ class OpenELMModel(OpenELMPreTrainedModel):
766
  )
767
 
768
  # We use the current dtype to avoid any overflows
769
- min_dtype = torch.finfo(dtype).min
770
  causal_mask = (
771
  self.causal_mask[None, None, :, :].repeat(batch_size, 1, 1, 1).to(dtype)
772
  * min_dtype
 
766
  )
767
 
768
  # We use the current dtype to avoid any overflows
769
+ min_dtype = torch.finfo(dtype).min / 2
770
  causal_mask = (
771
  self.causal_mask[None, None, :, :].repeat(batch_size, 1, 1, 1).to(dtype)
772
  * min_dtype