winglian commited on
Commit
cebea37
·
unverified ·
2 Parent(s): 2ae936f 1f5d83e

Merge pull request #36 from OpenAccess-AI-Collective/qlora

Browse files
requirements.txt CHANGED
@@ -1,10 +1,10 @@
1
  peft @ git+https://github.com/huggingface/peft.git
2
  transformers @ git+https://github.com/huggingface/transformers.git
 
3
  attrdict
4
  fire
5
  PyYAML==6.0
6
  black
7
- bitsandbytes==0.37.2
8
  datasets
9
  accelerate>=0.19.0
10
  sentencepiece
 
1
  peft @ git+https://github.com/huggingface/peft.git
2
  transformers @ git+https://github.com/huggingface/transformers.git
3
+ bitsandbytes>=0.39.0
4
  attrdict
5
  fire
6
  PyYAML==6.0
7
  black
 
8
  datasets
9
  accelerate>=0.19.0
10
  sentencepiece
scripts/finetune.py CHANGED
@@ -14,6 +14,7 @@ from attrdict import AttrDefault
14
 
15
  # add src to the pythonpath so we don't need to pip install this
16
  from axolotl.utils.tokenization import check_dataset_labels
 
17
 
18
  project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
19
  src_dir = os.path.join(project_root, "src")
