Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Shogi] Make core singlefile #1273

Merged
merged 4 commits into from
Nov 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
280 changes: 267 additions & 13 deletions pgx/_src/games/shogi.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,26 +16,281 @@
from typing import NamedTuple
from functools import partial

import numpy as np
import jax
from jax import Array
import jax.numpy as jnp

from pgx._src.shogi_utils import (
AROUND_IX,
BETWEEN_IX,
CAN_MOVE,
CAN_MOVE_ANY,
INIT_PIECE_BOARD,
LEGAL_FROM_IDX,
NEIGHBOUR_IX,
)
from pgx._src.struct import dataclass
from pgx._src.types import Array

MAX_TERMINATION_STEPS = 512 # From AZ paper

TRUE = jnp.bool_(True)
FALSE = jnp.bool_(False)


EMPTY = -1 # 空白
PAWN = 0 # 歩
LANCE = 1 # 香
KNIGHT = 2 # 桂
SILVER = 3 # 銀
BISHOP = 4 # 角
ROOK = 5 # 飛
GOLD = 6 # 金
KING = 7 # 玉
PRO_PAWN = 8 # と
PRO_LANCE = 9 # 成香
PRO_KNIGHT = 10 # 成桂
PRO_SILVER = 11 # 成銀
HORSE = 12 # 馬
DRAGON = 13 # 龍


# fmt: off
INIT_PIECE_BOARD = jnp.int32([[15, -1, 14, -1, -1, -1, 0, -1, 1], # noqa: E241
[16, 18, 14, -1, -1, -1, 0, 5, 2], # noqa: E241
[17, -1, 14, -1, -1, -1, 0, -1, 3], # noqa: E241
[20, -1, 14, -1, -1, -1, 0, -1, 6], # noqa: E241
[21, -1, 14, -1, -1, -1, 0, -1, 7], # noqa: E241
[20, -1, 14, -1, -1, -1, 0, -1, 6], # noqa: E241
[17, -1, 14, -1, -1, -1, 0, -1, 3], # noqa: E241
[16, 19, 14, -1, -1, -1, 0, 4, 2], # noqa: E241
[15, -1, 14, -1, -1, -1, 0, -1, 1]]).flatten() # noqa: E241
# fmt: on


# Can <piece,14> reach from <from,81> to <to,81> ignoring pieces on board?
def can_move_to(piece, from_, to):
"""Can <piece> move from <from_> to <to>?"""
if from_ == to:
return False
x0, y0 = from_ // 9, from_ % 9
x1, y1 = to // 9, to % 9
dx = x1 - x0
dy = y1 - y0
if piece == PAWN:
if dx == 0 and dy == -1:
return True
else:
return False
elif piece == LANCE:
if dx == 0 and dy < 0:
return True
else:
return False
elif piece == KNIGHT:
if dx in (-1, 1) and dy == -2:
return True
else:
return False
elif piece == SILVER:
if dx in (-1, 0, 1) and dy == -1:
return True
elif dx in (-1, 1) and dy == 1:
return True
else:
return False
elif piece == BISHOP:
if dx == dy or dx == -dy:
return True
else:
return False
elif piece == ROOK:
if dx == 0 or dy == 0:
return True
else:
return False
if piece in (GOLD, PRO_PAWN, PRO_LANCE, PRO_KNIGHT, PRO_SILVER):
if dx in (-1, 0, 1) and dy in (0, -1):
return True
elif dx == 0 and dy == 1:
return True
else:
return False
elif piece == KING:
if abs(dx) <= 1 and abs(dy) <= 1:
return True
else:
return False
elif piece == HORSE:
if abs(dx) <= 1 and abs(dy) <= 1:
return True
elif dx == dy or dx == -dy:
return True
else:
return False
elif piece == DRAGON:
if abs(dx) <= 1 and abs(dy) <= 1:
return True
if dx == 0 or dy == 0:
return True
else:
return False
else:
assert False

Check warning on line 129 in pgx/_src/games/shogi.py

View check run for this annotation

Codecov / codecov/patch

pgx/_src/games/shogi.py#L129

Added line #L129 was not covered by tests


def is_on_the_way(piece, from_, to, point):
if to == point:
return False
if piece not in (LANCE, BISHOP, ROOK, HORSE, DRAGON):
return False

Check warning on line 136 in pgx/_src/games/shogi.py

View check run for this annotation

Codecov / codecov/patch

pgx/_src/games/shogi.py#L136

Added line #L136 was not covered by tests
if not can_move_to(piece, from_, to):
return False
if not can_move_to(piece, from_, point):
return False

x0, y0 = from_ // 9, from_ % 9
x1, y1 = to // 9, to % 9
x2, y2 = point // 9, point % 9
dx1, dy1 = x1 - x0, y1 - y0
dx2, dy2 = x2 - x0, y2 - y0

def sign(d):
if d == 0:
return 0
return d > 0

