From f871ea66a7f30ad6e087e69c359ddb4b0f9373cb Mon Sep 17 00:00:00 2001 From: Quentin18 Date: Sat, 2 Dec 2023 09:33:31 +0100 Subject: [PATCH] docs: add docstrings in train script --- scripts/train.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/scripts/train.py b/scripts/train.py index a6f1fa8..3984297 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -183,6 +183,12 @@ def evaluate( def log_eval_metrics(episode: int, metrics: dict[str, Any]) -> None: + """ + Logs the evaluation metrics for an episode. + + :param episode: Episode number. + :param metrics: Performance measures. + """ logger.info( "episode %d: winning rate = %.2f, mean score = %.2f, max tile = %d", episode, @@ -193,6 +199,12 @@ def log_eval_metrics(episode: int, metrics: dict[str, Any]) -> None: def save_best_policy(out_dir: str, policy: NTupleNetworkBasePolicy) -> None: + """ + Saves the best policy. + + :param out_dir: Path to output directory. + :param policy: Policy to save. + """ best_model_path = os.path.join(out_dir, "best_n_tuple_network_policy.zip") logger.info("new best model saved to %s", best_model_path) policy.save(path=best_model_path) @@ -203,6 +215,13 @@ def save_checkpoint( out_dir: str, policy: NTupleNetworkBasePolicy, ) -> None: + """ + Saves a checkpoint. + + :param episode: Episode number. + :param out_dir: Output directory. + :param policy: Policy to save. + """ checkpoint_path = os.path.join(out_dir, f"checkpoint_episode_{episode}.zip") logger.info("checkpoint saved to %s", checkpoint_path) policy.save(path=checkpoint_path)