diff --git a/pgx/othello.py b/pgx/othello.py index 449e5633f..bc7183ca7 100644 --- a/pgx/othello.py +++ b/pgx/othello.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import NamedTuple + import jax import jax.numpy as jnp @@ -23,17 +25,8 @@ TRUE = jnp.bool_(True) -@dataclass -class State(core.State): - current_player: Array = jnp.int32(0) - observation: Array = jnp.zeros((8, 8, 2), dtype=jnp.bool_) - rewards: Array = jnp.float32([0.0, 0.0]) - terminated: Array = FALSE - truncated: Array = FALSE - legal_action_mask: Array = jnp.ones(64 + 1, dtype=jnp.bool_) - _step_count: Array = jnp.int32(0) - # --- Othello specific --- - _turn: Array = jnp.int32(0) +class GameState(NamedTuple): + turn: Array = jnp.int32(0) # 8x8 board # [[ 0, 1, 2, 3, 4, 5, 6, 7], # [ 8, 9, 10, 11, 12, 13, 14, 15], @@ -43,8 +36,20 @@ class State(core.State): # [40, 41, 42, 43, 44, 45, 46, 47], # [48, 49, 50, 51, 52, 53, 54, 55], # [56, 57, 58, 59, 60, 61, 62, 63]] - _board: Array = jnp.zeros(64, jnp.int32) # -1(opp), 0(empty), 1(self) - _passed: Array = FALSE + board: Array = jnp.zeros(64, jnp.int32) + passed: Array = jnp.bool_(False) + + +@dataclass +class State(core.State): + current_player: Array = jnp.int32(0) + observation: Array = jnp.zeros((8, 8, 2), dtype=jnp.bool_) + rewards: Array = jnp.float32([0.0, 0.0]) + terminated: Array = FALSE + truncated: Array = FALSE + legal_action_mask: Array = jnp.ones(64 + 1, dtype=jnp.bool_) + _step_count: Array = jnp.int32(0) + _x: GameState = GameState() @property def env_id(self) -> core.EnvId: @@ -107,7 +112,10 @@ def _init(rng: PRNGKey) -> State: current_player = jnp.int32(jax.random.bernoulli(rng)) return State( current_player=current_player, - _board=jnp.zeros(64, dtype=jnp.int32).at[28].set(1).at[35].set(1).at[27].set(-1).at[36].set(-1), + _x=GameState( + turn=0, + board=jnp.zeros(64, dtype=jnp.int32).at[28].set(1).at[35].set(1).at[27].set(-1).at[36].set(-1), + ), legal_action_mask=jnp.zeros(64 + 1, dtype=jnp.bool_) .at[19] .set(TRUE) @@ -121,7 +129,7 @@ def _init(rng: PRNGKey) -> State: def _step(state, action): - board = state._board + board = state._x.board my = board > 0 opp = board < 0 @@ -167,19 +175,21 @@ def _make_legal(i, legal): legal_action = jax.lax.fori_loop(0, 8, _make_legal, jnp.zeros(64, dtype=jnp.bool_)) reward, terminated = jax.lax.cond( - ((jnp.count_nonzero(my | opp) == 64) | ~opp.any() | (state._passed & (action == 64))), + ((jnp.count_nonzero(my | opp) == 64) | ~opp.any() | (state._x.passed & (action == 64))), lambda: (_get_reward(my, opp, state.current_player), TRUE), lambda: (jnp.zeros(2, jnp.float32), FALSE), ) return state.replace( current_player=1 - state.current_player, - _turn=1 - state._turn, + _x=GameState( + turn=1 - state._x.turn, + board=-jnp.where(jnp.int32(opp), -1, jnp.int32(my)), + passed=action == 64, + ), legal_action_mask=state.legal_action_mask.at[:64].set(legal_action).at[64].set(~legal_action.any()), rewards=reward, terminated=terminated, - _board=-jnp.where(jnp.int32(opp), -1, jnp.int32(my)), - _passed=action == 64, ) @@ -208,8 +218,8 @@ def _get_reward(my, opp, curr_player): def _observe(state, player_id) -> Array: board = jax.lax.cond( player_id == state.current_player, - lambda: state._board.reshape((8, 8)), - lambda: (state._board * -1).reshape((8, 8)), + lambda: state._x.board.reshape((8, 8)), + lambda: (state._x.board * -1).reshape((8, 8)), ) def make(color): @@ -219,4 +229,4 @@ def make(color): def _get_abs_board(state): - return jax.lax.cond(state._turn == 0, lambda: state._board, lambda: state._board * -1) + return jax.lax.cond(state._x.turn == 0, lambda: state._x.board, lambda: state._x.board * -1) diff --git a/tests/test_othello.py b/tests/test_othello.py index 1f7b69cde..cae8981e8 100644 --- a/tests/test_othello.py +++ b/tests/test_othello.py @@ -37,7 +37,7 @@ def test_step(): 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) # fmt:on - assert jnp.all(state._board == expected) + assert jnp.all(state._x.board == expected) def test_terminated():