Dont download, google scuttled this model
By comparing 5 implementations, I found the following issues:
Must add or else losses will be very high.
There’s a typo for model in the technical report!
sqrt(3072)=55.4256 but bfloat16 is 55.5.
Layernorm (w+1) must be in float32.
Keras mixed_bfloat16 RoPE is wrong.
RoPE is sensitive to y*(1/x) vs y/x.
RoPE should be float32 - already pushed to transformers 4.38.2.
GELU should be approx tanh not exact.
Directly from reddit"
[P] How I found 8 bugs in Google's Gemma 6T token model
Project
Hey r/MachineLearning! Maybe you might have seen me post on Twitter, but I'll just post here if you don't know about 8 bugs in multiple implementations on Google's Gemma :) The fixes should already be pushed into HF's transformers main branch, and Keras, Pytorch Gemma, vLLM should have gotten the fix :) https://github.com/huggingface/transformers/pull/29402 I run an OSS package called Unsloth which also makes Gemma finetuning 2.5x faster and use 70% less VRAM :)
By comparing 5 implementations, I found the following issues:
Must add or else losses will be very high.
There’s a typo for model in the technical report!
sqrt(3072)=55.4256 but bfloat16 is 55.5.
Layernorm (w+1) must be in float32.
Keras mixed_bfloat16 RoPE is wrong.
RoPE is sensitive to y*(1/x) vs y/x.
RoPE should be float32 - already pushed to transformers 4.38.2.
GELU should be approx tanh not exact.
Adding all these changes allows the Log L2 Norm to decrease from the red line to the black line (lower is better). Remember this is Log scale! So the error decreased from 10_000 to now 100 now - a factor of 100! The fixes are primarily for long sequence lengths.
r/MachineLearning - [P] How I found 8 bugs in Google's Gemma 6T token model
The most glaring one was adding BOS tokens to finetuning runs tames the training loss at the start. No BOS causes losses to become very high.
r/MachineLearning - [P] How I found 8 bugs in Google's Gemma 6T token model
Another very problematic issue was RoPE embeddings were done in bfloat16 rather than float32. This ruined very long context lengths, since [8190, 8191] became upcasted to [8192, 8192]. This destroyed finetunes on very long sequence lengths.
r/MachineLearning - [P] How I found 8 bugs in Google's Gemma 6T token model
Another major issue was nearly all implementations except the JAX type ones used exact GELU, whilst approx GELU is the correct choice:
r/MachineLearning - [P] How I found 8 bugs in Google's Gemma 6T token model
I also have a Twitter thread on the fixes: https://twitter.com/danielhanchen/status/1765446273661075609, and a full Colab notebook walking through more issues: https://colab.research.google.com/drive/1fxDWAfPIbC-bHwDSVj5SBmEJ6KG3bUu5?usp=sharing Also a longer blog post: https://unsloth.ai/blog/gemma-bugs
I also made Gemma finetuning 2.5x faster, use 60% less VRAM as well in a colab notebook: https://colab.research.google.com/drive/10NbwlsRChbma1v55m8LAPYG15uQv6HLo?usp=sharing There's also a $50K Kaggle competition https://www.kaggle.com/competitions/data-assistants-with-gemma specifically for Gemma :)
Hello! Surya from the Gemma team here -- I've been in touch with Daniel over the last few weeks, and he's done really excellent work in finding these! We've been pushing fixes to all the issue he's found and hopefully it should be much more stable to finetune Gemma models now. For context, the issues were inconsistencies between the different Jax / PyTorch / Flax / Keras implementations that impacted finetuning performance in particular. Note that most of these don't affect regular inference-only workloads from our finetuned models.
We'll do our best to make sure that all the different implementations and notebooks we put out are consistent and work out of the box -- if you or others find more issues, please let us know! We know things won't always be perfect, but we are very eager to improve the models and engage with folk about how to make them more useful :)
So its ok to try and use this one from this page now?
Yes, please give it a try! It'll also depend on how you're using them (are you interested in your finetuning, just generations, what frameworks, etc.)?
@suryabhupa if I simply want to benchmark the model on EleutherAI's LM Evaluation Harness, should the model be good to go?
yes, I believe so! If you find that there are strange generations, please let us know.
@suryabhupa can you please provide an example for FSDP for Nvidia GPU? Now, it has only for TPU
All of our internal dev has been on TPU so we don't have any examples of PyTorch FSDP unfortunately :( There might be other forks or Github repos where folks try this, or other model derivatives from Gemma, I would suggest trying that best.
I saw this: https://huggingface.co/google/gemma-7b/blob/main/examples/example_fsdp.py, but not sure if it was what you're looking for
I can't make it work with llama.cpp I tried everything and at best the model talks to itself :(
sure, there is a Gemma discord channel: https://ai.google.dev/gemma/docs/discord
I know. I already joined it some time ago.