Skip to content

Commit

Permalink
update evaluate progress bar
Browse files Browse the repository at this point in the history
  • Loading branch information
Quentin18 committed Nov 20, 2023
1 parent 2bdf3c7 commit c6962ba
Showing 1 changed file with 7 additions and 14 deletions.
21 changes: 7 additions & 14 deletions scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,6 @@ def evaluate(
env: gym.Env,
policy: NTupleNetworkBasePolicy,
eval_episodes: int,
show_progress: bool = False,
) -> dict[str, Any]:
"""
Evaluates the performance of the current policy.
Expand All @@ -162,24 +161,19 @@ def evaluate(
:param env: Game environment.
:param policy: Policy to evaluate.
:param eval_episodes: Number of games to play.
:param show_progress: True to show progress, False otherwise.
:return: Performance measures.
"""
winning_rate = 0
total_score = 0
max_tile = 0

iterator = (
trange(eval_episodes, desc="Evaluate", unit="episode", leave=False)
if show_progress
else range(eval_episodes)
)

for _ in iterator:
info = play_game(env=env, policy=policy)
winning_rate += int(2 ** info["max"] >= 2048)
total_score += info["total_score"]
max_tile = max(max_tile, 2 ** info["max"])
with trange(eval_episodes, desc="Evaluate", unit="episode", leave=False) as pbar:
for _ in pbar:
info = play_game(env=env, policy=policy)
winning_rate += int(2 ** info["max"] >= 2048)
total_score += info["total_score"]
max_tile = max(max_tile, 2 ** info["max"])
pbar.set_postfix({"max_tile": max_tile})

return {
"winning_rate": winning_rate / eval_episodes,
Expand Down Expand Up @@ -239,7 +233,6 @@ def train() -> None:
env=env,
policy=policy,
eval_episodes=args.eval_episodes,
show_progress=True,
)
log_eval_metrics(episode=e, metrics=metrics)

Expand Down

0 comments on commit c6962ba

Please sign in to comment.