if (sign(dx1) != sign(dx2)) or (sign(dy1) != sign(dy2)):
return False

return abs(dx2) <= abs(dx1) and abs(dy2) <= abs(dy1)


CAN_MOVE = np.zeros((14, 81, 81), dtype=jnp.bool_)
for piece in range(14):
for from_ in range(81):
for to in range(81):
CAN_MOVE[piece, from_, to] = can_move_to(piece, from_, to)

assert CAN_MOVE.sum() == 8228
CAN_MOVE = jnp.array(CAN_MOVE)


# When <lance/bishop/rook/horse/dragon,5> moves from <from,81> to <to,81>,
# is <point,81> on the way between two points?
BETWEEN = np.zeros((5, 81, 81, 81), dtype=np.bool_)
for i, piece in enumerate((LANCE, BISHOP, ROOK, HORSE, DRAGON)):
for from_ in range(81):
for to in range(81):
for p in range(81):
BETWEEN[i, from_, to, p] = is_on_the_way(piece, from_, to, p)

BETWEEN = jnp.array(BETWEEN)
assert BETWEEN.sum() == 10564


# Give <dir,10> and <to,81>, return the legal <from> idx
# E.g. LEGAL_FROM_IDX[Up, to=19] = [20, 21, ..., -1] (filled by -1)
# Used for computing dlshogi action
#
# dir, to, from
# (10, 81, 81)
#
# 0 Up
# 1 Up left
# 2 Up right
# 3 Left
# 4 Right
# 5 Down
# 6 Down left
# 7 Down right
# 8 Up2 left
# 9 Up2 right

LEGAL_FROM_IDX = -np.ones((10, 81, 8), dtype=jnp.int32) # type: ignore

for dir_ in range(10):
for to in range(81):
x, y = to // 9, to % 9
if dir_ == 0: # Up
dx, dy = 0, +1
elif dir_ == 1: # Up left
dx, dy = -1, +1
elif dir_ == 2: # Up right
dx, dy = +1, +1
elif dir_ == 3: # Left
dx, dy = -1, 0
elif dir_ == 4: # Right
dx, dy = +1, 0
elif dir_ == 5: # Down
dx, dy = 0, -1
elif dir_ == 6: # Down left
dx, dy = -1, -1
elif dir_ == 7: # Down right
dx, dy = +1, -1
elif dir_ == 8: # Up2 left
dx, dy = -1, +2
elif dir_ == 9: # Up2 right
dx, dy = +1, +2
for i in range(8):
x += dx
y += dy
if x < 0 or 8 < x or y < 0 or 8 < y:
break
LEGAL_FROM_IDX[dir_, to, i] = x * 9 + y
if dir_ == 8 or dir_ == 9:
break

LEGAL_FROM_IDX = jnp.array(LEGAL_FROM_IDX) # type: ignore


@jax.jit
@jax.vmap
def can_move_any_ix(from_):
return jnp.nonzero(
(CAN_MOVE[:, from_, :] | CAN_MOVE[:, :, from_]).any(axis=0),
size=36,
fill_value=-1,
)[0]


@jax.jit
@jax.vmap
def neighbour_ix(from_):
return jnp.nonzero(
(CAN_MOVE[7, from_, :] | CAN_MOVE[2, :, from_]),
size=10,
fill_value=-1,
)[0]


NEIGHBOUR_IX = neighbour_ix(jnp.arange(81))


def between_ix(p, from_, to):
return jnp.nonzero(BETWEEN[p, from_, to], size=8, fill_value=-1)[0]


BETWEEN_IX = jax.jit(
jax.vmap(
jax.vmap(jax.vmap(between_ix, (None, None, 0)), (None, 0, None)),
(0, None, None),
)
)(jnp.arange(5), jnp.arange(81), jnp.arange(81))


CAN_MOVE_ANY = can_move_any_ix(jnp.arange(81)) # (81, 36)


def _around(c):
x, y = c // 9, c % 9
dx = jnp.int32([-1, -1, 0, +1, +1, +1, 0, -1])
dy = jnp.int32([0, -1, -1, -1, 0, +1, +1, +1])

def f(i):
new_x, new_y = x + dx[i], y + dy[i]
return jax.lax.select(
(new_x < 0) | (new_x >= 9) | (new_y < 0) | (new_y >= 9),
-1,
new_x * 9 + new_y,
)

return jax.vmap(f)(jnp.arange(8))


AROUND_IX = jax.vmap(_around)(jnp.arange(81))


EMPTY = jnp.int32(-1) # 空白
PAWN = jnp.int32(0) # 歩
LANCE = jnp.int32(1) # 香
Expand Down Expand Up @@ -94,8 +349,7 @@
return _legal_action_mask(state)


@dataclass
class Action:
class Action(NamedTuple):
is_drop: Array
piece: Array
to: Array
Expand Down
Loading
Loading