Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Chess] reduce memory usage #1258

Open
wants to merge 23 commits into
base: main
Choose a base branch
from
75 changes: 75 additions & 0 deletions bb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import jax
from jax import lax
import jax.numpy as jnp
from pgx.experimental.chess import from_fen

# ピースの定義
EMPTY, PAWN, KNIGHT, BISHOP, ROOK, QUEEN, KING = tuple(range(7)) # ピース
PIECE_TYPES = [EMPTY, PAWN, KNIGHT, BISHOP, ROOK, QUEEN, KING]

# ボードのサイズ
BOARD_SIZE = 64

# 4bitの情報に対応するためのシフト量
SHIFT_PIECE_TYPE = 0 # piece typeを表すのは0ビット右にずれる
SHIFT_COLOR = 3 # colorは左端のビット (3ビット分ずれる)

@jax.jit
def to_bitboard(board):
bitboard = jnp.zeros(8, dtype=jnp.uint32)

for idx in range(BOARD_SIZE):
piece = board[idx]
rank = idx % 8
file = idx // 8
color = piece < 0 # >=0: us, <0: opp
piece_type = jnp.abs(piece)
bit_value = (color << SHIFT_COLOR) | piece_type
bit_value = bitboard[rank] | (bit_value << (4 * file))
bitboard = bitboard.at[rank].set(bit_value)

return bitboard


@jax.jit
def to_board(bitboard):
return jax.vmap(get_bb, in_axes=(None, 0))(bitboard, jnp.arange(64))


def get_bb(bb, pos):
rank, file = pos % 8, pos // 8
rank_bb = bb[rank]
bits = (rank_bb >> (4 * file)) & 0b1111
color = (bits >> SHIFT_COLOR) & 1
piece_type = bits & 0b111
return jnp.int32([1, -1])[color] * piece_type

# --- テストコード ---
# 8 7 15 23 31 39 47 55 63
# 7 6 14 22 30 38 46 54 62
# 6 5 13 21 29 37 45 53 61
# 5 4 12 20 28 36 44 52 60
# 4 3 11 19 27 35 43 51 59
# 3 2 10 18 26 34 42 50 58
# 2 1 9 17 25 33 41 49 57
# 1 0 8 16 24 32 40 48 56
# a b c d e f g h
INIT_BOARD = jnp.int32([4, 1, 0, 0, 0, 0, -1, -4, 2, 1, 0, 0, 0, 0, -1, -2, 3, 1, 0, 0, 0, 0, -1, -3, 5, 1, 0, 0, 0, 0, -1, -5, 6, 1, 0, 0, 0, 0, -1, -6, 3, 1, 0, 0, 0, 0, -1, -3, 2, 1, 0, 0, 0, 0, -1, -2, 4, 1, 0, 0, 0, 0, -1, -4]) # fmt: skip


# テスト用のboard
board = INIT_BOARD

# board -> bitboard -> board のテスト
bitboard = to_bitboard(board)
reconstructed_board = to_board(bitboard)

print("Original board:")
print(board.reshape(8, 8))
print("\nBitboard:")
print(bitboard)
print("\nReconstructed board:")
print(reconstructed_board.reshape(8, 8))

# 再構築したboardが元のboardと同じかどうかを確認
assert jnp.array_equal(board, reconstructed_board)
3 changes: 2 additions & 1 deletion pgx/_src/dwg/chess.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from pgx.chess import State as ChessState
from pgx.chess import _flip
from pgx._src.games.chess import to_board


def _make_chess_dwg(dwg, state: ChessState, config):
Expand Down Expand Up @@ -136,7 +137,7 @@ def _set_piece(_x, _y, _type, _dwg, _dwg_g, grid_size):
# pieces
pieces_g = dwg.g()
for i in range(64):
pi = int(state._x.board[i].item())
pi = int(to_board(state._x.bb)[i].item())
if pi == 0:
continue
if pi < 0:
Expand Down
127 changes: 93 additions & 34 deletions pgx/_src/games/chess.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,65 @@
EMPTY, PAWN, KNIGHT, BISHOP, ROOK, QUEEN, KING = tuple(range(7)) # opponent: -1 * piece
MAX_TERMINATION_STEPS = 512 # from AlphaZero paper


