Spaces:
Configuration error
Configuration error
# -*- coding: utf-8 -*- | |
""" | |
@author:XuMing([email protected]) | |
@description: | |
""" | |
import math | |
import os | |
from dataclasses import dataclass, field | |
from glob import glob | |
from typing import Any, List, Union, Optional, Dict | |
import torch | |
from datasets import load_dataset | |
from loguru import logger | |
from peft import LoraConfig, TaskType, get_peft_model, PeftModel, prepare_model_for_int8_training | |
from sklearn.metrics import mean_squared_error, mean_absolute_error | |
from torch.utils.data import Dataset | |
from transformers import ( | |
AutoConfig, | |
PreTrainedTokenizerBase, | |
BloomForSequenceClassification, | |
LlamaForSequenceClassification, | |
LlamaTokenizer, | |
BloomTokenizerFast, | |
AlbertForSequenceClassification, | |
BertForSequenceClassification, | |
BertTokenizer, | |
AutoTokenizer, | |
RobertaForSequenceClassification, | |
AutoModelForSequenceClassification, | |
RobertaTokenizer, | |
HfArgumentParser, | |
Trainer, | |
TrainingArguments, | |
set_seed, | |
) | |
from transformers.trainer import TRAINING_ARGS_NAME | |
MODEL_CLASSES = { | |
"bert": (AutoConfig, BertForSequenceClassification, BertTokenizer), | |
"roberta": (AutoConfig, RobertaForSequenceClassification, RobertaTokenizer), | |
"albert": (AutoConfig, AlbertForSequenceClassification, AutoTokenizer), | |
"bloom": (AutoConfig, BloomForSequenceClassification, BloomTokenizerFast), | |
"llama": (AutoConfig, LlamaForSequenceClassification, LlamaTokenizer), | |
"auto": (AutoConfig, AutoModelForSequenceClassification, AutoTokenizer), | |
} | |
class ModelArguments: | |
""" | |
Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch. | |
""" | |
model_type: str = field( | |
default=None, | |
metadata={"help": "Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys())} | |
) | |
model_name_or_path: Optional[str] = field( | |
default=None, | |
metadata={ | |
"help": ( | |
"The model checkpoint for weights initialization.Don't set if you want to train a model from scratch." | |
) | |
}, | |
) | |
tokenizer_name_or_path: Optional[str] = field( | |
default=None, | |
metadata={ | |
"help": ( | |
"The tokenizer for weights initialization.Don't set if you want to train a model from scratch." | |
) | |
}, | |
) | |
load_in_8bit: bool = field(default=False, metadata={"help": "Whether to load the model in 8bit mode or not."}) | |
cache_dir: Optional[str] = field( | |
default=None, | |
metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"}, | |
) | |
use_fast_tokenizer: bool = field( | |
default=False, | |
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, | |
) | |
torch_dtype: Optional[str] = field( | |
default=None, | |
metadata={ | |
"help": ( | |
"Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the " | |
"dtype will be automatically derived from the model's weights." | |
), | |
"choices": ["auto", "bfloat16", "float16", "float32"], | |
}, | |
) | |
device_map: Optional[str] = field( | |
default="auto", | |
metadata={"help": "Device to map model to. If `auto` is passed, the device will be selected automatically. "}, | |
) | |
trust_remote_code: bool = field( | |
default=True, | |
metadata={"help": "Whether to trust remote code when loading a model from a remote checkpoint."}, | |
) | |
def __post_init__(self): | |
if self.model_type is None: | |
raise ValueError( | |
"You must specify a valid model_type to run training. Available model types are " + ", ".join( | |
MODEL_CLASSES.keys())) | |
if self.model_name_or_path is None: | |
raise ValueError("You must specify a valid model_name_or_path to run training.") | |
class DataTrainingArguments: | |
""" | |
Arguments pertaining to what data we are going to input our model for training and eval. | |
""" | |
dataset_name: Optional[str] = field( | |
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} | |
) | |
dataset_config_name: Optional[str] = field( | |
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} | |
) | |
train_file_dir: Optional[str] = field(default=None, metadata={"help": "The input jsonl data file folder."}) | |
validation_file_dir: Optional[str] = field(default=None, metadata={"help": "The evaluation jsonl file folder."}, ) | |
max_source_length: Optional[int] = field(default=256, metadata={"help": "Max length of prompt input text"}) | |
max_target_length: Optional[int] = field(default=256, metadata={"help": "Max length of output text"}) | |
max_train_samples: Optional[int] = field( | |
default=None, | |
metadata={ | |
"help": ( | |
"For debugging purposes or quicker training, truncate the number of training examples to this " | |
"value if set." | |
) | |
}, | |
) | |
max_eval_samples: Optional[int] = field( | |
default=None, | |
metadata={ | |
"help": ( | |
"For debugging purposes or quicker training, truncate the number of evaluation examples to this " | |
"value if set." | |
) | |
}, | |
) | |
overwrite_cache: bool = field( | |
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} | |
) | |
validation_split_percentage: Optional[int] = field( | |
default=1, | |
metadata={ | |
"help": "The percentage of the train set used as validation set in case there's no validation split" | |
}, | |
) | |
preprocessing_num_workers: Optional[int] = field( | |
default=4, | |
metadata={"help": "The number of processes to use for the preprocessing."}, | |
) | |
class PeftArguments(TrainingArguments): | |
use_peft: bool = field(default=True, metadata={"help": "Whether to use peft"}) | |
target_modules: Optional[str] = field(default="all") | |
lora_rank: Optional[int] = field(default=8) | |
lora_dropout: Optional[float] = field(default=0.05) | |
lora_alpha: Optional[float] = field(default=32.0) | |
modules_to_save: Optional[str] = field(default=None) | |
peft_path: Optional[str] = field(default=None) | |
def compute_metrics(eval_preds): | |
preds, labels = eval_preds | |
# Here, predictions is rewards_chosen and rewards_rejected. | |
if isinstance(preds, torch.Tensor): | |
preds = preds.detach().cpu().numpy() | |
if isinstance(labels, torch.Tensor): | |
labels = labels.detach().cpu().numpy() | |
# MSE | |
mse = mean_squared_error(labels, preds) | |
# MAE | |
mae = mean_absolute_error(labels, preds) | |
return {"mse": mse, "mae": mae} | |
class RewardDataCollatorWithPadding: | |
"""We need to define a special data collator that batches the data in our chosen vs rejected format""" | |
tokenizer: PreTrainedTokenizerBase | |
padding: Union[bool, str] = True | |
max_length: Optional[int] = None | |
pad_to_multiple_of: Optional[int] = None | |
return_tensors: str = "pt" | |
def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]: | |
features_chosen = [] | |
features_rejected = [] | |
for feature in features: | |
features_chosen.append( | |
{ | |
"input_ids": feature["input_ids_chosen"], | |
"attention_mask": feature["attention_mask_chosen"], | |
} | |
) | |
features_rejected.append( | |
{ | |
"input_ids": feature["input_ids_rejected"], | |
"attention_mask": feature["attention_mask_rejected"], | |
} | |
) | |
batch_chosen = self.tokenizer.pad( | |
features_chosen, | |
padding=self.padding, | |
max_length=self.max_length, | |
pad_to_multiple_of=self.pad_to_multiple_of, | |
return_tensors=self.return_tensors, | |
) | |
batch_rejected = self.tokenizer.pad( | |
features_rejected, | |
padding=self.padding, | |
max_length=self.max_length, | |
pad_to_multiple_of=self.pad_to_multiple_of, | |
return_tensors=self.return_tensors, | |
) | |
batch = { | |
"input_ids_chosen": batch_chosen["input_ids"], | |
"attention_mask_chosen": batch_chosen["attention_mask"], | |
"input_ids_rejected": batch_rejected["input_ids"], | |
"attention_mask_rejected": batch_rejected["attention_mask"], | |
"return_loss": True, | |
} | |
return batch | |
class RewardTrainer(Trainer): | |
""" | |
Trainer for reward models | |
Define how to compute the reward loss. Use the InstructGPT pairwise logloss: https://arxiv.org/abs/2203.02155 | |
""" | |
def compute_loss(self, model, inputs, return_outputs=False): | |
rewards_chosen = model(input_ids=inputs["input_ids_chosen"], | |
attention_mask=inputs["attention_mask_chosen"])[0] | |
rewards_rejected = model(input_ids=inputs["input_ids_rejected"], | |
attention_mask=inputs["attention_mask_rejected"])[0] | |
loss = -torch.nn.functional.logsigmoid(rewards_chosen - rewards_rejected).mean() | |
if return_outputs: | |
return loss, {"rewards_chosen": rewards_chosen, "rewards_rejected": rewards_rejected} | |
return loss | |
def evaluate( | |
self, | |
eval_dataset: Optional[Dataset] = None, | |
ignore_keys: Optional[List[str]] = None, | |
metric_key_prefix: str = "eval", | |
) -> Dict[str, float]: | |
if eval_dataset is None: | |
eval_dataset = self.eval_dataset | |
return super().evaluate(eval_dataset=eval_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix) | |
def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=None): | |
# Prepare inputs for chosen and rejected separately | |
device = model.device | |
inputs_chosen = { | |
"input_ids": inputs["input_ids_chosen"].to(device), | |
"attention_mask": inputs["attention_mask_chosen"].to(device), | |
} | |
outputs_chosen = model(**inputs_chosen) | |
rewards_chosen = outputs_chosen.logits.detach() | |
inputs_rejected = { | |
"input_ids": inputs["input_ids_rejected"].to(device), | |
"attention_mask": inputs["attention_mask_rejected"].to(device), | |
} | |
outputs_rejected = model(**inputs_rejected) | |
rewards_rejected = outputs_rejected.logits.detach() | |
# Keep the compute_loss method | |
loss = -torch.nn.functional.logsigmoid(rewards_chosen - rewards_rejected).mean() | |
if prediction_loss_only: | |
return (loss, None, None) | |
return (loss, rewards_chosen, rewards_rejected) | |
def save_model(self, output_dir=None, _internal_call=False): | |
"""Save the LoRA model.""" | |
os.makedirs(output_dir, exist_ok=True) | |
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME)) | |
self.model.save_pretrained(output_dir) | |
def save_model(output_dir, model, tokenizer, args): | |
"""Save the model and the tokenizer.""" | |
os.makedirs(output_dir, exist_ok=True) | |
# Take care of distributed/parallel training | |
model_to_save = model.module if hasattr(model, "module") else model | |
model_to_save.save_pretrained(output_dir) | |
tokenizer.save_pretrained(output_dir) | |
torch.save(args, os.path.join(output_dir, TRAINING_ARGS_NAME)) | |
class CastOutputToFloat(torch.nn.Sequential): | |
"""Cast the output of the model to float""" | |
def forward(self, x): | |
return super().forward(x).to(torch.float32) | |
def print_trainable_parameters(model): | |
""" | |
Prints the number of trainable parameters in the model. | |
""" | |
trainable_params = 0 | |
all_param = 0 | |
for _, param in model.named_parameters(): | |
all_param += param.numel() | |
if param.requires_grad: | |
trainable_params += param.numel() | |
print( | |
f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}" | |
) | |
def find_all_linear_names(peft_model, int4=False, int8=False): | |
cls = torch.nn.Linear | |
if int4 or int8: | |
import bitsandbytes as bnb | |
if int4: | |
cls = bnb.nn.Linear4bit | |
elif int8: | |
cls = bnb.nn.Linear8bitLt | |
lora_module_names = set() | |
for name, module in peft_model.named_modules(): | |
if isinstance(module, cls): | |
# last layer is not add to lora_module_names | |
if 'lm_head' in name: | |
continue | |
if 'score' in name: | |
continue | |
names = name.split('.') | |
lora_module_names.add(names[0] if len(names) == 1 else names[-1]) | |
return sorted(lora_module_names) | |
def main(): | |
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, PeftArguments)) | |
model_args, data_args, training_args = parser.parse_args_into_dataclasses() | |
logger.info(f"Model args: {model_args}") | |
logger.info(f"Data args: {data_args}") | |
logger.info(f"Training args: {training_args}") | |
logger.info( | |
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" | |
+ f" distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" | |
) | |
# Set seed before initializing model. | |
set_seed(training_args.seed) | |
# Load model | |
if not model_args.model_type: | |
raise ValueError("Please specify a model_type, e.g. llama, chatglm, bloom, etc.") | |
config_class, model_class, tokenizer_class = MODEL_CLASSES[model_args.model_type] | |
if model_args.model_name_or_path: | |
torch_dtype = ( | |
model_args.torch_dtype | |
if model_args.torch_dtype in ["auto", None] | |
else getattr(torch, model_args.torch_dtype) | |
) | |
world_size = int(os.environ.get("WORLD_SIZE", 1)) | |
if world_size > 1: | |
model_args.device_map = {"": int(os.environ["LOCAL_RANK"]) or 0} | |
config = config_class.from_pretrained( | |
model_args.model_name_or_path, | |
num_labels=1, | |
torch_dtype=torch_dtype, | |
trust_remote_code=model_args.trust_remote_code, | |
cache_dir=model_args.cache_dir | |
) | |
if model_args.model_type in ['bloom', 'llama']: | |
model = model_class.from_pretrained( | |
model_args.model_name_or_path, | |
config=config, | |
load_in_8bit=model_args.load_in_8bit, | |
device_map=model_args.device_map, | |
trust_remote_code=model_args.trust_remote_code, | |
) | |
model.score = CastOutputToFloat(model.score) | |
else: | |
model = model_class.from_pretrained( | |
model_args.model_name_or_path, | |
config=config, | |
cache_dir=model_args.cache_dir, | |
ignore_mismatched_sizes=True | |
) | |
model.to(training_args.device) | |
else: | |
raise ValueError(f"Error, model_name_or_path is None, RM must be loaded from a pre-trained model") | |
# Load tokenizer | |
if model_args.model_type == "bloom": | |
model_args.use_fast_tokenizer = True | |
tokenizer_kwargs = { | |
"cache_dir": model_args.cache_dir, | |
"use_fast": model_args.use_fast_tokenizer, | |
"trust_remote_code": model_args.trust_remote_code, | |
} | |
tokenizer_name_or_path = model_args.tokenizer_name_or_path | |
if not tokenizer_name_or_path: | |
tokenizer_name_or_path = model_args.model_name_or_path | |
tokenizer = tokenizer_class.from_pretrained(tokenizer_name_or_path, **tokenizer_kwargs) | |
if tokenizer.pad_token_id is None: | |
tokenizer.pad_token_id = 0 | |
if training_args.use_peft: | |
if training_args.peft_path is not None: | |
logger.info(f"Peft from pre-trained model: {training_args.peft_path}") | |
model = PeftModel.from_pretrained(model, training_args.peft_path, is_trainable=True) | |
else: | |
logger.info("Init new peft model") | |
target_modules = training_args.target_modules.split(',') if training_args.target_modules else None | |
if target_modules and 'all' in target_modules: | |
target_modules = find_all_linear_names(model, int4=False, int8=model_args.load_in_8bit) | |
modules_to_save = training_args.modules_to_save | |
if modules_to_save is not None: | |
modules_to_save = modules_to_save.split(',') | |
logger.info(f"Peft target_modules: {target_modules}") | |
logger.info(f"Peft lora_rank: {training_args.lora_rank}") | |
peft_config = LoraConfig( | |
task_type=TaskType.SEQ_CLS, | |
target_modules=target_modules, | |
inference_mode=False, | |
r=training_args.lora_rank, | |
lora_alpha=training_args.lora_alpha, | |
lora_dropout=training_args.lora_dropout, | |
modules_to_save=modules_to_save) | |
model = get_peft_model(model, peft_config) | |
if model_args.load_in_8bit: | |
model = prepare_model_for_int8_training(model) | |
model.print_trainable_parameters() | |
else: | |
logger.info("Full parameters training") | |
print_trainable_parameters(model) | |
# Get reward dataset for tuning the reward model. | |
if data_args.dataset_name is not None: | |
# Downloading and loading a dataset from the hub. | |
raw_datasets = load_dataset( | |
data_args.dataset_name, | |
data_args.dataset_config_name, | |
cache_dir=model_args.cache_dir, | |
) | |
if "validation" not in raw_datasets.keys(): | |
raw_datasets["validation"] = load_dataset( | |
data_args.dataset_name, | |
data_args.dataset_config_name, | |
split=f"train[:{data_args.validation_split_percentage}%]", | |
cache_dir=model_args.cache_dir, | |
) | |
raw_datasets["train"] = load_dataset( | |
data_args.dataset_name, | |
data_args.dataset_config_name, | |
split=f"train[{data_args.validation_split_percentage}%:]", | |
cache_dir=model_args.cache_dir, | |
) | |
else: | |
data_files = {} | |
if data_args.train_file_dir is not None and os.path.exists(data_args.train_file_dir): | |
train_data_files = glob(f'{data_args.train_file_dir}/**/*.json', recursive=True) + glob( | |
f'{data_args.train_file_dir}/**/*.jsonl', recursive=True) | |
logger.info(f"train files: {', '.join(train_data_files)}") | |
data_files["train"] = train_data_files | |
if data_args.validation_file_dir is not None and os.path.exists(data_args.validation_file_dir): | |
eval_data_files = glob(f'{data_args.validation_file_dir}/**/*.json', recursive=True) + glob( | |
f'{data_args.validation_file_dir}/**/*.jsonl', recursive=True) | |
logger.info(f"eval files: {', '.join(eval_data_files)}") | |
data_files["validation"] = eval_data_files | |
raw_datasets = load_dataset( | |
'json', | |
data_files=data_files, | |
cache_dir=model_args.cache_dir, | |
) | |
# If no validation data is there, validation_split_percentage will be used to divide the dataset. | |
if "validation" not in raw_datasets.keys(): | |
raw_datasets["validation"] = load_dataset( | |
'json', | |
data_files=data_files, | |
split=f"train[:{data_args.validation_split_percentage}%]", | |
cache_dir=model_args.cache_dir, | |
) | |
raw_datasets["train"] = load_dataset( | |
'json', | |
data_files=data_files, | |
split=f"train[{data_args.validation_split_percentage}%:]", | |
cache_dir=model_args.cache_dir, | |
) | |
logger.info(f"Raw datasets: {raw_datasets}") | |
# Preprocessing the datasets | |
full_max_length = data_args.max_source_length + data_args.max_target_length | |
def preprocess_reward_function(examples): | |
""" | |
Turn the dataset into pairs of Question + Answer, where input_ids_chosen is the preferred question + answer | |
and text_rejected is the other. | |
""" | |
new_examples = { | |
"input_ids_chosen": [], | |
"attention_mask_chosen": [], | |
"input_ids_rejected": [], | |
"attention_mask_rejected": [], | |
} | |
for question, chosen, rejected in zip(examples["question"], examples["response_chosen"], | |
examples["response_rejected"]): | |
tokenized_chosen = tokenizer("Question: " + question + "\n\nAnswer: " + chosen) | |
tokenized_rejected = tokenizer("Question: " + question + "\n\nAnswer: " + rejected) | |
new_examples["input_ids_chosen"].append(tokenized_chosen["input_ids"]) | |
new_examples["attention_mask_chosen"].append(tokenized_chosen["attention_mask"]) | |
new_examples["input_ids_rejected"].append(tokenized_rejected["input_ids"]) | |
new_examples["attention_mask_rejected"].append(tokenized_rejected["attention_mask"]) | |
return new_examples | |
train_dataset = None | |
max_train_samples = 0 | |
if training_args.do_train: | |
if "train" not in raw_datasets: | |
raise ValueError("--do_train requires a train dataset") | |
train_dataset = raw_datasets['train'] | |
max_train_samples = len(train_dataset) | |
if data_args.max_train_samples is not None and data_args.max_train_samples > 0: | |
max_train_samples = min(len(train_dataset), data_args.max_train_samples) | |
train_dataset = train_dataset.select(range(max_train_samples)) | |
logger.debug(f"Example train_dataset[0]: {train_dataset[0]}") | |
with training_args.main_process_first(desc="Train dataset tokenization"): | |
tokenized_dataset = train_dataset.shuffle().map( | |
preprocess_reward_function, | |
batched=True, | |
num_proc=data_args.preprocessing_num_workers, | |
remove_columns=train_dataset.column_names, | |
load_from_cache_file=not data_args.overwrite_cache, | |
desc="Running tokenizer on dataset", | |
) | |
train_dataset = tokenized_dataset.filter( | |
lambda x: 0 < len(x['input_ids_rejected']) <= full_max_length and 0 < len( | |
x['input_ids_chosen']) <= full_max_length | |
) | |
logger.debug(f"Num train_samples: {len(train_dataset)}") | |
logger.debug("Tokenized training example:") | |
logger.debug(tokenizer.decode(train_dataset[0]['input_ids_chosen'])) | |
eval_dataset = None | |
max_eval_samples = 0 | |
if training_args.do_eval: | |
with training_args.main_process_first(desc="Eval dataset tokenization"): | |
if "validation" not in raw_datasets: | |
raise ValueError("--do_eval requires a validation dataset") | |
eval_dataset = raw_datasets["validation"] | |
max_eval_samples = len(eval_dataset) | |
if data_args.max_eval_samples is not None and data_args.max_eval_samples > 0: | |
max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples) | |
eval_dataset = eval_dataset.select(range(max_eval_samples)) | |
logger.debug(f"Example eval_dataset[0]: {eval_dataset[0]}") | |
tokenized_dataset = eval_dataset.map( | |
preprocess_reward_function, | |
batched=True, | |
num_proc=data_args.preprocessing_num_workers, | |
remove_columns=eval_dataset.column_names, | |
load_from_cache_file=not data_args.overwrite_cache, | |
desc="Running tokenizer on dataset", | |
) | |
eval_dataset = tokenized_dataset.filter( | |
lambda x: 0 < len(x['input_ids_rejected']) <= full_max_length and 0 < len( | |
x['input_ids_chosen']) <= full_max_length | |
) | |
logger.debug(f"Num eval_samples: {len(eval_dataset)}") | |
logger.debug("Tokenized eval example:") | |
logger.debug(tokenizer.decode(eval_dataset[0]['input_ids_chosen'])) | |
# Initialize our Trainer | |
if training_args.gradient_checkpointing: | |
model.gradient_checkpointing_enable() | |
model.config.use_cache = False | |
else: | |
model.config.use_cache = True | |
model.enable_input_require_grads() | |
if torch.cuda.device_count() > 1: | |
# Keeps Trainer from trying its own DataParallelism when more than 1 gpu is available | |
model.is_parallelizable = True | |
model.model_parallel = True | |
trainer = RewardTrainer( | |
model=model, | |
args=training_args, | |
train_dataset=train_dataset if training_args.do_train else None, | |
eval_dataset=eval_dataset if training_args.do_eval else None, | |
tokenizer=tokenizer, | |
compute_metrics=compute_metrics, | |
data_collator=RewardDataCollatorWithPadding( | |
tokenizer=tokenizer, max_length=full_max_length, padding="max_length" | |
), | |
) | |
# Training | |
if training_args.do_train: | |
logger.info("*** Train ***") | |
logger.debug(f"Train dataloader example: {next(iter(trainer.get_train_dataloader()))}") | |
checkpoint = None | |
if training_args.resume_from_checkpoint is not None: | |
checkpoint = training_args.resume_from_checkpoint | |
train_result = trainer.train(resume_from_checkpoint=checkpoint) | |
metrics = train_result.metrics | |
metrics["train_samples"] = max_train_samples | |
logger.debug(f"Training metrics: {metrics}") | |
trainer.log_metrics("train", metrics) | |
trainer.save_metrics("train", metrics) | |
trainer.save_state() | |
logger.info(f"Saving model checkpoint to {training_args.output_dir}") | |
save_model(training_args.output_dir, model, tokenizer, training_args) | |
# Evaluation | |
if training_args.do_eval and trainer.is_world_process_zero(): | |
logger.info("*** Evaluate ***") | |
metrics = trainer.evaluate() | |
metrics["eval_samples"] = max_eval_samples | |
try: | |
perplexity = math.exp(metrics["eval_loss"]) | |
except OverflowError: | |
perplexity = float("inf") | |
metrics["perplexity"] = perplexity | |
logger.debug(f"Eval metrics: {metrics}") | |
trainer.log_metrics("eval", metrics) | |
trainer.save_metrics("eval", metrics) | |
if __name__ == "__main__": | |
main() | |