nbroad's picture
Training in progress, step 10
9c13803 verified
raw
history blame
8.14 kB
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
Trainer,
TrainingArguments,
)
from datasets import load_dataset, concatenate_datasets
from omegaconf import DictConfig, OmegaConf
import hydra
import wandb
import shutil
import os
from functools import partial
from pathlib import Path
from trl import (
SFTTrainer,
ModelConfig,
get_quantization_config,
get_kbit_device_map,
get_peft_config,
DataCollatorForCompletionOnlyLM,
)
from dotenv import load_dotenv
from peft import (
get_peft_model,
prepare_model_for_kbit_training,
AutoPeftModelForSequenceClassification,
)
# from utils import add_metric_to_card
loaded = load_dotenv("../.env", override=True)
if not loaded:
raise ValueError("Failed to load .env file")
def tokenize(example, tokenizer):
ids = tokenizer.apply_chat_template([
{"role": "user", "content": example["text"]},
{"role": "assistant", "content": example["response"]},
])
return {
"input_ids": ids,
}
@hydra.main(config_path="conf", config_name="q7b-4bit")
def main(cfg: DictConfig):
cfg.time_start = "_".join(str(Path.cwd()).rsplit("/", 2)[-2:])
if cfg.DEBUG:
cfg.model_config.model_name_or_path = cfg.debug_model
script_args = cfg.script_args
training_args = TrainingArguments(**OmegaConf.to_container(cfg.training_args))
model_config = ModelConfig(**OmegaConf.to_container(cfg.model_config))
if training_args.process_index == 0:
if cfg.eval_only or training_args.resume_from_checkpoint is not None:
wandb_id = cfg.wandb_id
resume = "must"
config = None
else:
wandb_id = None
resume = None
config = OmegaConf.to_container(cfg)
wandb.init(config=config, id=wandb_id, resume=resume)
# copy current file to output, so it gets saved to hub
shutil.copy(
Path(__file__).resolve(),
Path(training_args.output_dir) / Path(__file__).name,
)
shutil.copy(
Path(__file__).resolve().parent / "utils.py",
Path(training_args.output_dir) / "utils.py",
)
quantization_config = get_quantization_config(model_config)
model_kwargs = dict(
revision=model_config.model_revision,
trust_remote_code=model_config.trust_remote_code,
attn_implementation=model_config.attn_implementation,
torch_dtype=model_config.torch_dtype,
use_cache=False if training_args.gradient_checkpointing else True,
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
cache_dir=os.environ["HF_HUB_CACHE"],
)
peft_config = get_peft_config(model_config)
if training_args.use_liger_kernel:
from liger_kernel.transformers import (
apply_liger_kernel_to_qwen2,
apply_liger_kernel_to_llama,
apply_liger_kernel_to_mistral,
)
apply_liger_kernel_to_qwen2()
apply_liger_kernel_to_llama()
apply_liger_kernel_to_mistral()
if cfg.eval_only:
model = AutoPeftModelForSequenceClassification.from_pretrained(
model_config.model_name_or_path,
**model_kwargs,
token=os.environ["HF_WRITE_PERSONAL"],
)
if cfg.merge_adapters:
model = model.merge_and_unload()
else:
model = AutoModelForCausalLM.from_pretrained(
model_config.model_name_or_path,
**model_kwargs,
token=os.environ["HF_GATED"],
)
tokenizer = AutoTokenizer.from_pretrained(
model_config.model_name_or_path,
use_fast=True,
token=os.environ["HF_GATED"],
)
tokenizer.padding_side = "left"
tokenizer.pad_token = cfg.pad_token
if not cfg.eval_only and model_config.load_in_4bit:
model = prepare_model_for_kbit_training(
model,
use_gradient_checkpointing=training_args.gradient_checkpointing,
gradient_checkpointing_kwargs=training_args.gradient_checkpointing_kwargs,
)
elif not cfg.eval_only and training_args.gradient_checkpointing:
model.enable_input_require_grads()
if not cfg.eval_only:
model = get_peft_model(model, peft_config)
with training_args.main_process_first():
ds = load_dataset(
script_args.dataset_name,
script_args.config,
token=os.environ["HF_WRITE_PERSONAL"],
)
# hack to downsample english squad
# ds["train"] = concatenate_datasets(
# [
# ds["train"].select(range(0, 45000)),
# ds["train"].select(range(98596, len(ds["train"]))),
# ])
if cfg.DEBUG:
ds[cfg.train_split_name] = (
ds[cfg.train_split_name].shuffle().select(range(100))
)
# ds[cfg.val_split_name] = ds[cfg.val_split_name].shuffle().select(range(100))
# if not cfg.eval_only:
# ds[cfg.val_split_name] = ds[cfg.val_split_name].shuffle().select(range(500))
ds = ds.map(tokenize, fn_kwargs={"tokenizer": tokenizer}, num_proc=cfg.num_proc, remove_columns=ds["train"].column_names)
collator = DataCollatorForCompletionOnlyLM(
tokenizer=tokenizer,
mlm=False,
pad_to_multiple_of=16,
response_template=cfg.response_template_ids
)
if training_args.process_index == 0:
group = os.environ["WANDB_RUN_GROUP"]
training_args.hub_model_id = f"nbroad/nbroad-odesia-{group}-{wandb.run.id}"
training_args.hub_token = os.environ["HF_WRITE_PERSONAL"]
prefix = ""
if cfg.eval_only:
if "awq" in model_config.model_name_or_path.lower():
prefix = "awq_"
if model_config.load_in_4bit:
prefix += "int4_"
elif model_config.torch_dtype == "bfloat16":
prefix += "bf16_"
elif model_config.torch_dtype == "float16":
prefix += "fp16_"
trainer = SFTTrainer(
model=model,
args=training_args,
train_dataset=ds["train"],
eval_dataset=(
ds[cfg.val_split_name] if training_args.eval_strategy != "no" else None
),
processing_class=tokenizer,
data_collator=collator,
# compute_metrics=partial(compute_metrics, prefix=prefix),
)
if training_args.process_index == 0:
trainer.model.config.update(
{
"wandb_id": wandb.run.id,
"fold": cfg.fold,
"group": group,
"dataset": script_args.dataset_name,
}
)
if not cfg.eval_only:
if training_args.resume_from_checkpoint is not None:
os.chdir(Path(training_args.resume_from_checkpoint).parent)
trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
else:
metrics = trainer.evaluate()
# if training_args.process_index == 0:
# met = [x for x in metrics if "accuracy" in x][0]
# result = add_metric_to_card(
# repo=training_args.hub_model_id,
# metrics_pretty_name=met,
# metrics_value=metrics[met],
# dataset_id=script_args.dataset_name,
# dataset_split=cfg.val_split_name,
# model_path=model_config.model_name_or_path,
# model_dtype=model_config.torch_dtype,
# token=os.environ["HF_WRITE_PERSONAL"],
# )
# print(result)
if not cfg.eval_only:
# Save and push to hub
trainer.save_model(training_args.output_dir)
if training_args.push_to_hub:
trainer.push_to_hub(
dataset_name=script_args.dataset_name,
model_name=model_config.model_name_or_path,
tags=cfg.hub_repo_tags,
)
if training_args.process_index == 0:
wandb.finish()
if __name__ == "__main__":
main()