winglian commited on
Commit
05b398a
·
unverified ·
1 Parent(s): e634118

fix some of the edge cases for Jamba (#1452)

Browse files

* fix some of the edge cases for Jamba

* update requirements for jamba

.github/workflows/pypi.yml CHANGED
@@ -25,7 +25,7 @@ jobs:
25
 
26
  - name: Install dependencies
27
  run: |
28
- pip3 install wheel
29
  pip3 install -e .
30
  pip3 install -r requirements-tests.txt
31
 
 
25
 
26
  - name: Install dependencies
27
  run: |
28
+ pip3 install wheel packaging
29
  pip3 install -e .
30
  pip3 install -r requirements-tests.txt
31
 
.github/workflows/tests.yml CHANGED
@@ -48,6 +48,8 @@ jobs:
48
 
49
  - name: Install dependencies
50
  run: |
 
 
51
  pip3 install -U -e .
52
  pip3 install -r requirements-tests.txt
53
 
 
48
 
49
  - name: Install dependencies
50
  run: |
51
+ pip3 install --upgrade pip
52
+ pip3 install --upgrade packaging
53
  pip3 install -U -e .
54
  pip3 install -r requirements-tests.txt
55
 
examples/jamba/README.md CHANGED
@@ -1,5 +1,10 @@
1
  # Jamba
2
 
3
- qlora w/ deepspeed needs at least 2x GPUs and 35GiB VRAM per GPU
4
-
5
- qlora single-gpu - training will start, but loss is off by an order of magnitude
 
 
 
 
 
 
1
  # Jamba
2
 
3
+ - ✅ qlora w/ deepspeed Zero-2 needs at least 2x GPUs and
4
+ - 35GiB VRAM per GPU w minimal context length
5
+ - 56GiB VRAM per GPU (w multipack enabled)
6
+ - ✅ qlora w/ deepspeed Zero-3 needs at least 2x GPUs and 67GiB VRAM (wtf?)
7
+ - ✅ qlora single-gpu, ~51GiB VRAM
8
+ - ✅ multipack
9
+ - ❓ FSDP
10
+ - ❓ 8-bit LoRA
examples/jamba/qlora.yaml ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ base_model: ai21labs/Jamba-v0.1
2
+ trust_remote_code: true
3
+
4
+ load_in_8bit: false
5
+ load_in_4bit: true
6
+ strict: false
7
+
8
+ datasets:
9
+ - path: mhenrichsen/alpaca_2k_test
10
+ type: alpaca
11
+ dataset_prepared_path:
12
+ val_set_size: 0.0
13
+ output_dir: ./out
14
+
15
+ sequence_len: 4096
16
+ sample_packing: false
17
+ pad_to_sequence_len: false
18
+ eval_sample_packing: false
19
+
20
+ wandb_project:
21
+ wandb_entity:
22
+ wandb_watch:
23
+ wandb_name:
24
+ wandb_log_model:
25
+
26
+ adapter: qlora
27
+ lora_r: 8
28
+ lora_alpha: 16
29
+ lora_dropout: 0.05
30
+ lora_target_linear: true
31
+
32
+ low_cpu_mem_usage: true
33
+ gradient_accumulation_steps: 4
34
+ micro_batch_size: 1
35
+ num_epochs: 2
36
+ optimizer: paged_adamw_8bit
37
+ lr_scheduler: cosine
38
+ learning_rate: 0.00001
39
+
40
+ train_on_inputs: false
41
+ group_by_length: false
42
+ bf16: auto
43
+ fp16:
44
+ tf32: false
45
+
46
+ gradient_checkpointing: true
47
+ gradient_checkpointing_kwargs:
48
+ use_reentrant: false
49
+ early_stopping_patience:
50
+ resume_from_checkpoint:
51
+ local_rank:
52
+ logging_steps: 1
53
+ xformers_attention:
54
+ flash_attention: true
55
+
56
+ warmup_steps: 10
57
+ evals_per_epoch:
58
+ saves_per_epoch: 1
59
+ debug:
60
+ deepspeed:
61
+ weight_decay: 0.0
62
+ special_tokens:
requirements.txt CHANGED
@@ -32,7 +32,7 @@ fschat==0.2.36
32
  gradio==3.50.2
33
  tensorboard
34
 
35
- mamba-ssm==1.1.1
36
 
37
  # remote filesystems
38
  s3fs
 
32
  gradio==3.50.2
33
  tensorboard
34
 
35
+ mamba-ssm==1.2.0.post1
36
 
37
  # remote filesystems
38
  s3fs
setup.py CHANGED
@@ -78,7 +78,7 @@ setup(
78
  "deepspeed-kernels",
79
  ],
80
  "mamba-ssm": [
81
- "mamba-ssm==1.0.1",
82
  ],
83
  "auto-gptq": [
84
  "auto-gptq==0.5.1",
 
78
  "deepspeed-kernels",
79
  ],
80
  "mamba-ssm": [
81
+ "mamba-ssm==1.2.0.post1",
82
  ],
83
  "auto-gptq": [
84
  "auto-gptq==0.5.1",
src/axolotl/monkeypatch/multipack.py CHANGED
@@ -48,14 +48,16 @@ def patch_for_multipack(model_type, model_name=None):
48
  get_unpad_data
49
  )
50
  elif model_type == "gemmoe":
51
- model_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
52
- # we need to load the model here in order for modeling_gemmoe to be available
53
- with init_empty_weights():
54
- AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
55
- module_name = model_config.__class__.__module__.replace(
56
- ".configuration_gemmoe", ".modeling_gemmoe"
57
- )
58
- modeling_gemmoe = importlib.import_module(module_name)
59
- modeling_gemmoe._get_unpad_data = ( # pylint: disable=protected-access
60
- get_unpad_data
61
- )
 
 
 
48
  get_unpad_data
49
  )
50
  elif model_type == "gemmoe":
51
+ patch_remote(model_name, ".configuration_gemmoe", ".modeling_gemmoe")
52
+ elif model_type == "jamba":
53
+ patch_remote(model_name, ".configuration_jamba", ".modeling_jamba")
54
+
55
+
56
+ def patch_remote(model_name, config_name, modeling_name):
57
+ model_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
58
+ # we need to load the model here in order for modeling_* to be available
59
+ with init_empty_weights():
60
+ AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
61
+ module_name = model_config.__class__.__module__.replace(config_name, modeling_name)
62
+ modeling_arch = importlib.import_module(module_name)
63
+ modeling_arch._get_unpad_data = get_unpad_data # pylint: disable=protected-access
src/axolotl/utils/models.py CHANGED
@@ -456,6 +456,10 @@ def load_model(
456
  "bnb_4bit_quant_type": "nf4",
457
  "bnb_4bit_quant_storage": torch.bfloat16,
458
  }
 
 
 
 
459
 
460
  if cfg.bnb_config_kwargs:
461
  bnb_config.update(cfg.bnb_config_kwargs)
 
456
  "bnb_4bit_quant_type": "nf4",
457
  "bnb_4bit_quant_storage": torch.bfloat16,
458
  }
459
+ if cfg.model_config_type == "jamba" and not cfg.deepspeed:
460
+ # for some reason, this causes the loss to be off by an order of magnitude
461
+ # but deepspeed needs this still in bfloat16
462
+ bnb_config["bnb_4bit_quant_storage"] = torch.float32
463
 
464
  if cfg.bnb_config_kwargs:
465
  bnb_config.update(cfg.bnb_config_kwargs)