diff --git a/scripts/evaluate.py b/scripts/evaluate.py index 4d323cb..9b1654b 100644 --- a/scripts/evaluate.py +++ b/scripts/evaluate.py @@ -12,6 +12,8 @@ NTupleNetworkTDPolicySmall, ) +plt.style.use("ggplot") + def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser( @@ -46,6 +48,16 @@ def parse_args() -> argparse.Namespace: default=42, help="random generator seed", ) + parser.add_argument( + "-t", + "--title", + help="figure title", + ) + parser.add_argument( + "-o", + "--output-path", + help="path to output png file", + ) args = parser.parse_args() return args @@ -57,6 +69,13 @@ def make_env(env_id: str) -> gym.Env: def make_policy(algo: str, trained_agent: str) -> NTupleNetworkBasePolicy: + """ + Makes the policy to evaluate. + + :param algo: Name of the algorithm. + :param trained_agent: Path to a trained agent. + :return: Policy. + """ algo_policy_map = { "ql": NTupleNetworkQLearningPolicy, "tdl": NTupleNetworkTDPolicy, @@ -66,45 +85,66 @@ def make_policy(algo: str, trained_agent: str) -> NTupleNetworkBasePolicy: return policy.load(trained_agent) -def evaluate() -> None: - args = parse_args() - - np.random.seed(args.seed) - env = make_env(env_id=args.env) - if args.algo is not None and args.trained_agent is not None: - policy = make_policy(algo=args.algo, trained_agent=args.trained_agent) - else: - policy = None +def run_episodes( + env: gym.Env, + policy: NTupleNetworkBasePolicy | None, + n_episodes: int, +) -> tuple[list[int], list[int], list[int], list[int]]: + """ + Runs episodes and record statistics. + :param env: Game environment. + :param policy: Policy or None for random policy. + :param n_episodes: Number of episodes. + :return: Lengths, rewards, max tiles and total score. + """ lengths = [] rewards = [] max_tiles = [] total_score = [] - # Run episodes - for _ in trange(args.n_episodes, desc="Episode"): + for _ in trange(n_episodes, desc="Episode", unit="episode"): _observation, info = env.reset() terminated = truncated = False + while not terminated and not truncated: if policy is None: action = env.action_space.sample() else: state = info["board"] action = policy.predict(state=state) + _observation, _reward, terminated, truncated, info = env.step(action) lengths.extend(info["episode"]["l"]) rewards.extend(info["episode"]["r"]) max_tiles.append(info["max"]) total_score.append(info["total_score"]) - env.reset() - - env.close() - # Plot results - plt.style.use("ggplot") + env.reset() + return lengths, rewards, max_tiles, total_score + + +def plot_statistics( + lengths: list[int], + rewards: list[int], + max_tiles: list[int], + total_score: list[int], + title: str | None = None, +) -> plt.Figure: + """ + Plots episode statistics. + + :param lengths: Lengths. + :param rewards: Rewards. + :param max_tiles: Maximum tiles reached. + :param total_score: Total game score. + :param title: Figure title. Default to None. + :return: Figure with statistics. + """ fig, axs = plt.subplots(2, 2) + axs[0, 0].hist(lengths) axs[0, 0].set_xlabel("Length") axs[0, 0].set_ylabel("Count") @@ -128,9 +168,40 @@ def evaluate() -> None: axs[1, 1].set_ylabel("Count") axs[1, 1].set_title("Score") + fig.suptitle(title) fig.tight_layout() + + return fig + + +def evaluate() -> None: + args = parse_args() + + np.random.seed(args.seed) + env = make_env(env_id=args.env) + if args.algo is not None and args.trained_agent is not None: + policy = make_policy(algo=args.algo, trained_agent=args.trained_agent) + else: + policy = None + + lengths, rewards, max_tiles, total_score = run_episodes( + env=env, + policy=policy, + n_episodes=args.n_episodes, + ) + env.close() + fig = plot_statistics( + lengths=lengths, + rewards=rewards, + max_tiles=max_tiles, + total_score=total_score, + title=args.title, + ) fig.show() + if args.output_path is not None: + fig.savefig(args.output_path) + if __name__ == "__main__": evaluate()