PIECE_TYPES = [EMPTY, PAWN, KNIGHT, BISHOP, ROOK, QUEEN, KING]

# ボードのサイズ
BOARD_SIZE = 64

# 4bitの情報に対応するためのシフト量
SHIFT_PIECE_TYPE = 0 # piece typeを表すのは0ビット右にずれる
SHIFT_COLOR = 3 # colorは左端のビット (3ビット分ずれる)


@jax.jit
def to_bitboard(board):
bitboard = jnp.zeros(8, dtype=jnp.uint32)

for idx in range(BOARD_SIZE):
piece = board[idx]
rank = idx % 8
file = idx // 8
color = piece < 0 # >=0: us, <0: opp
piece_type = jnp.abs(piece)
bit_value = (color << SHIFT_COLOR) | piece_type
bit_value = bitboard[rank] | (bit_value << (4 * file))
bitboard = bitboard.at[rank].set(bit_value)

return bitboard


@jax.jit
def to_board(bitboard):
return jax.vmap(get_bb, in_axes=(None, 0))(bitboard, jnp.arange(64))


def get_bb(bb, pos):
bb = bb.astype(jnp.uint32)
rank, file = pos % 8, pos // 8
rank_bb = bb[rank]
bits = (rank_bb >> (4 * file)) & 0b1111
color = (bits >> SHIFT_COLOR) & 1
piece_type = bits & 0b111
return jnp.int32([1, -1])[color] * piece_type


