diff --git a/scripts/random_policy.py b/scripts/random_policy.py index ffd8cd7..725a88c 100644 --- a/scripts/random_policy.py +++ b/scripts/random_policy.py @@ -1,15 +1,24 @@ import argparse import gymnasium as gym +from tqdm import trange def parse_args() -> argparse.Namespace: - parser = argparse.ArgumentParser(description="Random policy") + parser = argparse.ArgumentParser( + description="Random policy", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "--env", + default="gymnasium_2048:gymnasium_2048/TwentyFortyEight-v0", + help="environment id", + ) parser.add_argument( "--seed", type=int, default=42, - help="seed", + help="random generator seed", ) parser.add_argument( "--n-timesteps", @@ -21,16 +30,14 @@ def parse_args() -> argparse.Namespace: return args -def main() -> None: +def random_policy() -> None: args = parse_args() - env = gym.make( - "gymnasium_2048:gymnasium_2048/TwentyFortyEight-v0", - render_mode="human", - ) + + env = gym.make(args.env, render_mode="human") env.reset(seed=args.seed) - for _ in range(args.n_timesteps): + for _ in trange(args.n_timesteps, desc="Random policy"): action = env.action_space.sample() _, _, terminated, truncated, _ = env.step(action) @@ -41,4 +48,4 @@ def main() -> None: if __name__ == "__main__": - main() + random_policy()