ValueError: Must flatten tensors with uniform dtype but got torch.bfloat16 and torch.float32

#19
by Srinivas777 - opened

I am trying to train 4bi quantized model. even though i gave
bnb_4bit_compute_dtype=torch.bfloat16, # Set to bfloat16 for uniformity
bnb_4bit_quant_storage=torch.bfloat16, # Ensure uniformity
I was getting :

AlgorithmError: ExecuteUserScriptError: ExitCode 1 ErrorMessage "raise ValueError( ValueError: Must flatten tensors with uniform dtype but got torch.bfloat16 and torch.float32 Traceback (most recent call last) File "/opt/ml/code/run_fsdp_qlora.py", line 262, in training_function(script_args, training_args) File "/opt/ml/code/run_fsdp_qlora.py", line 210, in training_function trainer.train(resume_from_checkpoint=checkpoint) File "/opt/conda/lib/python3.10/site-packages/trl/trainer/sft_trainer.py", line 361, in train output = super().train(*args, **kwargs) File "/opt/conda/lib/python3.10/site-packages/transformers/trainer.py", line 2052, in train return inner_training_loop( File "/opt/conda/lib/python3.10/site-packages/transformers/trainer.py", line 2194, in _inner_training_loop self.model = self.accelerator.prepare(self.model) File "/opt/conda/lib/python3.10/site-packages/accelerate/accelerator.py", line 1274, in prepare result = tuple( File "/opt/conda/lib/python3.10/site-packages/acce

run_fsdp_qlora.py. :

import logging
from dataclasses import dataclass, field
import os
import random
import torch
from datasets import load_dataset
import datetime
from transformers import AutoTokenizer, TrainingArguments
from trl.commands.cli_utils import TrlParser
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
set_seed,
)
from peft import LoraConfig
from datasets import load_from_disk

from liger_kernel.transformers import apply_liger_kernel_to_llama

from trl import SFTTrainer

Comment in if you want to use the Llama 3 instruct template but make sure to add modules_to_save

LLAMA_3_CHAT_TEMPLATE="{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}"

Anthropic/Vicuna like template without the need for special tokens

LLAMA_3_CHAT_TEMPLATE = (
"{% for message in messages %}"
"{% if message['role'] == 'system' %}"
"{{ message['content'] }}"
"{% elif message['role'] == 'user' %}"
"{{ '\n\nHuman: ' + message['content'] + eos_token }}"
"{% elif message['role'] == 'assistant' %}"
"{{ '\n\nAssistant: ' + message['content'] + eos_token }}"
"{% endif %}"
"{% endfor %}"
"{% if add_generation_prompt %}"
"{{ '\n\nAssistant: ' }}"
"{% endif %}"
)

ACCELERATE_USE_FSDP=1 FSDP_CPU_RAM_EFFICIENT_LOADING=1 torchrun --nproc_per_node=4 ./scripts/run_fsdp_qlora.py --config llama_3_70b_fsdp_qlora.yaml

@dataclass
class ScriptArguments:
train_dataset_path: str = field(
default=None,
metadata={"help": "Path to the dataset, e.g. /opt/ml/input/data/train/"},
)
test_dataset_path: str = field(
default=None,
metadata={"help": "Path to the dataset, e.g. /opt/ml/input/data/test/"},
)
model_id: str = field(
default=None, metadata={"help": "Model ID to use for SFT training"}
)
max_seq_length: int = field(
default=512, metadata={"help": "The maximum sequence length for SFT Trainer"}
)

def merge_and_save_model(model_id, adapter_dir, output_dir):
from peft import PeftModel

print("Trying to load a Peft model. It might take a while without feedback")
base_model = AutoModelForCausalLM.from_pretrained(
    model_id,
    low_cpu_mem_usage=True,
)
peft_model = PeftModel.from_pretrained(base_model, adapter_dir)
model = peft_model.merge_and_unload()

os.makedirs(output_dir, exist_ok=True)
print(f"Saving the newly created merged model to {output_dir}")
model.save_pretrained(output_dir, safe_serialization=True)
base_model.config.save_pretrained(output_dir)

def training_function(script_args, training_args):
################
# Dataset
################

# train_dataset = load_dataset(
#     "json",
#     data_files=os.path.join(script_args.train_dataset_path, "dataset.json"),
#     split="train",
# )
# test_dataset = load_dataset(
#     "json",
#     data_files=os.path.join(script_args.test_dataset_path, "dataset.json"),
#     split="train",
# )

 # Load dataset
try:
    print(f"Loading dataset from {script_args.train_dataset_path}")
    full_dataset = load_from_disk(script_args.train_dataset_path)
except Exception as e:
    print(f"Failed to load dataset: {str(e)}")
    raise

# Split the dataset
# split_dataset = full_dataset.train_test_split(test_size=0.0001, seed=42)
# train_dataset = split_dataset['train']
# test_dataset = split_dataset['test']

