from transformers import ( AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments, ) from datasets import load_dataset, concatenate_datasets from omegaconf import DictConfig, OmegaConf import hydra import wandb import shutil import os from functools import partial from pathlib import Path from trl import ( SFTTrainer, ModelConfig, get_quantization_config, get_kbit_device_map, get_peft_config, DataCollatorForCompletionOnlyLM, ) from dotenv import load_dotenv from peft import ( get_peft_model, prepare_model_for_kbit_training, AutoPeftModelForSequenceClassification, ) # from utils import add_metric_to_card loaded = load_dotenv("../.env", override=True) if not loaded: raise ValueError("Failed to load .env file") def tokenize(example, tokenizer): ids = tokenizer.apply_chat_template([ {"role": "user", "content": example["text"]}, {"role": "assistant", "content": example["response"]}, ]) return { "input_ids": ids, } @hydra.main(config_path="conf", config_name="q7b-4bit") def main(cfg: DictConfig): cfg.time_start = "_".join(str(Path.cwd()).rsplit("/", 2)[-2:]) if cfg.DEBUG: cfg.model_config.model_name_or_path = cfg.debug_model script_args = cfg.script_args training_args = TrainingArguments(**OmegaConf.to_container(cfg.training_args)) model_config = ModelConfig(**OmegaConf.to_container(cfg.model_config)) if training_args.process_index == 0: if cfg.eval_only or training_args.resume_from_checkpoint is not None: wandb_id = cfg.wandb_id resume = "must" config = None else: wandb_id = None resume = None config = OmegaConf.to_container(cfg) wandb.init(config=config, id=wandb_id, resume=resume) # copy current file to output, so it gets saved to hub shutil.copy( Path(__file__).resolve(), Path(training_args.output_dir) / Path(__file__).name, ) shutil.copy( Path(__file__).resolve().parent / "utils.py", Path(training_args.output_dir) / "utils.py", ) quantization_config = get_quantization_config(model_config) model_kwargs = dict( revision=model_config.model_revision, trust_remote_code=model_config.trust_remote_code, attn_implementation=model_config.attn_implementation, torch_dtype=model_config.torch_dtype, use_cache=False if training_args.gradient_checkpointing else True, device_map=get_kbit_device_map() if quantization_config is not None else None, quantization_config=quantization_config, cache_dir=os.environ["HF_HUB_CACHE"], ) peft_config = get_peft_config(model_config) if training_args.use_liger_kernel: from liger_kernel.transformers import ( apply_liger_kernel_to_qwen2, apply_liger_kernel_to_llama, apply_liger_kernel_to_mistral, ) apply_liger_kernel_to_qwen2() apply_liger_kernel_to_llama() apply_liger_kernel_to_mistral() if cfg.eval_only: model = AutoPeftModelForSequenceClassification.from_pretrained( model_config.model_name_or_path, **model_kwargs, token=os.environ["HF_WRITE_PERSONAL"], ) if cfg.merge_adapters: model = model.merge_and_unload() else: model = AutoModelForCausalLM.from_pretrained( model_config.model_name_or_path, **model_kwargs, token=os.environ["HF_GATED"], ) tokenizer = AutoTokenizer.from_pretrained( model_config.model_name_or_path, use_fast=True, token=os.environ["HF_GATED"], ) tokenizer.padding_side = "left" tokenizer.pad_token = cfg.pad_token if not cfg.eval_only and model_config.load_in_4bit: model = prepare_model_for_kbit_training( model, use_gradient_checkpointing=training_args.gradient_checkpointing, gradient_checkpointing_kwargs=training_args.gradient_checkpointing_kwargs, ) elif not cfg.eval_only and training_args.gradient_checkpointing: model.enable_input_require_grads() if not cfg.eval_only: model = get_peft_model(model, peft_config) with training_args.main_process_first(): ds = load_dataset( script_args.dataset_name, script_args.config, token=os.environ["HF_WRITE_PERSONAL"], ) # hack to downsample english squad # ds["train"] = concatenate_datasets( # [ # ds["train"].select(range(0, 45000)), # ds["train"].select(range(98596, len(ds["train"]))), # ]) if cfg.DEBUG: ds[cfg.train_split_name] = ( ds[cfg.train_split_name].shuffle().select(range(100)) ) # ds[cfg.val_split_name] = ds[cfg.val_split_name].shuffle().select(range(100)) # if not cfg.eval_only: # ds[cfg.val_split_name] = ds[cfg.val_split_name].shuffle().select(range(500)) ds = ds.map(tokenize, fn_kwargs={"tokenizer": tokenizer}, num_proc=cfg.num_proc, remove_columns=ds["train"].column_names) collator = DataCollatorForCompletionOnlyLM( tokenizer=tokenizer, mlm=False, pad_to_multiple_of=16, response_template=cfg.response_template_ids ) if training_args.process_index == 0: group = os.environ["WANDB_RUN_GROUP"] training_args.hub_model_id = f"nbroad/nbroad-odesia-{group}-{wandb.run.id}" training_args.hub_token = os.environ["HF_WRITE_PERSONAL"] prefix = "" if cfg.eval_only: if "awq" in model_config.model_name_or_path.lower(): prefix = "awq_" if model_config.load_in_4bit: prefix += "int4_" elif model_config.torch_dtype == "bfloat16": prefix += "bf16_" elif model_config.torch_dtype == "float16": prefix += "fp16_" trainer = SFTTrainer( model=model, args=training_args, train_dataset=ds["train"], eval_dataset=( ds[cfg.val_split_name] if training_args.eval_strategy != "no" else None ), processing_class=tokenizer, data_collator=collator, # compute_metrics=partial(compute_metrics, prefix=prefix), ) if training_args.process_index == 0: trainer.model.config.update( { "wandb_id": wandb.run.id, "fold": cfg.fold, "group": group, "dataset": script_args.dataset_name, } ) if not cfg.eval_only: if training_args.resume_from_checkpoint is not None: os.chdir(Path(training_args.resume_from_checkpoint).parent) trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint) else: metrics = trainer.evaluate() # if training_args.process_index == 0: # met = [x for x in metrics if "accuracy" in x][0] # result = add_metric_to_card( # repo=training_args.hub_model_id, # metrics_pretty_name=met, # metrics_value=metrics[met], # dataset_id=script_args.dataset_name, # dataset_split=cfg.val_split_name, # model_path=model_config.model_name_or_path, # model_dtype=model_config.torch_dtype, # token=os.environ["HF_WRITE_PERSONAL"], # ) # print(result) if not cfg.eval_only: # Save and push to hub trainer.save_model(training_args.output_dir) if training_args.push_to_hub: trainer.push_to_hub( dataset_name=script_args.dataset_name, model_name=model_config.model_name_or_path, tags=cfg.hub_repo_tags, ) if training_args.process_index == 0: wandb.finish() if __name__ == "__main__": main()