Skip to content

Commit

Permalink
test
Browse files Browse the repository at this point in the history
  • Loading branch information
sotetsuk committed Nov 2, 2023
1 parent d89aeeb commit 5104320
Show file tree
Hide file tree
Showing 9 changed files with 151 additions and 151 deletions.
4 changes: 2 additions & 2 deletions tests/experimental/test_full_mahjong.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def test_deck():


def test_hand():
hand = np.zeros(34, dtype=np.uint8)
hand = np.zeros(34, dtype=np.uint32)
red = np.full(3, False)
hand, red = Hand.add(hand, red, 0)
assert Hand.can_ron(hand, 0)
Expand Down Expand Up @@ -458,7 +458,7 @@ def score(
is_ron: bool = False
) -> int:
hand, red = Hand.from_str(hand_s)
dora = np.zeros(34, dtype=np.uint8)
dora = np.zeros(34, dtype=np.uint32)
melds = np.zeros(4, dtype=np.int32)
n_meld=0
for s in melds_s.split(","):
Expand Down
2 changes: 1 addition & 1 deletion tests/experimental/test_mini_mahjong.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def test_deck():


def test_hand():
hand = jnp.zeros(34, dtype=jnp.uint8)
hand = jnp.zeros(34, dtype=jnp.uint32)
hand = Hand.add(hand, 0)
assert Hand.can_ron(hand, 0)
assert not Hand.can_ron(hand, 1)
Expand Down
62 changes: 31 additions & 31 deletions tests/test_backgammon.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@


