diff --git a/bb.py b/bb.py new file mode 100644 index 000000000..70a605a5a --- /dev/null +++ b/bb.py @@ -0,0 +1,75 @@ +import jax +from jax import lax +import jax.numpy as jnp +from pgx.experimental.chess import from_fen + +# ピースの定義 +EMPTY, PAWN, KNIGHT, BISHOP, ROOK, QUEEN, KING = tuple(range(7)) # ピース +PIECE_TYPES = [EMPTY, PAWN, KNIGHT, BISHOP, ROOK, QUEEN, KING] + +# ボードのサイズ +BOARD_SIZE = 64 + +# 4bitの情報に対応するためのシフト量 +SHIFT_PIECE_TYPE = 0 # piece typeを表すのは0ビット右にずれる +SHIFT_COLOR = 3 # colorは左端のビット (3ビット分ずれる) + +@jax.jit +def to_bitboard(board): + bitboard = jnp.zeros(8, dtype=jnp.uint32) + + for idx in range(BOARD_SIZE): + piece = board[idx] + rank = idx % 8 + file = idx // 8 + color = piece < 0 # >=0: us, <0: opp + piece_type = jnp.abs(piece) + bit_value = (color << SHIFT_COLOR) | piece_type + bit_value = bitboard[rank] | (bit_value << (4 * file)) + bitboard = bitboard.at[rank].set(bit_value) + + return bitboard + + +@jax.jit +def to_board(bitboard): + return jax.vmap(get_bb, in_axes=(None, 0))(bitboard, jnp.arange(64)) + + +def get_bb(bb, pos): + rank, file = pos % 8, pos // 8 + rank_bb = bb[rank] + bits = (rank_bb >> (4 * file)) & 0b1111 + color = (bits >> SHIFT_COLOR) & 1 + piece_type = bits & 0b111 + return jnp.int32([1, -1])[color] * piece_type + +# --- テストコード --- +# 8 7 15 23 31 39 47 55 63 +# 7 6 14 22 30 38 46 54 62 +# 6 5 13 21 29 37 45 53 61 +# 5 4 12 20 28 36 44 52 60 +# 4 3 11 19 27 35 43 51 59 +# 3 2 10 18 26 34 42 50 58 +# 2 1 9 17 25 33 41 49 57 +# 1 0 8 16 24 32 40 48 56 +# a b c d e f g h +INIT_BOARD = jnp.int32([4, 1, 0, 0, 0, 0, -1, -4, 2, 1, 0, 0, 0, 0, -1, -2, 3, 1, 0, 0, 0, 0, -1, -3, 5, 1, 0, 0, 0, 0, -1, -5, 6, 1, 0, 0, 0, 0, -1, -6, 3, 1, 0, 0, 0, 0, -1, -3, 2, 1, 0, 0, 0, 0, -1, -2, 4, 1, 0, 0, 0, 0, -1, -4]) # fmt: skip + + +# テスト用のboard +board = INIT_BOARD + +# board -> bitboard -> board のテスト +bitboard = to_bitboard(board) +reconstructed_board = to_board(bitboard) + +print("Original board:") +print(board.reshape(8, 8)) +print("\nBitboard:") +print(bitboard) +print("\nReconstructed board:") +print(reconstructed_board.reshape(8, 8)) + +# 再構築したboardが元のboardと同じかどうかを確認 +assert jnp.array_equal(board, reconstructed_board) diff --git a/pgx/_src/dwg/chess.py b/pgx/_src/dwg/chess.py index 860955637..4e7788399 100644 --- a/pgx/_src/dwg/chess.py +++ b/pgx/_src/dwg/chess.py @@ -3,6 +3,7 @@ from pgx.chess import State as ChessState from pgx.chess import _flip +from pgx._src.games.chess import to_board def _make_chess_dwg(dwg, state: ChessState, config): @@ -136,7 +137,7 @@ def _set_piece(_x, _y, _type, _dwg, _dwg_g, grid_size): # pieces pieces_g = dwg.g() for i in range(64): - pi = int(state._x.board[i].item()) + pi = int(to_board(state._x.bb)[i].item()) if pi == 0: continue if pi < 0: diff --git a/pgx/_src/games/chess.py b/pgx/_src/games/chess.py index e16df4bf3..c94b5b62e 100644 --- a/pgx/_src/games/chess.py +++ b/pgx/_src/games/chess.py @@ -22,6 +22,65 @@ EMPTY, PAWN, KNIGHT, BISHOP, ROOK, QUEEN, KING = tuple(range(7)) # opponent: -1 * piece MAX_TERMINATION_STEPS = 512 # from AlphaZero paper + +PIECE_TYPES = [EMPTY, PAWN, KNIGHT, BISHOP, ROOK, QUEEN, KING] + +# ボードのサイズ +BOARD_SIZE = 64 + +# 4bitの情報に対応するためのシフト量 +SHIFT_PIECE_TYPE = 0 # piece typeを表すのは0ビット右にずれる +SHIFT_COLOR = 3 # colorは左端のビット (3ビット分ずれる) + + +@jax.jit +def to_bitboard(board): + bitboard = jnp.zeros(8, dtype=jnp.uint32) + + for idx in range(BOARD_SIZE): + piece = board[idx] + rank = idx % 8 + file = idx // 8 + color = piece < 0 # >=0: us, <0: opp + piece_type = jnp.abs(piece) + bit_value = (color << SHIFT_COLOR) | piece_type + bit_value = bitboard[rank] | (bit_value << (4 * file)) + bitboard = bitboard.at[rank].set(bit_value) + + return bitboard + + +@jax.jit +def to_board(bitboard): + return jax.vmap(get_bb, in_axes=(None, 0))(bitboard, jnp.arange(64)) + + +def get_bb(bb, pos): + bb = bb.astype(jnp.uint32) + rank, file = pos % 8, pos // 8 + rank_bb = bb[rank] + bits = (rank_bb >> (4 * file)) & 0b1111 + color = (bits >> SHIFT_COLOR) & 1 + piece_type = bits & 0b111 + return jnp.int32([1, -1])[color] * piece_type + + +def set_bb(bb, ix, piece): + bb = bb.astype(jnp.uint32) + rank, file = ix % 8, jnp.uint32(ix // 8) + rank_bb = bb[rank] + color = piece < 0 # >=0: us, <0: opp + piece_type = jnp.abs(piece) + bits = (color << SHIFT_COLOR) | piece_type + rank_bb = rank_bb & ~(0b1111 << (4 * file)) # clear the bits + rank_bb = rank_bb | (bits << (4 * file)) + return bb.at[rank].set(rank_bb) + + +def flip_bb(bb): + return bb ^ jnp.uint32(0b10001000100010001000100010001000) + + # prepare precomputed values here (e.g., available moves, map to label, etc.) # index: a1: 0, a2: 1, ..., h8: 63 @@ -136,10 +195,12 @@ ZOBRIST_EN_PASSANT = jax.random.randint(keys[3], shape=(65, 2), minval=0, maxval=2**31 - 1, dtype=jnp.uint32) INIT_ZOBRIST_HASH = jnp.uint32([1455170221, 1478960862]) +INIT_BB = to_bitboard(INIT_BOARD) + class GameState(NamedTuple): color: Array = jnp.int32(0) # w: 0, b: 1 - board: Array = INIT_BOARD # (64,) + bb: Array = INIT_BB castling_rights: Array = jnp.ones([2, 2], dtype=jnp.bool_) # my queen, my king, opp queen, opp king en_passant: Array = jnp.int32(-1) halfmove_count: Array = jnp.int32(0) # number of moves since the last piece capture or pawn move @@ -231,7 +292,7 @@ def rewards(self, state: GameState) -> Array: def _update_history(state: GameState): board_history = jnp.roll(state.board_history, 64) - board_history = board_history.at[0].set(state.board) + board_history = board_history.at[0].set(to_board(state.bb)) hash_hist = jnp.roll(state.hash_history, 2) hash_hist = hash_hist.at[0].set(_zobrist_hash(state)) return state._replace(board_history=board_history, hash_history=hash_hist) @@ -239,12 +300,12 @@ def _update_history(state: GameState): def has_insufficient_pieces(state: GameState): # uses the same condition as OpenSpiel - num_pieces = (state.board != EMPTY).sum() - num_pawn_rook_queen = ((jnp.abs(state.board) >= ROOK) | (jnp.abs(state.board) == PAWN)).sum() - 2 # two kings - num_bishop = (jnp.abs(state.board) == BISHOP).sum() + num_pieces = (to_board(state.bb) != EMPTY).sum() + num_pawn_rook_queen = ((jnp.abs(to_board(state.bb)) >= ROOK) | (jnp.abs(to_board(state.bb)) == PAWN)).sum() - 2 # two kings + num_bishop = (jnp.abs(to_board(state.bb)) == BISHOP).sum() coords = jnp.arange(64).reshape((8, 8)) black_coords = jnp.hstack((coords[::2, ::2].ravel(), coords[1::2, 1::2].ravel())) - num_bishop_on_black = (jnp.abs(state.board[black_coords]) == BISHOP).sum() + num_bishop_on_black = (jnp.abs(to_board(state.bb)[black_coords]) == BISHOP).sum() is_insufficient = False # king vs king is_insufficient |= num_pieces <= 2 @@ -259,28 +320,26 @@ def has_insufficient_pieces(state: GameState): def _apply_move(state: GameState, a: Action) -> GameState: - piece = state.board[a.from_] + piece = get_bb(state.bb, a.from_) # en passant is_en_passant = (state.en_passant >= 0) & (piece == PAWN) & (state.en_passant == a.to) removed_pawn_pos = a.to - 1 - state = state._replace( - board=state.board.at[removed_pawn_pos].set(lax.select(is_en_passant, EMPTY, state.board[removed_pawn_pos])) - ) + new_piece = lax.select(is_en_passant, EMPTY, get_bb(state.bb, removed_pawn_pos)) + state = state._replace(bb=set_bb(state.bb, removed_pawn_pos, new_piece)) is_en_passant = (piece == PAWN) & (jnp.abs(a.to - a.from_) == 2) state = state._replace(en_passant=lax.select(is_en_passant, (a.to + a.from_) // 2, -1)) # update counters - captured = (state.board[a.to] < 0) | is_en_passant + captured = (get_bb(state.bb, a.to) < 0) | is_en_passant state = state._replace( halfmove_count=lax.select(captured | (piece == PAWN), 0, state.halfmove_count + 1), fullmove_count=state.fullmove_count + jnp.int32(state.color == 1), ) # castling - board = state.board is_queen_side_castling = (piece == KING) & (a.from_ == 32) & (a.to == 16) - board = lax.select(is_queen_side_castling, board.at[0].set(EMPTY).at[24].set(ROOK), board) + bb = lax.select(is_queen_side_castling, set_bb(set_bb(state.bb, 0, EMPTY), 24, ROOK), state.bb) is_king_side_castling = (piece == KING) & (a.from_ == 32) & (a.to == 48) - board = lax.select(is_king_side_castling, board.at[56].set(EMPTY).at[40].set(ROOK), board) - state = state._replace(board=board) + bb = lax.select(is_king_side_castling, set_bb(set_bb(bb, 56, EMPTY), 40, ROOK), bb) + state = state._replace(bb=bb) # update castling rights cond = jnp.bool_([[(a.from_ != 32) & (a.from_ != 0), (a.from_ != 32) & (a.from_ != 56)], [a.to != 7, a.to != 63]]) state = state._replace(castling_rights=state.castling_rights & cond) @@ -289,7 +348,7 @@ def _apply_move(state: GameState, a: Action) -> GameState: # underpromotion piece = lax.select(a.underpromotion < 0, piece, jnp.int32([ROOK, BISHOP, KNIGHT])[a.underpromotion]) # actually move - state = state._replace(board=state.board.at[a.from_].set(EMPTY).at[a.to].set(piece)) # type: ignore + state = state._replace(bb=set_bb(set_bb(state.bb, a.from_, EMPTY), a.to, piece)) return state @@ -299,7 +358,7 @@ def _flip_pos(x: Array): # e.g., 37 <-> 34, -1 <-> -1 def _flip(state: GameState) -> GameState: return state._replace( - board=-jnp.flip(state.board.reshape(8, 8), axis=1).flatten(), + bb=flip_bb(state.bb)[::-1], color=(state.color + 1) % 2, en_passant=_flip_pos(state.en_passant), castling_rights=state.castling_rights[::-1], @@ -309,14 +368,14 @@ def _flip(state: GameState) -> GameState: def _legal_action_mask(state: GameState) -> Array: def legal_normal_moves(from_): - piece = state.board[from_] + piece = get_bb(state.bb, from_) def legal_label(to): - ok = (from_ >= 0) & (piece > 0) & (to >= 0) & (state.board[to] <= 0) + ok = (from_ >= 0) & (piece > 0) & (to >= 0) & (get_bb(state.bb, to) <= 0) between_ixs = BETWEEN[from_, to] - ok &= CAN_MOVE[piece, from_, to] & ((between_ixs < 0) | (state.board[between_ixs] == EMPTY)).all() + ok &= CAN_MOVE[piece, from_, to] & ((between_ixs < 0) | (jax.vmap(get_bb, in_axes=(None, 0))(state.bb, between_ixs) == EMPTY)).all() c0, c1 = from_ // 8, to // 8 - pawn_should = ((c1 == c0) & (state.board[to] == EMPTY)) | ((c1 != c0) & (state.board[to] < 0)) + pawn_should = ((c1 == c0) & (get_bb(state.bb, to) == EMPTY)) | ((c1 != c0) & (get_bb(state.bb, to) < 0)) ok &= (piece != PAWN) | pawn_should return lax.select(ok, Action(from_=from_, to=to)._to_label(), -1) @@ -326,7 +385,7 @@ def legal_en_passants(): to = state.en_passant def legal_labels(from_): - ok = (from_ >= 0) & (from_ < 64) & (to >= 0) & (state.board[from_] == PAWN) & (state.board[to - 1] == -PAWN) + ok = (from_ >= 0) & (from_ < 64) & (to >= 0) & (get_bb(state.bb, from_) == PAWN) & (get_bb(state.bb, to - 1) == -PAWN) a = Action(from_=from_, to=to) return lax.select(ok, a._to_label(), -1) @@ -339,7 +398,7 @@ def is_not_checked(label): def legal_underpromotions(mask): def legal_labels(label): a = Action._from_label(label) - ok = (state.board[a.from_] == PAWN) & (a.to >= 0) + ok = (get_bb(state.bb, a.from_) == PAWN) & (a.to >= 0) ok &= mask[Action(from_=a.from_, to=a.to)._to_label()] return lax.select(ok, label, -1) @@ -347,7 +406,7 @@ def legal_labels(label): return jax.vmap(legal_labels)(labels) # normal move and en passant - possible_piece_positions = jnp.nonzero(state.board > 0, size=16, fill_value=-1)[0] + possible_piece_positions = jnp.nonzero(to_board(state.bb) > 0, size=16, fill_value=-1)[0] a1 = jax.vmap(legal_normal_moves)(possible_piece_positions).flatten() a2 = legal_en_passants() actions = jnp.hstack((a1, a2)) # include -1 @@ -356,11 +415,11 @@ def legal_labels(label): mask = mask.at[actions].set(True) # castling - b = state.board + bb = state.bb can_castle_queen_side = state.castling_rights[0, 0] - can_castle_queen_side &= (b[0] == ROOK) & (b[8] == EMPTY) & (b[16] == EMPTY) & (b[24] == EMPTY) & (b[32] == KING) + can_castle_queen_side &= (get_bb(bb, 0) == ROOK) & (get_bb(bb, 8) == EMPTY) & (get_bb(bb, 16) == EMPTY) & (get_bb(bb, 24) == EMPTY) & (get_bb(bb, 32) == KING) can_castle_king_side = state.castling_rights[0, 1] - can_castle_king_side &= (b[32] == KING) & (b[40] == EMPTY) & (b[48] == EMPTY) & (b[56] == ROOK) + can_castle_king_side &= (get_bb(bb, 32) == KING) & (get_bb(bb, 40) == EMPTY) & (get_bb(bb, 48) == EMPTY) & (get_bb(bb, 56) == ROOK) not_checked = ~jax.vmap(_is_attacked, in_axes=(None, 0))(state, jnp.int32([16, 24, 32, 40, 48])) mask = mask.at[2364].set(mask[2364] | (can_castle_queen_side & not_checked[:3].all())) mask = mask.at[2367].set(mask[2367] | (can_castle_king_side & not_checked[2:].all())) @@ -374,16 +433,16 @@ def legal_labels(label): def _is_attacked(state: GameState, pos: Array): def attacked_far(to): - ok = (to >= 0) & (state.board[to] < 0) # should be opponent's - piece = jnp.abs(state.board[to]) + ok = (to >= 0) & (get_bb(state.bb, to) < 0) # should be opponent's + piece = jnp.abs(get_bb(state.bb, to)) ok &= (piece == QUEEN) | (piece == ROOK) | (piece == BISHOP) between_ixs = BETWEEN[pos, to] - ok &= CAN_MOVE[piece, pos, to] & ((between_ixs < 0) | (state.board[between_ixs] == EMPTY)).all() + ok &= CAN_MOVE[piece, pos, to] & ((between_ixs < 0) | (jax.vmap(get_bb, in_axes=(None, 0))(state.bb, between_ixs) == EMPTY)).all() return ok def attacked_near(to): - ok = (to >= 0) & (state.board[to] < 0) # should be opponent's - piece = jnp.abs(state.board[to]) + ok = (to >= 0) & (get_bb(state.bb, to) < 0) # should be opponent's + piece = jnp.abs(get_bb(state.bb, to)) ok &= CAN_MOVE[piece, pos, to] ok &= ~((piece == PAWN) & (to // 8 == pos // 8)) # should move diagonally to capture return ok @@ -394,13 +453,13 @@ def attacked_near(to): def _is_checked(state: GameState): - king_pos = jnp.argmin(jnp.abs(state.board - KING)) + king_pos = jnp.argmin(jnp.abs(to_board(state.bb) - KING)) return _is_attacked(state, king_pos) def _zobrist_hash(state: GameState) -> Array: hash_ = lax.select(state.color == 0, ZOBRIST_SIDE, jnp.zeros_like(ZOBRIST_SIDE)) - to_reduce = ZOBRIST_BOARD[jnp.arange(64), state.board + 6] # 0, ..., 12 (w:pawn, ..., b:king) + to_reduce = ZOBRIST_BOARD[jnp.arange(64), to_board(state.bb) + 6] # 0, ..., 12 (w:pawn, ..., b:king) hash_ ^= lax.reduce(to_reduce, 0, lax.bitwise_xor, (0,)) to_reduce = jnp.where(state.castling_rights.reshape(-1, 1), ZOBRIST_CASTLING, 0) hash_ ^= lax.reduce(to_reduce, 0, lax.bitwise_xor, (0,)) diff --git a/pgx/experimental/chess.py b/pgx/experimental/chess.py index 36313d069..cc68fafe7 100644 --- a/pgx/experimental/chess.py +++ b/pgx/experimental/chess.py @@ -2,7 +2,7 @@ import jax.numpy as jnp import numpy as np -from pgx._src.games.chess import Game, GameState, _flip_pos, _legal_action_mask, _update_history +from pgx._src.games.chess import Game, GameState, _flip_pos, _legal_action_mask, _update_history, to_bitboard, to_board from pgx.chess import State TRUE = jnp.bool_(True) @@ -59,7 +59,7 @@ def from_fen(fen: str): if color == "b" and ep >= 0: ep = _flip_pos(ep) x = GameState( - board=jnp.rot90(mat, k=3).flatten(), + bb=to_bitboard(jnp.rot90(mat, k=3).flatten()), color=jnp.int32(0) if color == "w" else jnp.int32(1), castling_rights=castling_rights, en_passant=ep, @@ -103,7 +103,7 @@ def to_fen(state: State): >>> _to_fen(_from_fen("rnbqkbnr/pppppppp/8/8/8/P7/1PPPPPPP/RNBQKBNR b KQkq e3 0 1")) 'rnbqkbnr/pppppppp/8/8/8/P7/1PPPPPPP/RNBQKBNR b KQkq e3 0 1' """ - pb = np.rot90(state._x.board.reshape(8, 8), k=1) + pb = np.rot90(to_board(state._x.bb).reshape(8, 8), k=1) if state._x.color == 1: pb = -np.flip(pb, axis=0) fen = "" diff --git a/tests/test_chess.py b/tests/test_chess.py index ee091622c..902db2d76 100644 --- a/tests/test_chess.py +++ b/tests/test_chess.py @@ -2,7 +2,7 @@ import jax.numpy as jnp import pgx from pgx.chess import State, Chess -from pgx._src.games.chess import GameState, Action, KING, QUEEN, EMPTY, ROOK, PAWN, _legal_action_mask, CAN_MOVE, _zobrist_hash, INIT_ZOBRIST_HASH +from pgx._src.games.chess import GameState, Action, KING, QUEEN, EMPTY, ROOK, PAWN, _legal_action_mask, CAN_MOVE, _zobrist_hash, INIT_ZOBRIST_HASH, to_board from pgx.experimental.utils import act_randomly from pgx.experimental.chess import from_fen, to_fen @@ -135,23 +135,23 @@ def test_step(): # normal step state = from_fen("1k6/8/8/8/8/8/1Q6/7K w - - 0 1") state.save_svg("tests/assets/chess/step_001.svg") - assert state._x.board[p("b1")] == EMPTY + assert to_board(state._x.bb)[p("b1")] == EMPTY state = step(state, jnp.int32(672)) state.save_svg("tests/assets/chess/step_002.svg") - assert state._x.board[p("b1", True)] == -QUEEN + assert to_board(state._x.bb)[p("b1", True)] == -QUEEN # promotion state = from_fen("r1r4k/1P6/8/8/8/8/P7/7K w - - 0 1") state.save_svg("tests/assets/chess/step_002.svg") - assert state._x.board[p("b8")] == EMPTY + assert to_board(state._x.bb)[p("b8")] == EMPTY # underpromotion next_state = step(state, jnp.int32(1022)) next_state.save_svg("tests/assets/chess/step_003.svg") - assert next_state._x.board[p("b8", True)] == -ROOK + assert to_board(next_state._x.bb)[p("b8", True)] == -ROOK # promotion to queen next_state = step(state, jnp.int32(p("b7") * 73 + 16)) next_state.save_svg("tests/assets/chess/step_004.svg") - assert next_state._x.board[p("b8", True)] == -QUEEN + assert to_board(next_state._x.bb)[p("b8", True)] == -QUEEN # castling state = from_fen("1k6/8/8/8/8/8/8/R3K2R w KQ - 0 1") @@ -159,24 +159,24 @@ def test_step(): # left next_state = step(state, jnp.int32(p("e1") * 73 + 28)) next_state.save_svg("tests/assets/chess/step_006.svg") - assert next_state._x.board[p("c1", True)] == -KING - assert next_state._x.board[p("d1", True)] == -ROOK # castling - assert next_state._x.board[p("a1", True)] == EMPTY # castling + assert to_board(next_state._x.bb)[p("c1", True)] == -KING + assert to_board(next_state._x.bb)[p("d1", True)] == -ROOK # castling + assert to_board(next_state._x.bb)[p("a1", True)] == EMPTY # castling # right next_state = step(state, jnp.int32(p("e1") * 73 + 31)) next_state.save_svg("tests/assets/chess/step_007.svg") - assert next_state._x.board[p("g1", True)] == -KING - assert next_state._x.board[p("f1", True)] == -ROOK # castling - assert next_state._x.board[p("h1", True)] == EMPTY # castling + assert to_board(next_state._x.bb)[p("g1", True)] == -KING + assert to_board(next_state._x.bb)[p("f1", True)] == -ROOK # castling + assert to_board(next_state._x.bb)[p("h1", True)] == EMPTY # castling # en passant state = from_fen("1k6/8/8/8/3pP3/8/8/R3K2R b KQ e3 0 1") state.save_svg("tests/assets/chess/step_008.svg") - assert state._x.board[p("e4", True)] == -PAWN + assert to_board(state._x.bb)[p("e4", True)] == -PAWN next_state = step(state, jnp.int32(p("d4", True) * 73 + 44)) next_state.save_svg("tests/assets/chess/step_009.svg") - assert next_state._x.board[p("e3")] == -PAWN - assert next_state._x.board[p("e4")] == EMPTY + assert to_board(next_state._x.bb)[p("e3")] == -PAWN + assert to_board(next_state._x.bb)[p("e4")] == EMPTY state = from_fen("1k6/8/8/8/3p4/8/4P3/R3K2R w KQ - 0 1") state.save_svg("tests/assets/chess/step_010.svg") next_state = step(state, jnp.int32(p("e2") * 73 + 17)) # UP 2