Skip to content

Commit

Permalink
tidy
Browse files Browse the repository at this point in the history
  • Loading branch information
sotetsuk committed Dec 2, 2024
1 parent d644813 commit 85af893
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions pgx/_src/games/hex.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,11 @@ def init(self) -> GameState:
return GameState()

def step(self, state: GameState, action: Array) -> GameState:
x = jax.lax.cond(
return jax.lax.cond(
action != self.size * self.size,
lambda: partial(_step, size=self.size)(state, action),
lambda: partial(_swap, size=self.size)(state),
)
terminated = _is_terminal(x, self.size)
return x._replace(terminated=terminated)

def observe(self, state: GameState, color: Optional[Array] = None) -> Array:
return _observe(state, color, self.size)
Expand Down Expand Up @@ -98,11 +96,15 @@ def merge(i, b):
)

board = jax.lax.fori_loop(0, 6, merge, board)
return state._replace(

state = state._replace(
step_count=state.step_count + 1,
board=board * -1,
)

terminated = _is_terminal(state, size)
return state._replace(terminated=terminated)


def _swap(state: GameState, size: int) -> GameState:
ix = jnp.nonzero(state.board, size=1)[0]
Expand Down

0 comments on commit 85af893

Please sign in to comment.