From c6962ba04e4b490690f1e48cd03b5e08d0c83d80 Mon Sep 17 00:00:00 2001 From: Quentin18 Date: Mon, 20 Nov 2023 08:54:57 +0100 Subject: [PATCH] update evaluate progress bar --- scripts/train.py | 21 +++++++-------------- 1 file changed, 7 insertions(+), 14 deletions(-) diff --git a/scripts/train.py b/scripts/train.py index ad128d1..a6f1fa8 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -148,7 +148,6 @@ def evaluate( env: gym.Env, policy: NTupleNetworkBasePolicy, eval_episodes: int, - show_progress: bool = False, ) -> dict[str, Any]: """ Evaluates the performance of the current policy. @@ -162,24 +161,19 @@ def evaluate( :param env: Game environment. :param policy: Policy to evaluate. :param eval_episodes: Number of games to play. - :param show_progress: True to show progress, False otherwise. :return: Performance measures. """ winning_rate = 0 total_score = 0 max_tile = 0 - iterator = ( - trange(eval_episodes, desc="Evaluate", unit="episode", leave=False) - if show_progress - else range(eval_episodes) - ) - - for _ in iterator: - info = play_game(env=env, policy=policy) - winning_rate += int(2 ** info["max"] >= 2048) - total_score += info["total_score"] - max_tile = max(max_tile, 2 ** info["max"]) + with trange(eval_episodes, desc="Evaluate", unit="episode", leave=False) as pbar: + for _ in pbar: + info = play_game(env=env, policy=policy) + winning_rate += int(2 ** info["max"] >= 2048) + total_score += info["total_score"] + max_tile = max(max_tile, 2 ** info["max"]) + pbar.set_postfix({"max_tile": max_tile}) return { "winning_rate": winning_rate / eval_episodes, @@ -239,7 +233,6 @@ def train() -> None: env=env, policy=policy, eval_episodes=args.eval_episodes, - show_progress=True, ) log_eval_metrics(episode=e, metrics=metrics)