Skip to content

Commit

Permalink
style: black format
Browse files Browse the repository at this point in the history
  • Loading branch information
Quentin18 committed Jan 26, 2024
1 parent 582650c commit f99dc94
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 14 deletions.
14 changes: 6 additions & 8 deletions src/gymnasium_2048/agents/ntuple/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,12 @@ def get_all_rectangles_tuples(state: np.ndarray) -> Sequence[Sequence[int]]:
# square
for row in range(state.shape[0] - 1):
for col in range(state.shape[1] - 1):
tuples.append(
(
state[row, col],
state[row, col + 1],
state[row + 1, col + 1],
state[row + 1, col],
)
)
tuples.append((
state[row, col],
state[row, col + 1],
state[row + 1, col + 1],
state[row + 1, col],
))

return tuples

Expand Down
12 changes: 6 additions & 6 deletions src/gymnasium_2048/agents/ntuple/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,9 @@ def predict(self, state: np.ndarray) -> int:
:param state: The board state.
:return: Next action to play.
"""
return np.argmax([
self.evaluate(state=state, action=action) for action in range(4)
])
return np.argmax(
[self.evaluate(state=state, action=action) for action in range(4)]
)

@abstractmethod
def save(self, path: str | pathlib.Path | io.BufferedIOBase) -> None:
Expand Down Expand Up @@ -207,9 +207,9 @@ def learn(
after_state_tuples = self._get_tuples(state=after_state)
after_state_value = self.net.predict(tuples=after_state_tuples)

next_action = np.argmax([
self.evaluate(state=next_state, action=a) for a in range(4)
])
next_action = np.argmax(
[self.evaluate(state=next_state, action=a) for a in range(4)]
)
next_after_state, next_reward, is_legal = TwentyFortyEightEnv.apply_action(
board=next_state,
action=next_action,
Expand Down

0 comments on commit f99dc94

Please sign in to comment.