|
from transformers import Trainer |
|
import torch.nn.functional as F |
|
from typing import Optional |
|
import os |
|
import torch |
|
from transformers.utils import WEIGHTS_NAME |
|
import json |
|
|
|
class VideoBaseTrainer(Trainer): |
|
|
|
def _save(self, output_dir: Optional[str] = None, state_dict=None): |
|
output_dir = output_dir if output_dir is not None else self.args.output_dir |
|
os.makedirs(output_dir, exist_ok=True) |
|
if state_dict is None: |
|
state_dict = self.model.state_dict() |
|
|
|
|
|
model_config = self.model.config.to_dict() |
|
|
|
|
|
model_config['model'] = self.model.__class__.__name__ |
|
|
|
with open(os.path.join(output_dir, "config.json"), "w") as file: |
|
json.dump(self.model.config.to_dict(), file) |
|
torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME)) |
|
torch.save(self.args, os.path.join(output_dir, "training_args.bin")) |
|
|