File size: 2,314 Bytes
8ced4d2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 |
import os
from typing import Optional
import torch
from transformers import Trainer
def maybe_zero_3(param, ignore_status=False, name=None):
from deepspeed import zero
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
if hasattr(param, "ds_id"):
if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
if not ignore_status:
print(name, "no ignore status")
with zero.GatheredParameters([param]):
param = param.data.detach().cpu().clone()
else:
param = param.detach().cpu().clone()
return param
def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match):
to_return = {
k: t
for k, t in named_params
if any(key_match in k for key_match in keys_to_match)
}
to_return = {
k: maybe_zero_3(v, ignore_status=True, name=k).cpu()
for k, v in to_return.items()
}
return to_return
class LLaVATrainer(Trainer):
def _save_checkpoint(self, model, trial, metrics=None):
if getattr(self.args, "tune_mm_mlp_adapter", False):
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
run_dir = self._get_output_dir(trial=trial)
output_dir = os.path.join(run_dir, checkpoint_folder)
# Only save Adapter
keys_to_match = ["mm_projector"]
if getattr(self.args, "use_im_start_end", False):
keys_to_match.extend(["embed_tokens", "embed_in"])
weight_to_save = get_mm_adapter_state_maybe_zero_3(
self.model.named_parameters(), keys_to_match
)
if self.args.local_rank == 0 or self.args.local_rank == -1:
self.model.config.save_pretrained(output_dir)
torch.save(
weight_to_save, os.path.join(output_dir, f"mm_projector.bin")
)
else:
super(LLaVATrainer, self)._save_checkpoint(model, trial, metrics)
def _save(self, output_dir: Optional[str] = None, state_dict=None):
if getattr(self.args, "tune_mm_mlp_adapter", False):
pass
else:
super(LLaVATrainer, self)._save(output_dir, state_dict)
|