def set_bb(bb, ix, piece):
bb = bb.astype(jnp.uint32)
rank, file = ix % 8, jnp.uint32(ix // 8)
rank_bb = bb[rank]
color = piece < 0 # >=0: us, <0: opp
piece_type = jnp.abs(piece)
bits = (color << SHIFT_COLOR) | piece_type
rank_bb = rank_bb & ~(0b1111 << (4 * file)) # clear the bits
rank_bb = rank_bb | (bits << (4 * file))
return bb.at[rank].set(rank_bb)


def flip_bb(bb):
return bb ^ jnp.uint32(0b10001000100010001000100010001000)


# prepare precomputed values here (e.g., available moves, map to label, etc.)

# index: a1: 0, a2: 1, ..., h8: 63
Expand Down Expand Up @@ -136,10 +195,12 @@
ZOBRIST_EN_PASSANT = jax.random.randint(keys[3], shape=(65, 2), minval=0, maxval=2**31 - 1, dtype=jnp.uint32)
INIT_ZOBRIST_HASH = jnp.uint32([1455170221, 1478960862])

INIT_BB = to_bitboard(INIT_BOARD)


class GameState(NamedTuple):
color: Array = jnp.int32(0) # w: 0, b: 1
board: Array = INIT_BOARD # (64,)
bb: Array = INIT_BB
castling_rights: Array = jnp.ones([2, 2], dtype=jnp.bool_) # my queen, my king, opp queen, opp king
en_passant: Array = jnp.int32(-1)
halfmove_count: Array = jnp.int32(0) # number of moves since the last piece capture or pawn move
Expand Down Expand Up @@ -231,20 +292,20 @@ def rewards(self, state: GameState) -> Array:

def _update_history(state: GameState):
board_history = jnp.roll(state.board_history, 64)
board_history = board_history.at[0].set(state.board)
board_history = board_history.at[0].set(to_board(state.bb))
hash_hist = jnp.roll(state.hash_history, 2)
hash_hist = hash_hist.at[0].set(_zobrist_hash(state))
return state._replace(board_history=board_history, hash_history=hash_hist)


def has_insufficient_pieces(state: GameState):
# uses the same condition as OpenSpiel
num_pieces = (state.board != EMPTY).sum()
num_pawn_rook_queen = ((jnp.abs(state.board) >= ROOK) | (jnp.abs(state.board) == PAWN)).sum() - 2 # two kings
num_bishop = (jnp.abs(state.board) == BISHOP).sum()
num_pieces = (to_board(state.bb) != EMPTY).sum()
num_pawn_rook_queen = ((jnp.abs(to_board(state.bb)) >= ROOK) | (jnp.abs(to_board(state.bb)) == PAWN)).sum() - 2 # two kings
num_bishop = (jnp.abs(to_board(state.bb)) == BISHOP).sum()
coords = jnp.arange(64).reshape((8, 8))
black_coords = jnp.hstack((coords[::2, ::2].ravel(), coords[1::2, 1::2].ravel()))
num_bishop_on_black = (jnp.abs(state.board[black_coords]) == BISHOP).sum()
num_bishop_on_black = (jnp.abs(to_board(state.bb)[black_coords]) == BISHOP).sum()
is_insufficient = False
# king vs king
is_insufficient |= num_pieces <= 2
Expand All @@ -259,28 +320,26 @@ def has_insufficient_pieces(state: GameState):


def _apply_move(state: GameState, a: Action) -> GameState:
piece = state.board[a.from_]
piece = get_bb(state.bb, a.from_)
# en passant
is_en_passant = (state.en_passant >= 0) & (piece == PAWN) & (state.en_passant == a.to)
removed_pawn_pos = a.to - 1
state = state._replace(
board=state.board.at[removed_pawn_pos].set(lax.select(is_en_passant, EMPTY, state.board[removed_pawn_pos]))
)
new_piece = lax.select(is_en_passant, EMPTY, get_bb(state.bb, removed_pawn_pos))
state = state._replace(bb=set_bb(state.bb, removed_pawn_pos, new_piece))
is_en_passant = (piece == PAWN) & (jnp.abs(a.to - a.from_) == 2)
state = state._replace(en_passant=lax.select(is_en_passant, (a.to + a.from_) // 2, -1))
# update counters
captured = (state.board[a.to] < 0) | is_en_passant
captured = (get_bb(state.bb, a.to) < 0) | is_en_passant
state = state._replace(
halfmove_count=lax.select(captured | (piece == PAWN), 0, state.halfmove_count + 1),
fullmove_count=state.fullmove_count + jnp.int32(state.color == 1),
)
# castling
board = state.board
is_queen_side_castling = (piece == KING) & (a.from_ == 32) & (a.to == 16)
board = lax.select(is_queen_side_castling, board.at[0].set(EMPTY).at[24].set(ROOK), board)
bb = lax.select(is_queen_side_castling, set_bb(set_bb(state.bb, 0, EMPTY), 24, ROOK), state.bb)
is_king_side_castling = (piece == KING) & (a.from_ == 32) & (a.to == 48)
board = lax.select(is_king_side_castling, board.at[56].set(EMPTY).at[40].set(ROOK), board)
state = state._replace(board=board)
bb = lax.select(is_king_side_castling, set_bb(set_bb(bb, 56, EMPTY), 40, ROOK), bb)
state = state._replace(bb=bb)
# update castling rights
cond = jnp.bool_([[(a.from_ != 32) & (a.from_ != 0), (a.from_ != 32) & (a.from_ != 56)], [a.to != 7, a.to != 63]])
state = state._replace(castling_rights=state.castling_rights & cond)
Expand All @@ -289,7 +348,7 @@ def _apply_move(state: GameState, a: Action) -> GameState:
# underpromotion
piece = lax.select(a.underpromotion < 0, piece, jnp.int32([ROOK, BISHOP, KNIGHT])[a.underpromotion])
# actually move
state = state._replace(board=state.board.at[a.from_].set(EMPTY).at[a.to].set(piece)) # type: ignore
state = state._replace(bb=set_bb(set_bb(state.bb, a.from_, EMPTY), a.to, piece))
return state


Expand All @@ -299,7 +358,7 @@ def _flip_pos(x: Array): # e.g., 37 <-> 34, -1 <-> -1

def _flip(state: GameState) -> GameState:
return state._replace(
board=-jnp.flip(state.board.reshape(8, 8), axis=1).flatten(),
bb=flip_bb(state.bb)[::-1],
color=(state.color + 1) % 2,
en_passant=_flip_pos(state.en_passant),
castling_rights=state.castling_rights[::-1],
Expand All @@ -309,14 +368,14 @@ def _flip(state: GameState) -> GameState:

def _legal_action_mask(state: GameState) -> Array:
def legal_normal_moves(from_):
piece = state.board[from_]
piece = get_bb(state.bb, from_)

def legal_label(to):
ok = (from_ >= 0) & (piece > 0) & (to >= 0) & (state.board[to] <= 0)
ok = (from_ >= 0) & (piece > 0) & (to >= 0) & (get_bb(state.bb, to) <= 0)
between_ixs = BETWEEN[from_, to]
ok &= CAN_MOVE[piece, from_, to] & ((between_ixs < 0) | (state.board[between_ixs] == EMPTY)).all()
ok &= CAN_MOVE[piece, from_, to] & ((between_ixs < 0) | (jax.vmap(get_bb, in_axes=(None, 0))(state.bb, between_ixs) == EMPTY)).all()
c0, c1 = from_ // 8, to // 8
pawn_should = ((c1 == c0) & (state.board[to] == EMPTY)) | ((c1 != c0) & (state.board[to] < 0))
pawn_should = ((c1 == c0) & (get_bb(state.bb, to) == EMPTY)) | ((c1 != c0) & (get_bb(state.bb, to) < 0))
ok &= (piece != PAWN) | pawn_should
return lax.select(ok, Action(from_=from_, to=to)._to_label(), -1)

Expand All @@ -326,7 +385,7 @@ def legal_en_passants():
to = state.en_passant

def legal_labels(from_):
ok = (from_ >= 0) & (from_ < 64) & (to >= 0) & (state.board[from_] == PAWN) & (state.board[to - 1] == -PAWN)
ok = (from_ >= 0) & (from_ < 64) & (to >= 0) & (get_bb(state.bb, from_) == PAWN) & (get_bb(state.bb, to - 1) == -PAWN)
a = Action(from_=from_, to=to)
return lax.select(ok, a._to_label(), -1)

Expand All @@ -339,15 +398,15 @@ def is_not_checked(label):
def legal_underpromotions(mask):
def legal_labels(label):
a = Action._from_label(label)
ok = (state.board[a.from_] == PAWN) & (a.to >= 0)
ok = (get_bb(state.bb, a.from_) == PAWN) & (a.to >= 0)
ok &= mask[Action(from_=a.from_, to=a.to)._to_label()]
return lax.select(ok, label, -1)

labels = jnp.int32([from_ * 73 + i for i in range(9) for from_ in [6, 14, 22, 30, 38, 46, 54, 62]])
return jax.vmap(legal_labels)(labels)

# normal move and en passant
possible_piece_positions = jnp.nonzero(state.board > 0, size=16, fill_value=-1)[0]
possible_piece_positions = jnp.nonzero(to_board(state.bb) > 0, size=16, fill_value=-1)[0]
a1 = jax.vmap(legal_normal_moves)(possible_piece_positions).flatten()
a2 = legal_en_passants()
actions = jnp.hstack((a1, a2)) # include -1
Expand All @@ -356,11 +415,11 @@ def legal_labels(label):
mask = mask.at[actions].set(True)

# castling
b = state.board
bb = state.bb
can_castle_queen_side = state.castling_rights[0, 0]
can_castle_queen_side &= (b[0] == ROOK) & (b[8] == EMPTY) & (b[16] == EMPTY) & (b[24] == EMPTY) & (b[32] == KING)
can_castle_queen_side &= (get_bb(bb, 0) == ROOK) & (get_bb(bb, 8) == EMPTY) & (get_bb(bb, 16) == EMPTY) & (get_bb(bb, 24) == EMPTY) & (get_bb(bb, 32) == KING)
can_castle_king_side = state.castling_rights[0, 1]
can_castle_king_side &= (b[32] == KING) & (b[40] == EMPTY) & (b[48] == EMPTY) & (b[56] == ROOK)
can_castle_king_side &= (get_bb(bb, 32) == KING) & (get_bb(bb, 40) == EMPTY) & (get_bb(bb, 48) == EMPTY) & (get_bb(bb, 56) == ROOK)
not_checked = ~jax.vmap(_is_attacked, in_axes=(None, 0))(state, jnp.int32([16, 24, 32, 40, 48]))
mask = mask.at[2364].set(mask[2364] | (can_castle_queen_side & not_checked[:3].all()))
mask = mask.at[2367].set(mask[2367] | (can_castle_king_side & not_checked[2:].all()))
Expand All @@ -374,16 +433,16 @@ def legal_labels(label):

def _is_attacked(state: GameState, pos: Array):
def attacked_far(to):
ok = (to >= 0) & (state.board[to] < 0) # should be opponent's
piece = jnp.abs(state.board[to])
ok = (to >= 0) & (get_bb(state.bb, to) < 0) # should be opponent's
piece = jnp.abs(get_bb(state.bb, to))
ok &= (piece == QUEEN) | (piece == ROOK) | (piece == BISHOP)
between_ixs = BETWEEN[pos, to]
ok &= CAN_MOVE[piece, pos, to] & ((between_ixs < 0) | (state.board[between_ixs] == EMPTY)).all()
ok &= CAN_MOVE[piece, pos, to] & ((between_ixs < 0) | (jax.vmap(get_bb, in_axes=(None, 0))(state.bb, between_ixs) == EMPTY)).all()
return ok

def attacked_near(to):
ok = (to >= 0) & (state.board[to] < 0) # should be opponent's
piece = jnp.abs(state.board[to])
ok = (to >= 0) & (get_bb(state.bb, to) < 0) # should be opponent's
piece = jnp.abs(get_bb(state.bb, to))
ok &= CAN_MOVE[piece, pos, to]
ok &= ~((piece == PAWN) & (to // 8 == pos // 8)) # should move diagonally to capture
return ok
Expand All @@ -394,13 +453,13 @@ def attacked_near(to):


def _is_checked(state: GameState):
king_pos = jnp.argmin(jnp.abs(state.board - KING))
king_pos = jnp.argmin(jnp.abs(to_board(state.bb) - KING))
return _is_attacked(state, king_pos)


def _zobrist_hash(state: GameState) -> Array:
hash_ = lax.select(state.color == 0, ZOBRIST_SIDE, jnp.zeros_like(ZOBRIST_SIDE))
to_reduce = ZOBRIST_BOARD[jnp.arange(64), state.board + 6] # 0, ..., 12 (w:pawn, ..., b:king)
to_reduce = ZOBRIST_BOARD[jnp.arange(64), to_board(state.bb) + 6] # 0, ..., 12 (w:pawn, ..., b:king)
hash_ ^= lax.reduce(to_reduce, 0, lax.bitwise_xor, (0,))
to_reduce = jnp.where(state.castling_rights.reshape(-1, 1), ZOBRIST_CASTLING, 0)
hash_ ^= lax.reduce(to_reduce, 0, lax.bitwise_xor, (0,))
Expand Down
6 changes: 3 additions & 3 deletions pgx/experimental/chess.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import jax.numpy as jnp
import numpy as np

from pgx._src.games.chess import Game, GameState, _flip_pos, _legal_action_mask, _update_history
from pgx._src.games.chess import Game, GameState, _flip_pos, _legal_action_mask, _update_history, to_bitboard, to_board
from pgx.chess import State

TRUE = jnp.bool_(True)
Expand Down Expand Up @@ -59,7 +59,7 @@ def from_fen(fen: str):
if color == "b" and ep >= 0:
ep = _flip_pos(ep)
x = GameState(
board=jnp.rot90(mat, k=3).flatten(),
bb=to_bitboard(jnp.rot90(mat, k=3).flatten()),
color=jnp.int32(0) if color == "w" else jnp.int32(1),
castling_rights=castling_rights,
en_passant=ep,
Expand Down Expand Up @@ -103,7 +103,7 @@ def to_fen(state: State):
>>> _to_fen(_from_fen("rnbqkbnr/pppppppp/8/8/8/P7/1PPPPPPP/RNBQKBNR b KQkq e3 0 1"))
'rnbqkbnr/pppppppp/8/8/8/P7/1PPPPPPP/RNBQKBNR b KQkq e3 0 1'
"""
pb = np.rot90(state._x.board.reshape(8, 8), k=1)
pb = np.rot90(to_board(state._x.bb).reshape(8, 8), k=1)
if state._x.color == 1:
pb = -np.flip(pb, axis=0)
fen = ""
Expand Down
Loading
Loading