Skip to content

Commit

Permalink
docs: add docstrings in train script
Browse files Browse the repository at this point in the history
  • Loading branch information
Quentin18 committed Dec 2, 2023
1 parent 0fb3e14 commit f871ea6
Showing 1 changed file with 19 additions and 0 deletions.
19 changes: 19 additions & 0 deletions scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,12 @@ def evaluate(


def log_eval_metrics(episode: int, metrics: dict[str, Any]) -> None:
"""
Logs the evaluation metrics for an episode.
:param episode: Episode number.
:param metrics: Performance measures.
"""
logger.info(
"episode %d: winning rate = %.2f, mean score = %.2f, max tile = %d",
episode,
Expand All @@ -193,6 +199,12 @@ def log_eval_metrics(episode: int, metrics: dict[str, Any]) -> None:


def save_best_policy(out_dir: str, policy: NTupleNetworkBasePolicy) -> None:
"""
Saves the best policy.
:param out_dir: Path to output directory.
:param policy: Policy to save.
"""
best_model_path = os.path.join(out_dir, "best_n_tuple_network_policy.zip")
logger.info("new best model saved to %s", best_model_path)
policy.save(path=best_model_path)
Expand All @@ -203,6 +215,13 @@ def save_checkpoint(
out_dir: str,
policy: NTupleNetworkBasePolicy,
) -> None:
"""
Saves a checkpoint.
:param episode: Episode number.
:param out_dir: Output directory.
:param policy: Policy to save.
"""
checkpoint_path = os.path.join(out_dir, f"checkpoint_episode_{episode}.zip")
logger.info("checkpoint saved to %s", checkpoint_path)
policy.save(path=checkpoint_path)
Expand Down

0 comments on commit f871ea6

Please sign in to comment.