|
import os |
|
from typing import Optional |
|
|
|
import torch |
|
from transformers import Trainer |
|
|
|
|
|
def maybe_zero_3(param, ignore_status=False, name=None): |
|
from deepspeed import zero |
|
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus |
|
|
|
if hasattr(param, "ds_id"): |
|
if param.ds_status == ZeroParamStatus.NOT_AVAILABLE: |
|
if not ignore_status: |
|
print(name, "no ignore status") |
|
with zero.GatheredParameters([param]): |
|
param = param.data.detach().cpu().clone() |
|
else: |
|
param = param.detach().cpu().clone() |
|
return param |
|
|
|
|
|
def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match): |
|
to_return = { |
|
k: t |
|
for k, t in named_params |
|
if any(key_match in k for key_match in keys_to_match) |
|
} |
|
to_return = { |
|
k: maybe_zero_3(v, ignore_status=True, name=k).cpu() |
|
for k, v in to_return.items() |
|
} |
|
return to_return |
|
|
|
|
|
class LLaVATrainer(Trainer): |
|
def _save_checkpoint(self, model, trial, metrics=None): |
|
if getattr(self.args, "tune_mm_mlp_adapter", False): |
|
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR |
|
|
|
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}" |
|
|
|
run_dir = self._get_output_dir(trial=trial) |
|
output_dir = os.path.join(run_dir, checkpoint_folder) |
|
|
|
|
|
keys_to_match = ["mm_projector"] |
|
if getattr(self.args, "use_im_start_end", False): |
|
keys_to_match.extend(["embed_tokens", "embed_in"]) |
|
|
|
weight_to_save = get_mm_adapter_state_maybe_zero_3( |
|
self.model.named_parameters(), keys_to_match |
|
) |
|
|
|
if self.args.local_rank == 0 or self.args.local_rank == -1: |
|
self.model.config.save_pretrained(output_dir) |
|
torch.save( |
|
weight_to_save, os.path.join(output_dir, f"mm_projector.bin") |
|
) |
|
else: |
|
super(LLaVATrainer, self)._save_checkpoint(model, trial, metrics) |
|
|
|
def _save(self, output_dir: Optional[str] = None, state_dict=None): |
|
if getattr(self.args, "tune_mm_mlp_adapter", False): |
|
pass |
|
else: |
|
super(LLaVATrainer, self)._save(output_dir, state_dict) |
|
|