def make_test_boad():
board: jnp.ndarray = jnp.zeros(28, dtype=jnp.int8)
board: jnp.ndarray = jnp.zeros(28, dtype=jnp.int32)
# 黒
board = board.at[19].set(5)
board = board.at[20].set(1)
Expand Down Expand Up @@ -101,7 +101,7 @@ def make_test_state(

def test_flip_board():
test_board = make_test_boad()
board: jnp.ndarray = jnp.zeros(28, dtype=jnp.int8)
board: jnp.ndarray = jnp.zeros(28, dtype=jnp.int32)
board = board.at[4].set(-5)
board = board.at[3].set(-1)
board = board.at[2].set(-2)
Expand Down Expand Up @@ -135,10 +135,10 @@ def test_is_turn_end():
# white dance
board: jnp.ndarray = make_test_boad()
state = make_test_state(
current_player=jnp.int8(1),
current_player=jnp.int32(1),
rng=rng,
board=board,
turn=jnp.int8(1),
turn=jnp.int32(1),
dice=jnp.array([2, 2], dtype=jnp.int16),
playable_dice=jnp.array([-1, -1, -1, -1], dtype=jnp.int16),
played_dice_num=jnp.int16(0),
Expand All @@ -148,10 +148,10 @@ def test_is_turn_end():
# No playable dice
board: jnp.ndarray = make_test_boad()
state = make_test_state(
current_player=jnp.int8(1),
current_player=jnp.int32(1),
rng=rng,
board=board,
turn=jnp.int8(1),
turn=jnp.int32(1),
dice=jnp.array([2, 2], dtype=jnp.int16),
playable_dice=jnp.array([-1, -1, -1, -1], dtype=jnp.int16),
played_dice_num=jnp.int16(2),
Expand All @@ -166,7 +166,7 @@ def test_change_turn():
assert state._turn == (_turn + 1) % 2

test_board: jnp.ndarray = make_test_boad()
board: jnp.ndarray = jnp.zeros(28, dtype=jnp.int8)
board: jnp.ndarray = jnp.zeros(28, dtype=jnp.int32)
board = board.at[4].set(-5)
board = board.at[3].set(-1)
board = board.at[2].set(-2)
Expand All @@ -177,17 +177,17 @@ def test_change_turn():
board = board.at[1].set(3)
board = board.at[24].set(4)
state = make_test_state(
current_player=jnp.int8(0),
current_player=jnp.int32(0),
rng=rng,
board=test_board,
turn=jnp.int8(0),
turn=jnp.int32(0),
dice=jnp.array([2, 2], dtype=jnp.int16),
playable_dice=jnp.array([-1, -1, -1, -1], dtype=jnp.int16),
played_dice_num=jnp.int16(2),
)
state = _change_turn(state)
print(state._board, board)
assert state._turn == jnp.int8(1) # Turn changed
assert state._turn == jnp.int32(1) # Turn changed
assert (state._board == board).all() # Flipped.

def test_no_op():
Expand All @@ -196,17 +196,17 @@ def test_no_op():
board, jnp.array([0, 1, -1, -1], dtype=jnp.int16)
)
state = make_test_state(
current_player=jnp.int8(1),
current_player=jnp.int32(1),
rng=rng,
board=board,
turn=jnp.int8(1),
turn=jnp.int32(1),
dice=jnp.array([0, 1], dtype=jnp.int16),
playable_dice=jnp.array([0, 1, -1, -1], dtype=jnp.int16),
played_dice_num=jnp.int16(0),
legal_action_mask=legal_action_mask,
)
state = step(state, 0) # execute no-op action
assert state._turn == jnp.int8(0) # Turn changes after no-op.
assert state._turn == jnp.int32(0) # Turn changes after no-op.


def test_step():
Expand All @@ -217,10 +217,10 @@ def test_step():
board, jnp.array([0, 1, -1, -1], dtype=jnp.int16)
)
state = make_test_state(
current_player=jnp.int8(1),
current_player=jnp.int32(1),
rng=rng,
board=board,
turn=jnp.int8(1),
turn=jnp.int32(1),
dice=jnp.array([0, 1], dtype=jnp.int16),
playable_dice=jnp.array([0, 1, -1, -1], dtype=jnp.int16),
played_dice_num=jnp.int16(0),
Expand Down Expand Up @@ -270,10 +270,10 @@ def test_step():
board, jnp.array([4, 5, -1, -1], dtype=jnp.int16)
)
state = make_test_state(
current_player=jnp.int8(0),
current_player=jnp.int32(0),
rng=rng,
board=board,
turn=jnp.int8(0),
turn=jnp.int32(0),
dice=jnp.array([4, 5], dtype=jnp.int16),
playable_dice=jnp.array([4, 5, -1, -1], dtype=jnp.int16),
played_dice_num=jnp.int16(0),
Expand Down Expand Up @@ -301,61 +301,61 @@ def test_observe():

# current_player = white, playable_dice = (1, 2)
state = make_test_state(
current_player=jnp.int8(1),
current_player=jnp.int32(1),
rng=rng,
board=board,
turn=jnp.int8(1),
turn=jnp.int32(1),
dice=jnp.array([0, 1], dtype=jnp.int16),
playable_dice=jnp.array([0, 1, -1, -1], dtype=jnp.int16),
played_dice_num=jnp.int16(0),
)
expected_obs = jnp.concatenate(
(board, jnp.array([1, 1, 0, 0, 0, 0])), axis=None
)
assert (observe(state, jnp.int8(1)) == expected_obs).all()
assert (observe(state, jnp.int32(1)) == expected_obs).all()

state = make_test_state(
current_player=jnp.int8(1),
current_player=jnp.int32(1),
rng=rng,
board=board,
turn=jnp.int8(1),
turn=jnp.int32(1),
dice=jnp.array([0, 1], dtype=jnp.int16),
playable_dice=jnp.array([1, 1, 1, 1], dtype=jnp.int16),
played_dice_num=jnp.int16(0),
)
expected_obs = jnp.concatenate(
(board, jnp.array([0, 4, 0, 0, 0, 0])), axis=None
)
assert (observe(state, jnp.int8(1)) == expected_obs).all()
assert (observe(state, jnp.int32(1)) == expected_obs).all()

# current_player = black, playabl_dice = (2)
state = make_test_state(
current_player=jnp.int8(1),
current_player=jnp.int32(1),
rng=rng,
board=board,
turn=jnp.int8(-1),
turn=jnp.int32(-1),
dice=jnp.array([0, 1], dtype=jnp.int16),
playable_dice=jnp.array([-1, 1, -1, -1], dtype=jnp.int16),
played_dice_num=jnp.int16(0),
)
expected_obs = jnp.concatenate(
(board, jnp.array([0, 1, 0, 0, 0, 0])), axis=None
)
assert (observe(state, jnp.int8(1)) == expected_obs).all()
assert (observe(state, jnp.int32(1)) == expected_obs).all()

state = make_test_state(
current_player=jnp.int8(1),
current_player=jnp.int32(1),
rng=rng,
board=board,
turn=jnp.int8(-1),
turn=jnp.int32(-1),
dice=jnp.array([0, 1], dtype=jnp.int16),
playable_dice=jnp.array([-1, 1, -1, -1], dtype=jnp.int16),
played_dice_num=jnp.int16(0),
)
expected_obs = jnp.concatenate(
(1 * board, jnp.array([0, 0, 0, 0, 0, 0])), axis=None
)
assert (observe(state, jnp.int8(0)) == expected_obs).all()
assert (observe(state, jnp.int32(0)) == expected_obs).all()


def test_is_open():
Expand Down Expand Up @@ -397,7 +397,7 @@ def test_is_all_on_home_boad():

def test_rear_distance():
board = make_test_boad()
turn = jnp.int8(-1)
turn = jnp.int32(-1)
# Black
assert _rear_distance(board) == 5
# White
Expand All @@ -408,7 +408,7 @@ def test_rear_distance():
def test_distance_to_goal():
board = make_test_boad()
# Black
turn = jnp.int8(-1)
turn = jnp.int32(-1)
src = 23
assert _distance_to_goal(src) == 1
src = 10
Expand Down
Loading

0 comments on commit 5104320

Please sign in to comment.