Fix For NaN Logits in HuggingFace Distribution of OpenELM
Browse filesI found that left padding of inputs led to NaN logits. The fix (credit to [this thread](https://github.com/huggingface/transformers/issues/32390), is to change the line ```min_dtype = torch.finfo(dtype).min``` to ```min_dtype = torch.finfo(dtype).min / 2``` in the function ```_update_causal_mask```.
I presume all other OpenELM model sizes and variations require the same fix.
Note: the ```if not is_tracing and torch.any(attention_mask != 1):``` condition in the ```_update_causal_mask``` function seems to be addressing the same issue; however, this mitigation only occurs when ```self.config._attn_implementation == "sdpa"```, whereas the issue is present even if ```self.config._attn_implementation == "eager"```.
P.S. thanks for your work on OpenELM!
- modeling_openelm.py +1 -1
@@ -766,7 +766,7 @@ class OpenELMModel(OpenELMPreTrainedModel):
|
|
766 |
)
|
767 |
|
768 |
# We use the current dtype to avoid any overflows
|
769 |
-
min_dtype = torch.finfo(dtype).min
|
770 |
causal_mask = (
|
771 |
self.causal_mask[None, None, :, :].repeat(batch_size, 1, 1, 1).to(dtype)
|
772 |
* min_dtype
|
|
|
766 |
)
|
767 |
|
768 |
# We use the current dtype to avoid any overflows
|
769 |
+
min_dtype = torch.finfo(dtype).min / 2
|
770 |
causal_mask = (
|
771 |
self.causal_mask[None, None, :, :].repeat(batch_size, 1, 1, 1).to(dtype)
|
772 |
* min_dtype
|