Fix script
Browse files
src/run.sh
CHANGED
|
@@ -27,7 +27,7 @@ export LOGGING_STEPS=500
|
|
| 27 |
export EVAL_STEPS=2500
|
| 28 |
export SAVE_STEPS=2500
|
| 29 |
|
| 30 |
-
python src/
|
| 31 |
--output_dir="$OUTPUT_DIR" \
|
| 32 |
--train_file="$TRAIN_FILE" \
|
| 33 |
--validation_file="$VALIDATION_FILE" \
|
|
|
|
| 27 |
export EVAL_STEPS=2500
|
| 28 |
export SAVE_STEPS=2500
|
| 29 |
|
| 30 |
+
python src/run_recipe_nlg_flax.py \
|
| 31 |
--output_dir="$OUTPUT_DIR" \
|
| 32 |
--train_file="$TRAIN_FILE" \
|
| 33 |
--validation_file="$VALIDATION_FILE" \
|
src/{run_ed_recipe_nlg.py → run_recipe_nlg_flax.py}
RENAMED
|
@@ -779,7 +779,9 @@ def main():
|
|
| 779 |
# Save metrics
|
| 780 |
train_metric = unreplicate(train_metric)
|
| 781 |
train_time += time.time() - train_start
|
|
|
|
| 782 |
if has_tensorboard and jax.process_index() == 0:
|
|
|
|
| 783 |
write_train_metric(summary_writer, train_metrics, train_time, cur_step)
|
| 784 |
|
| 785 |
epochs.write(
|
|
@@ -789,6 +791,7 @@ def main():
|
|
| 789 |
train_metrics = []
|
| 790 |
|
| 791 |
if cur_step % training_args.eval_steps == 0 and cur_step > 0 and training_args.do_eval:
|
|
|
|
| 792 |
eval_metrics = []
|
| 793 |
eval_preds = []
|
| 794 |
eval_labels = []
|
|
@@ -827,20 +830,27 @@ def main():
|
|
| 827 |
|
| 828 |
# Save metrics
|
| 829 |
if has_tensorboard and jax.process_index() == 0:
|
| 830 |
-
|
|
|
|
| 831 |
write_eval_metric(summary_writer, eval_metrics, cur_step)
|
| 832 |
|
| 833 |
if cur_step % training_args.save_steps == 0 and cur_step > 0:
|
| 834 |
-
|
|
|
|
| 835 |
if jax.process_index() == 0:
|
| 836 |
-
|
| 837 |
-
params = jax.device_get(unreplicate(state.params))
|
| 838 |
model.save_pretrained(
|
| 839 |
training_args.output_dir,
|
| 840 |
params=params,
|
| 841 |
push_to_hub=training_args.push_to_hub,
|
| 842 |
commit_message=f"Saving weights and logs of step {cur_step}",
|
| 843 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 844 |
|
| 845 |
|
| 846 |
if __name__ == "__main__":
|
|
|
|
| 779 |
# Save metrics
|
| 780 |
train_metric = unreplicate(train_metric)
|
| 781 |
train_time += time.time() - train_start
|
| 782 |
+
|
| 783 |
if has_tensorboard and jax.process_index() == 0:
|
| 784 |
+
logger.info(f"*** Writing training summary after {cur_step} steps ***")
|
| 785 |
write_train_metric(summary_writer, train_metrics, train_time, cur_step)
|
| 786 |
|
| 787 |
epochs.write(
|
|
|
|
| 791 |
train_metrics = []
|
| 792 |
|
| 793 |
if cur_step % training_args.eval_steps == 0 and cur_step > 0 and training_args.do_eval:
|
| 794 |
+
logger.info(f"*** Evaluation after {cur_step} steps ***")
|
| 795 |
eval_metrics = []
|
| 796 |
eval_preds = []
|
| 797 |
eval_labels = []
|
|
|
|
| 830 |
|
| 831 |
# Save metrics
|
| 832 |
if has_tensorboard and jax.process_index() == 0:
|
| 833 |
+
logger.info(f"*** Writing evaluation summary after {cur_step} steps ***")
|
| 834 |
+
# cur_step = epoch * (len(train_dataset) // train_batch_size)
|
| 835 |
write_eval_metric(summary_writer, eval_metrics, cur_step)
|
| 836 |
|
| 837 |
if cur_step % training_args.save_steps == 0 and cur_step > 0:
|
| 838 |
+
logger.info(f"*** Saving checkpoints after {cur_step} steps ***")
|
| 839 |
+
# save checkpoint after each steps and push checkpoint to the hub
|
| 840 |
if jax.process_index() == 0:
|
| 841 |
+
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
|
| 842 |
+
# params = jax.device_get(unreplicate(state.params))
|
| 843 |
model.save_pretrained(
|
| 844 |
training_args.output_dir,
|
| 845 |
params=params,
|
| 846 |
push_to_hub=training_args.push_to_hub,
|
| 847 |
commit_message=f"Saving weights and logs of step {cur_step}",
|
| 848 |
)
|
| 849 |
+
tokenizer.save_pretrained(
|
| 850 |
+
training_args.output_dir,
|
| 851 |
+
push_to_hub=training_args.push_to_hub,
|
| 852 |
+
commit_message=f"Saving tokenizer step {cur_step}",
|
| 853 |
+
)
|
| 854 |
|
| 855 |
|
| 856 |
if __name__ == "__main__":
|