Skip to content

Commit

Permalink
enhance is_terminal by using last action
Browse files Browse the repository at this point in the history
  • Loading branch information
sotetsuk committed Dec 2, 2024
1 parent 85af893 commit 4d9ca17
Showing 1 changed file with 4 additions and 9 deletions.
13 changes: 4 additions & 9 deletions pgx/_src/games/hex.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,17 +69,14 @@ def is_terminal(self, state: GameState) -> Array:
# ...


def _is_terminal(state: GameState, size: int) -> Array:
def _is_terminal(state: GameState, action: Array, size: int) -> Array:
top, bottom = jax.lax.cond(
state.color == 0,
lambda: (state.board[::size], state.board[size - 1 :: size]),
lambda: (state.board[:size], state.board[-size:]),
)

def check_same_id_exist(_id):
return (_id < 0) & (_id == bottom).any()

return jax.vmap(check_same_id_exist)(top).any()
target_id = state.board[action] # target_id != 0
return (top == target_id).any() & (bottom == target_id).any()


def _step(state: GameState, action: Array, size: int) -> GameState:
Expand All @@ -101,9 +98,7 @@ def merge(i, b):
step_count=state.step_count + 1,
board=board * -1,
)

terminated = _is_terminal(state, size)
return state._replace(terminated=terminated)
return state._replace(terminated=_is_terminal(state, action, size))


def _swap(state: GameState, size: int) -> GameState:
Expand Down

0 comments on commit 4d9ca17

Please sign in to comment.