################
# Model & Tokenizer
################

# Tokenizer
tokenizer = AutoTokenizer.from_pretrained(script_args.model_id, use_fast=True)
tokenizer.pad_token = tokenizer.eos_token
# tokenizer.chat_template = LLAMA_3_CHAT_TEMPLATE

# # template dataset
# def template_dataset(examples):
#     return {
#         "text": tokenizer.apply_chat_template(examples["messages"], tokenize=False)
#     }

# train_dataset = train_dataset.map(template_dataset, remove_columns=["messages"])
# test_dataset = test_dataset.map(template_dataset, remove_columns=["messages"])

# print random sample on rank 0
# if training_args.distributed_state.is_main_process:
#     for index in random.sample(range(len(train_dataset)), 2):
#         print(train_dataset[index]["text"])
training_args.distributed_state.wait_for_everyone()  # wait for all processes to print

# Model
torch_dtype = torch.bfloat16
quant_storage_dtype = torch.bfloat16

# Model - GPT  ##########

torch_dtype = torch.float32

quant_storage_dtype = torch.float16

# Model - GPT  ##########

quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16, # Set to bfloat16 for uniformity
    bnb_4bit_quant_storage=torch.bfloat16, # Ensure uniformity
)

model = AutoModelForCausalLM.from_pretrained(
    script_args.model_id,
    quantization_config=quantization_config,
    attn_implementation="flash_attention_2",
    torch_dtype=torch.bfloat16, # Ensure this is bfloat16
    use_cache=(
        False if training_args.gradient_checkpointing else True
    ),  # this is needed for gradient checkpointing
)#.to('cuda')

# Convert all model parameters to bfloat16

model = model.to(torch.bfloat16)

if training_args.gradient_checkpointing:
    model.gradient_checkpointing_enable()

################
# PEFT
################

# LoRA config based on QLoRA paper & Sebastian Raschka experiment
peft_config = LoraConfig(
    lora_alpha= 16, #8,
    lora_dropout=0,
    r=32, # 32,
    bias="none",
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj","gate_proj", "up_proj", "down_proj","embed_tokens", "lm_head",],
    task_type="CAUSAL_LM",
    use_rslora = True,
    modules_to_save = ["lm_head", "embed_tokens"] # add if you want to use the Llama 3 instruct template
)

apply_liger_kernel_to_llama()

################
# Training
################
trainer = SFTTrainer(
    model=model,
    args=training_args,
    train_dataset=full_dataset,
    dataset_text_field="text",
    # eval_dataset=test_dataset,
    peft_config=peft_config,
    max_seq_length=script_args.max_seq_length,
    tokenizer=tokenizer,
    packing=True,
    dataset_kwargs={
        "add_special_tokens": False,  # We template with special tokens
        "append_concat_token": False,  # No need to add additional separator token
    },
)
if trainer.accelerator.is_main_process:
    trainer.model.print_trainable_parameters()
if trainer.is_fsdp_enabled:
    trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT")

##########################
# Train model
##########################
checkpoint = None
if training_args.resume_from_checkpoint is not None:
    checkpoint = training_args.resume_from_checkpoint
trainer.train(resume_from_checkpoint=checkpoint)

#########################################
# SAVE ADAPTER AND CONFIG FOR SAGEMAKER
#########################################
# save adapter
if trainer.is_fsdp_enabled:
    trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT")
trainer.save_model()

del model
del trainer
torch.cuda.empty_cache()  # Clears the cache
# load and merge
if training_args.distributed_state.is_main_process:
    merge_and_save_model(
        script_args.model_id, training_args.output_dir, "/opt/ml/model"
    )
    tokenizer.save_pretrained("/opt/ml/model")
training_args.distributed_state.wait_for_everyone()  # wait for all processes to print

if name == "main":
# setting few env variables as suggested by aws team
os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "1"
os.environ["NCCL_TIMEOUT"] = "3600"

# code to print all env variables to check changes reflected 
# Get all environment variables
env_vars = os.environ

# Print each environment variable
for key, value in env_vars.items():
    print(f"{key}: {value}")
# ========================= #

parser = TrlParser((ScriptArguments, TrainingArguments))
script_args, training_args = parser.parse_args_and_config()

# set use reentrant to False
if training_args.gradient_checkpointing:
    training_args.gradient_checkpointing_kwargs = {"use_reentrant": True} # True

# set seed
set_seed(training_args.seed)

# =================== #
# print statement to check all the training arguments parsed
print(training_args)
# =================== #

# launch training
training_function(script_args, training_args)

what could be the issue ?

Can you edit your message to make it easier to read.

You can use triple backticks with the word python at the beginning and triple backticks at the end to enclose a python block of code.

```python
like_this = list(range(10))
```

You could consider using the ModelConfig class which will make your code cleaner and may fix the error. It is used in this example.

Sign up or log in to comment