@@ -158,6 +159,8 @@ def train(
158
  cfg.fp16 = True
159
  cfg.bf16 = False
160
 
 
 
161
  # Load the model and tokenizer
162
  logging.info("loading model, tokenizer, and peft_config...")
163
  model, tokenizer, peft_config = load_model(
 
14
 
15
  # add src to the pythonpath so we don't need to pip install this
16
  from axolotl.utils.tokenization import check_dataset_labels
17
+ from axolotl.utils.validation import validate_config
18
 
19
  project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
20
  src_dir = os.path.join(project_root, "src")
 
159
  cfg.fp16 = True
160
  cfg.bf16 = False
161
 
162
+ validate_config(cfg)
163
+
164
  # Load the model and tokenizer
165
  logging.info("loading model, tokenizer, and peft_config...")
166
  model, tokenizer, peft_config = load_model(
src/axolotl/prompters.py CHANGED
@@ -11,6 +11,7 @@ class PromptStyle(Enum):
11
  instruct = "instruct"
12
  chat = "chat"
13
 
 
14
  class AlpacaPrompter:
15
  system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n"
16
  system_no_input_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n"
@@ -50,6 +51,10 @@ class AlpacaPrompter:
50
  return output.split(self.response_split)[1].strip()
51
 
52
 
 
 
 
 
53
  class JeopardyPrompter(AlpacaPrompter):
54
  prompt_input = "Below is a Jeopardy clue paired with input providing the category of the clue. Write a concise response that best answers tbe clue given the category.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
55
 
 
11
  instruct = "instruct"
12
  chat = "chat"
13
 
14
+
15
  class AlpacaPrompter:
16
  system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n"
17
  system_no_input_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n"
 
51
  return output.split(self.response_split)[1].strip()
52
 
53
 
54
+ class UnpromptedPrompter(AlpacaPrompter):
55
+ system_prompt = ""
56
+ system_no_input_prompt = ""
57
+
58
  class JeopardyPrompter(AlpacaPrompter):
59
  prompt_input = "Below is a Jeopardy clue paired with input providing the category of the clue. Write a concise response that best answers tbe clue given the category.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
60
 
src/axolotl/utils/data.py CHANGED
@@ -98,6 +98,11 @@ def load_tokenized_prepared_datasets(tokenizer, cfg, default_dataset_prepared_pa
98
  ds = load_dataset("json", data_files=fp, streaming=False, split=None)
99
  if not ds:
100
  raise Exception("unhandled dataset load")
 
 
 
 
 
101
  d_type = d.type
102
  d_type_split = d_type.split(":")
103
  d_base_type = d_type_split[0]
 
98
  ds = load_dataset("json", data_files=fp, streaming=False, split=None)
99
  if not ds:
100
  raise Exception("unhandled dataset load")
101
+ # support for using a subset of the data
102
+ if d.shards:
103
+ ds = ds.shuffle(seed=42)["train"].shard(
104
+ num_shards=cfg.shards, index=0
105
+ )
106
  d_type = d.type
107
  d_type_split = d_type.split(":")
108
  d_base_type = d_type_split[0]
src/axolotl/utils/models.py CHANGED
@@ -6,11 +6,12 @@ from typing import Optional, Tuple, TYPE_CHECKING
6
 
7
  import torch
8
  import transformers
 
9
  from transformers import (
10
  AutoModelForCausalLM,
11
  AutoTokenizer,
12
  PreTrainedModel,
13
- AutoConfig,
14
  )
15
 
16
  try:
@@ -81,6 +82,16 @@ def load_model(
81
  logging.exception(e)
82
  raise e
83
 
 
 
 
 
 
 
 
 
 
 
84
  try:
85
  if cfg.load_4bit and is_llama_derived_model:
86
  from alpaca_lora_4bit.autograd_4bit import load_llama_model_4bit_low_ram
@@ -123,8 +134,10 @@ def load_model(
123
  model = LlamaForCausalLM.from_pretrained(
124
  base_model,
125
  load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
 
126
  torch_dtype=torch_dtype,
127
  device_map=cfg.device_map,
 
128
  )
129
  # elif model_type == "GPTNeoXForCausalLM" and cfg.flash_attention:
130
  # This is a WIP, still an issue with the backward pass
@@ -156,9 +169,11 @@ def load_model(
156
  model = getattr(transformers, model_type).from_pretrained(
157
  base_model,
158
  load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
 
159
  torch_dtype=torch_dtype,
160
  device_map=cfg.device_map,
161
  trust_remote_code=True if cfg.trust_remote_code is True else False,
 
162
  )
163
  else:
164
  config = AutoConfig.from_pretrained(
@@ -169,9 +184,11 @@ def load_model(
169
  base_model,
170
  config=config,
171
  load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
 
172
  torch_dtype=torch_dtype,
173
  device_map=cfg.device_map,
174
  trust_remote_code=True if cfg.trust_remote_code is True else False,
 
175
  )
176
  except Exception as e:
177
  logging.error(
@@ -184,6 +201,7 @@ def load_model(
184
  torch_dtype=torch_dtype,
185
  device_map=cfg.device_map,
186
  trust_remote_code=True if cfg.trust_remote_code is True else False,
 
187
  )
188
 
189
  if not tokenizer:
@@ -225,7 +243,7 @@ def load_model(
225
  embeddings_len = math.ceil(len(tokenizer) / 32) * 32
226
  model.resize_token_embeddings(embeddings_len)
227
 
228
- if cfg.adapter and load_in_8bit and not cfg.load_4bit:
229
  logging.info("converting PEFT model w/ prepare_model_for_int8_training")
230
  model = prepare_model_for_int8_training(model)
231
 
@@ -270,7 +288,7 @@ def load_adapter(model, cfg, adapter):
270
 
271
  if adapter is None:
272
  return model, None
273
- if adapter == "lora":
274
  return load_lora(model, cfg)
275
  if adapter == "llama-adapter":
276
  return load_llama_adapter(model, cfg)
 
6
 
7
  import torch
8
  import transformers
9
+ from torch import nn
10
  from transformers import (
11
  AutoModelForCausalLM,
12
  AutoTokenizer,
13
  PreTrainedModel,
14
+ AutoConfig, BitsAndBytesConfig,
15
  )
16
 
17
  try:
 
82
  logging.exception(e)
83
  raise e
84
 
85
+ model_kwargs = {}
86
+ if cfg.adapter == "qlora":
87
+ model_kwargs["quantization_config"] = BitsAndBytesConfig(
88
+ load_in_4bit=True,
89
+ llm_int8_threshold=6.0,
90
+ llm_int8_has_fp16_weight=False,
91
+ bnb_4bit_compute_dtype=torch.float16,
92
+ bnb_4bit_use_double_quant=True,
93
+ bnb_4bit_quant_type="nf4",
94
+ )
95
  try:
96
  if cfg.load_4bit and is_llama_derived_model:
97
  from alpaca_lora_4bit.autograd_4bit import load_llama_model_4bit_low_ram
 
134
  model = LlamaForCausalLM.from_pretrained(
135
  base_model,
136
  load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
137
+ load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
138
  torch_dtype=torch_dtype,
139
  device_map=cfg.device_map,
140
+ **model_kwargs,
141
  )
142
  # elif model_type == "GPTNeoXForCausalLM" and cfg.flash_attention:
143
  # This is a WIP, still an issue with the backward pass
 
169
  model = getattr(transformers, model_type).from_pretrained(
170
  base_model,
171
  load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
172
+ load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
173
  torch_dtype=torch_dtype,
174
  device_map=cfg.device_map,
175
  trust_remote_code=True if cfg.trust_remote_code is True else False,
176
+ **model_kwargs,
177
  )
178
  else:
179
  config = AutoConfig.from_pretrained(
 
184
  base_model,
185
  config=config,
186
  load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
187
+ load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
188
  torch_dtype=torch_dtype,
189
  device_map=cfg.device_map,
190
  trust_remote_code=True if cfg.trust_remote_code is True else False,
191
+ **model_kwargs,
192
  )
193
  except Exception as e:
194
  logging.error(
 
201
  torch_dtype=torch_dtype,
202
  device_map=cfg.device_map,
203
  trust_remote_code=True if cfg.trust_remote_code is True else False,
204
+ **model_kwargs,
205
  )
206
 
207
  if not tokenizer:
 
243
  embeddings_len = math.ceil(len(tokenizer) / 32) * 32
244
  model.resize_token_embeddings(embeddings_len)
245
 
246
+ if ((cfg.adapter == "lora" and load_in_8bit) or cfg.adapter == "qlora") and not cfg.load_4bit:
247
  logging.info("converting PEFT model w/ prepare_model_for_int8_training")
248
  model = prepare_model_for_int8_training(model)
249
 
 
288
 
289
  if adapter is None:
290
  return model, None
291
+ if adapter == "lora" or adapter == "qlora":
292
  return load_lora(model, cfg)
293
  if adapter == "llama-adapter":
294
  return load_llama_adapter(model, cfg)