From 81f967016a71b1dbbedce8bad403832a639b9f27 Mon Sep 17 00:00:00 2001 From: Quentin18 Date: Sat, 2 Dec 2023 09:24:47 +0100 Subject: [PATCH] feat: add expectimax search --- scripts/enjoy.py | 2 + scripts/evaluate.py | 2 + src/gymnasium_2048/agents/ntuple/search.py | 68 ++++++++++++++++++++++ 3 files changed, 72 insertions(+) create mode 100644 src/gymnasium_2048/agents/ntuple/search.py diff --git a/scripts/enjoy.py b/scripts/enjoy.py index 064e9c7..13ce7e0 100644 --- a/scripts/enjoy.py +++ b/scripts/enjoy.py @@ -11,6 +11,7 @@ NTupleNetworkTDPolicy, NTupleNetworkTDPolicySmall, ) +from gymnasium_2048.agents.ntuple.search import ExpectimaxSearch def parse_args() -> argparse.Namespace: @@ -117,6 +118,7 @@ def enjoy() -> None: env = gym.make(args.env, render_mode="human") policy = make_policy(algo=args.algo, trained_agent=args.trained_agent) + policy = ExpectimaxSearch(policy=policy) for _ in trange(args.n_episodes, desc="Enjoy"): play_game(env=env, policy=policy) diff --git a/scripts/evaluate.py b/scripts/evaluate.py index 9b1654b..5f69940 100644 --- a/scripts/evaluate.py +++ b/scripts/evaluate.py @@ -11,6 +11,7 @@ NTupleNetworkTDPolicy, NTupleNetworkTDPolicySmall, ) +from gymnasium_2048.agents.ntuple.search import ExpectimaxSearch plt.style.use("ggplot") @@ -181,6 +182,7 @@ def evaluate() -> None: env = make_env(env_id=args.env) if args.algo is not None and args.trained_agent is not None: policy = make_policy(algo=args.algo, trained_agent=args.trained_agent) + policy = ExpectimaxSearch(policy=policy) else: policy = None diff --git a/src/gymnasium_2048/agents/ntuple/search.py b/src/gymnasium_2048/agents/ntuple/search.py new file mode 100644 index 0000000..05f21fd --- /dev/null +++ b/src/gymnasium_2048/agents/ntuple/search.py @@ -0,0 +1,68 @@ +import numpy as np + +from gymnasium_2048.agents.ntuple.policy import NTupleNetworkBasePolicy +from gymnasium_2048.envs import TwentyFortyEightEnv + + +class ExpectimaxSearch: + def __init__( + self, + policy: NTupleNetworkBasePolicy, + max_depth: int = 3, + ) -> None: + self.policy = policy + self.max_depth = max_depth + self.min_value = 0.0 + + def _evaluate(self, state: np.ndarray) -> tuple[float, int]: + values = [ + self.policy.evaluate(state=state, action=action) for action in range(4) + ] + max_action = np.argmax(values) + return max(self.min_value, values[max_action]), max_action + + def _maximize(self, state: np.ndarray, depth: int) -> tuple[float, int]: + if depth >= self.max_depth: + return self._evaluate(state=state) + + max_value = self.min_value + max_action = 0 + + for action in range(4): + after_state, _, is_legal = TwentyFortyEightEnv.apply_action( + board=state, + action=action, + ) + if not is_legal: + continue + + value = self._chance(after_state=after_state, depth=depth + 1) + if value > max_value: + max_value = value + max_action = action + + return max_value, max_action + + def _chance(self, after_state: np.ndarray, depth: int) -> float: + if depth >= self.max_depth: + return self._evaluate(state=after_state)[0] + + values, weights = [], [] + + for row in range(after_state.shape[0]): + for col in range(after_state.shape[1]): + if after_state[row, col] != 0: + continue + + for value, prob in ((1, 0.9), (2, 0.1)): + after_state[row, col] = value + values.append(self._maximize(state=after_state, depth=depth + 1)[0]) + weights.append(prob) + after_state[row, col] = 0 + + return np.average(values, weights=weights) + + def predict(self, state: np.ndarray) -> int: + value, action = self._maximize(state=state, depth=0) + print(value, action) + return action