adjust batch_size and update run_mlm_flax script
Browse files- events.out.tfevents.1625581996.t1v-n-bf8aeee7-w-0.10129.3.v2 +0 -0
- run.sh +2 -2
- run_mlm_flax.py +38 -38
events.out.tfevents.1625581996.t1v-n-bf8aeee7-w-0.10129.3.v2
DELETED
Binary file (883 kB)
|
|
run.sh
CHANGED
@@ -9,8 +9,8 @@ python3 run_mlm_flax.py \
|
|
9 |
--dataset_config_name="unshuffled_deduplicated_th" \
|
10 |
--max_seq_length="128" \
|
11 |
--preprocessing_num_workers="64" \
|
12 |
-
--per_device_train_batch_size="
|
13 |
-
--per_device_eval_batch_size="
|
14 |
--learning_rate="2e-4" \
|
15 |
--warmup_steps="1000" \
|
16 |
--overwrite_output_dir \
|
|
|
9 |
--dataset_config_name="unshuffled_deduplicated_th" \
|
10 |
--max_seq_length="128" \
|
11 |
--preprocessing_num_workers="64" \
|
12 |
+
--per_device_train_batch_size="64" \
|
13 |
+
--per_device_eval_batch_size="64" \
|
14 |
--learning_rate="2e-4" \
|
15 |
--warmup_steps="1000" \
|
16 |
--overwrite_output_dir \
|
run_mlm_flax.py
CHANGED
@@ -606,7 +606,7 @@ if __name__ == "__main__":
|
|
606 |
state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs)
|
607 |
train_metrics.append(train_metric)
|
608 |
|
609 |
-
cur_step = epoch * num_train_samples + step
|
610 |
|
611 |
if cur_step % training_args.logging_steps == 0 and cur_step > 0:
|
612 |
# Save metrics
|
@@ -621,43 +621,43 @@ if __name__ == "__main__":
|
|
621 |
|
622 |
train_metrics = []
|
623 |
|
624 |
-
|
625 |
-
|
626 |
-
|
627 |
-
|
|
|
628 |
|
629 |
-
|
630 |
-
|
631 |
-
|
632 |
-
|
633 |
|
634 |
-
|
635 |
-
|
636 |
-
|
637 |
-
|
638 |
-
|
639 |
-
# normalize eval metrics
|
640 |
-
eval_metrics = get_metrics(eval_metrics)
|
641 |
-
eval_metrics = jax.tree_map(jnp.sum, eval_metrics)
|
642 |
-
eval_normalizer = eval_metrics.pop("normalizer")
|
643 |
-
eval_metrics = jax.tree_map(lambda x: x / eval_normalizer, eval_metrics)
|
644 |
-
|
645 |
-
# Update progress bar
|
646 |
-
epochs.desc = (
|
647 |
-
f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {eval_metrics['loss']}, Acc: {eval_metrics['accuracy']})"
|
648 |
-
)
|
649 |
|
650 |
-
|
651 |
-
|
652 |
-
|
653 |
-
|
654 |
-
|
655 |
-
|
656 |
-
|
657 |
-
|
658 |
-
|
659 |
-
|
660 |
-
|
661 |
-
|
662 |
-
|
663 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
606 |
state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs)
|
607 |
train_metrics.append(train_metric)
|
608 |
|
609 |
+
cur_step = epoch * (num_train_samples // train_batch_size) + step
|
610 |
|
611 |
if cur_step % training_args.logging_steps == 0 and cur_step > 0:
|
612 |
# Save metrics
|
|
|
621 |
|
622 |
train_metrics = []
|
623 |
|
624 |
+
if cur_step % training_args.eval_steps == 0 and cur_step > 0:
|
625 |
+
# ======================== Evaluating ==============================
|
626 |
+
num_eval_samples = len(tokenized_datasets["validation"])
|
627 |
+
eval_samples_idx = jnp.arange(num_eval_samples)
|
628 |
+
eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
|
629 |
|
630 |
+
eval_metrics = []
|
631 |
+
for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
|
632 |
+
samples = [tokenized_datasets["validation"][int(idx)] for idx in batch_idx]
|
633 |
+
model_inputs = data_collator(samples, pad_to_multiple_of=16)
|
634 |
|
635 |
+
# Model forward
|
636 |
+
model_inputs = shard(model_inputs.data)
|
637 |
+
metrics = p_eval_step(state.params, model_inputs)
|
638 |
+
eval_metrics.append(metrics)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
639 |
|
640 |
+
# normalize eval metrics
|
641 |
+
eval_metrics = get_metrics(eval_metrics)
|
642 |
+
eval_metrics = jax.tree_map(jnp.sum, eval_metrics)
|
643 |
+
eval_normalizer = eval_metrics.pop("normalizer")
|
644 |
+
eval_metrics = jax.tree_map(lambda x: x / eval_normalizer, eval_metrics)
|
645 |
+
|
646 |
+
# Update progress bar
|
647 |
+
epochs.desc = f"Step... ({cur_step} | Loss: {eval_metrics['loss']}, Acc: {eval_metrics['accuracy']})"
|
648 |
+
|
649 |
+
# Save metrics
|
650 |
+
if has_tensorboard and jax.process_index() == 0:
|
651 |
+
cur_step = epoch * (len(tokenized_datasets["train"]) // train_batch_size)
|
652 |
+
write_eval_metric(summary_writer, eval_metrics, cur_step)
|
653 |
+
|
654 |
+
if cur_step % training_args.save_steps == 0 and cur_step > 0:
|
655 |
+
# save checkpoint after each epoch and push checkpoint to the hub
|
656 |
+
if jax.process_index() == 0:
|
657 |
+
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
|
658 |
+
model.save_pretrained(
|
659 |
+
training_args.output_dir,
|
660 |
+
params=params,
|
661 |
+
push_to_hub=training_args.push_to_hub,
|
662 |
+
commit_message=f"Saving weights and logs of step {cur_step}",
|
663 |
+
)
|