Skip to content

Commit

Permalink
[Othello] Extract game specific attributes (#1296)
Browse files Browse the repository at this point in the history
  • Loading branch information
sotetsuk authored Dec 6, 2024
1 parent 17ea895 commit 7932e97
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 23 deletions.
54 changes: 32 additions & 22 deletions pgx/othello.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import NamedTuple

import jax
import jax.numpy as jnp

Expand All @@ -23,17 +25,8 @@
TRUE = jnp.bool_(True)


@dataclass
class State(core.State):
current_player: Array = jnp.int32(0)
observation: Array = jnp.zeros((8, 8, 2), dtype=jnp.bool_)
rewards: Array = jnp.float32([0.0, 0.0])
terminated: Array = FALSE
truncated: Array = FALSE
legal_action_mask: Array = jnp.ones(64 + 1, dtype=jnp.bool_)
_step_count: Array = jnp.int32(0)
# --- Othello specific ---
_turn: Array = jnp.int32(0)
class GameState(NamedTuple):
turn: Array = jnp.int32(0)
# 8x8 board
# [[ 0, 1, 2, 3, 4, 5, 6, 7],
# [ 8, 9, 10, 11, 12, 13, 14, 15],
Expand All @@ -43,8 +36,20 @@ class State(core.State):
# [40, 41, 42, 43, 44, 45, 46, 47],
# [48, 49, 50, 51, 52, 53, 54, 55],
# [56, 57, 58, 59, 60, 61, 62, 63]]
_board: Array = jnp.zeros(64, jnp.int32) # -1(opp), 0(empty), 1(self)
_passed: Array = FALSE
board: Array = jnp.zeros(64, jnp.int32)
passed: Array = jnp.bool_(False)


@dataclass
class State(core.State):
current_player: Array = jnp.int32(0)
observation: Array = jnp.zeros((8, 8, 2), dtype=jnp.bool_)
rewards: Array = jnp.float32([0.0, 0.0])
terminated: Array = FALSE
truncated: Array = FALSE
legal_action_mask: Array = jnp.ones(64 + 1, dtype=jnp.bool_)
_step_count: Array = jnp.int32(0)
_x: GameState = GameState()

@property
def env_id(self) -> core.EnvId:
Expand Down Expand Up @@ -107,7 +112,10 @@ def _init(rng: PRNGKey) -> State:
current_player = jnp.int32(jax.random.bernoulli(rng))
return State(
current_player=current_player,
_board=jnp.zeros(64, dtype=jnp.int32).at[28].set(1).at[35].set(1).at[27].set(-1).at[36].set(-1),
_x=GameState(
turn=0,
board=jnp.zeros(64, dtype=jnp.int32).at[28].set(1).at[35].set(1).at[27].set(-1).at[36].set(-1),
),
legal_action_mask=jnp.zeros(64 + 1, dtype=jnp.bool_)
.at[19]
.set(TRUE)
Expand All @@ -121,7 +129,7 @@ def _init(rng: PRNGKey) -> State:


def _step(state, action):
board = state._board
board = state._x.board
my = board > 0
opp = board < 0

Expand Down Expand Up @@ -167,19 +175,21 @@ def _make_legal(i, legal):
legal_action = jax.lax.fori_loop(0, 8, _make_legal, jnp.zeros(64, dtype=jnp.bool_))

reward, terminated = jax.lax.cond(
((jnp.count_nonzero(my | opp) == 64) | ~opp.any() | (state._passed & (action == 64))),
((jnp.count_nonzero(my | opp) == 64) | ~opp.any() | (state._x.passed & (action == 64))),
lambda: (_get_reward(my, opp, state.current_player), TRUE),
lambda: (jnp.zeros(2, jnp.float32), FALSE),
)

return state.replace(
current_player=1 - state.current_player,
_turn=1 - state._turn,
_x=GameState(
turn=1 - state._x.turn,
board=-jnp.where(jnp.int32(opp), -1, jnp.int32(my)),
passed=action == 64,
),
legal_action_mask=state.legal_action_mask.at[:64].set(legal_action).at[64].set(~legal_action.any()),
rewards=reward,
terminated=terminated,
_board=-jnp.where(jnp.int32(opp), -1, jnp.int32(my)),
_passed=action == 64,
)


Expand Down Expand Up @@ -208,8 +218,8 @@ def _get_reward(my, opp, curr_player):
def _observe(state, player_id) -> Array:
board = jax.lax.cond(
player_id == state.current_player,
lambda: state._board.reshape((8, 8)),
lambda: (state._board * -1).reshape((8, 8)),
lambda: state._x.board.reshape((8, 8)),
lambda: (state._x.board * -1).reshape((8, 8)),
)

def make(color):
Expand All @@ -219,4 +229,4 @@ def make(color):


def _get_abs_board(state):
return jax.lax.cond(state._turn == 0, lambda: state._board, lambda: state._board * -1)
return jax.lax.cond(state._x.turn == 0, lambda: state._x.board, lambda: state._x.board * -1)
2 changes: 1 addition & 1 deletion tests/test_othello.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def test_step():
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0])
# fmt:on
assert jnp.all(state._board == expected)
assert jnp.all(state._x.board == expected)


def test_terminated():
Expand Down

0 comments on commit 7932e97

Please sign in to comment.