update script regarding to https://github.com/huggingface/transformers/pull/12608
Browse files- run_mlm_flax.py +26 -12
run_mlm_flax.py
CHANGED
@@ -431,7 +431,8 @@ if __name__ == "__main__":
|
|
431 |
total_length = len(concatenated_examples[list(examples.keys())[0]])
|
432 |
# We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
|
433 |
# customize this part to your needs.
|
434 |
-
|
|
|
435 |
# Split by chunks of max_len.
|
436 |
result = {
|
437 |
k: [t[i : i + max_seq_length] for i in range(0, total_length, max_seq_length)]
|
@@ -478,7 +479,14 @@ if __name__ == "__main__":
|
|
478 |
rng = jax.random.PRNGKey(training_args.seed)
|
479 |
dropout_rngs = jax.random.split(rng, jax.local_device_count())
|
480 |
|
481 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
482 |
|
483 |
# Store some constant
|
484 |
num_epochs = int(training_args.num_train_epochs)
|
@@ -513,17 +521,24 @@ if __name__ == "__main__":
|
|
513 |
return traverse_util.unflatten_dict(flat_mask)
|
514 |
|
515 |
# create adam optimizer
|
516 |
-
|
517 |
-
|
518 |
-
|
519 |
-
|
520 |
-
|
521 |
-
|
522 |
-
|
523 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
524 |
|
525 |
# Setup train state
|
526 |
-
state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=
|
527 |
|
528 |
# Define gradient update step fn
|
529 |
def train_step(state, batch, dropout_rng):
|
@@ -648,7 +663,6 @@ if __name__ == "__main__":
|
|
648 |
|
649 |
# Save metrics
|
650 |
if has_tensorboard and jax.process_index() == 0:
|
651 |
-
cur_step = epoch * (len(tokenized_datasets["train"]) // train_batch_size)
|
652 |
write_eval_metric(summary_writer, eval_metrics, cur_step)
|
653 |
|
654 |
if cur_step % training_args.save_steps == 0 and cur_step > 0:
|
|
|
431 |
total_length = len(concatenated_examples[list(examples.keys())[0]])
|
432 |
# We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
|
433 |
# customize this part to your needs.
|
434 |
+
if total_length >= max_seq_length:
|
435 |
+
total_length = (total_length // max_seq_length) * max_seq_length
|
436 |
# Split by chunks of max_len.
|
437 |
result = {
|
438 |
k: [t[i : i + max_seq_length] for i in range(0, total_length, max_seq_length)]
|
|
|
479 |
rng = jax.random.PRNGKey(training_args.seed)
|
480 |
dropout_rngs = jax.random.split(rng, jax.local_device_count())
|
481 |
|
482 |
+
if model_args.model_name_or_path:
|
483 |
+
model = FlaxAutoModelForMaskedLM.from_pretrained(
|
484 |
+
model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
|
485 |
+
)
|
486 |
+
else:
|
487 |
+
model = FlaxAutoModelForMaskedLM.from_config(
|
488 |
+
config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
|
489 |
+
)
|
490 |
|
491 |
# Store some constant
|
492 |
num_epochs = int(training_args.num_train_epochs)
|
|
|
521 |
return traverse_util.unflatten_dict(flat_mask)
|
522 |
|
523 |
# create adam optimizer
|
524 |
+
if training_args.adafactor:
|
525 |
+
# We use the default parameters here to initialize adafactor,
|
526 |
+
# For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74
|
527 |
+
optimizer = optax.adafactor(
|
528 |
+
learning_rate=linear_decay_lr_schedule_fn,
|
529 |
+
)
|
530 |
+
else:
|
531 |
+
optimizer = optax.adamw(
|
532 |
+
learning_rate=linear_decay_lr_schedule_fn,
|
533 |
+
b1=training_args.adam_beta1,
|
534 |
+
b2=training_args.adam_beta2,
|
535 |
+
eps=training_args.adam_epsilon,
|
536 |
+
weight_decay=training_args.weight_decay,
|
537 |
+
mask=decay_mask_fn,
|
538 |
+
)
|
539 |
|
540 |
# Setup train state
|
541 |
+
state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer)
|
542 |
|
543 |
# Define gradient update step fn
|
544 |
def train_step(state, batch, dropout_rng):
|
|
|
663 |
|
664 |
# Save metrics
|
665 |
if has_tensorboard and jax.process_index() == 0:
|
|
|
666 |
write_eval_metric(summary_writer, eval_metrics, cur_step)
|
667 |
|
668 |
if cur_step % training_args.save_steps == 0 and cur_step > 0:
|