Skip to content

Commit

Permalink
fix load tdl-small policy
Browse files Browse the repository at this point in the history
  • Loading branch information
Quentin18 committed Oct 24, 2023
1 parent 9eb4cf0 commit d09c78c
Showing 1 changed file with 6 additions and 0 deletions.
6 changes: 6 additions & 0 deletions src/gymnasium_2048/agents/ntuple/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,3 +249,9 @@ def _get_tuples(state: np.ndarray) -> Sequence[Sequence[int]]:
*get_all_straight_3_tuples(state=state),
*get_all_corners_3_tuples(state=state),
]

@classmethod
def load(cls, path: str | bytes | os.PathLike) -> NTupleNetworkBasePolicy:
policy = NTupleNetworkTDPolicySmall()
policy.net = NTupleNetwork.load(path=path)
return policy

0 comments on commit d09c78c

Please sign in to comment.