How to training a llama-2-7B-32k from llama-2-7B?
Just taking llama-2-7b as an example, I want to know how to train the context that can be extended to 32k. I saw that there is only fine-tuning llama-2-7b-32k code in openchatkit. If I want to train llama-2-7b to llama-2-7b-32k from scratch, what should I do?
Together published a blog post that describes the process in great detail:
On the modeling side, we follow Meta’s recent paper and use linear interpolation to extend the context length. This provides a powerful way to extend the context length for models with rotary positional embeddings. We take the LLaMA-2 checkpoint, and continue pre-training/fine-tuning it with linear interpolation for 1.5B tokens.
Start by reading the paper: https://arxiv.org/abs/2306.15595
So this part of the code is not open source, right? I'll look into this, thank you~
@Sayoyo the training code is here https://github.com/togethercomputer/OpenChatKit ; and the datasets are here https://huggingface.co/datasets/togethercomputer/Long-Data-Collections
@Sayoyo the paper states:
Models extended via Position Interpolation retain its original architecture and can reuse most pre-existing optimization and infrastructure.
While the Together blog post states:
continue pre-training/fine-tuning it
You would have to try it out but this would suggest that starting from llama-2 base weights can be done using the code in the GitHub repo. The important part seems to be using the RotaryEmbedding
defined in training/modules/llama_modules.py
during training. You can see that this gets picked up in the GPTStageBase
class and ultimately used in training in the async implementation of Gpipe via get_pp_module()
.
More specifically, to go from llama-2 base you could try to pass the weights into the prepare.py
script:
python pretrained/Llama-2-7B-32K-beta/prepare.py --model-name huggyllama/llama-7b # you might need these locally
Then look in training/finetune_llama-2-7b-32k-mqa.sh
for ideas on what parameters you want to use while crunching through the long dataset that
@zhangce
shared, or try your own long dataset!
As stated elsewhere, you can expect the training process to require much more VRAM than for other 7B models.
Please report back with your results.
the training code is here https://github.com/togethercomputer/OpenChatKit ; and the datasets are here https://huggingface.co/datasets/togethercomputer/Long-Data-Collections
Hi @zhangce , just wonder if I can pretrain/fine-tune the model, for example, llama-2-13b to get llama-2-13b-32k by using OpenChatKit ? If that is possible, can you show me the path? Thanks