diff --git a/pgx/_mahjong/_mahjong2.py b/pgx/_mahjong/_mahjong2.py index c04d8647a..51995cdcf 100644 --- a/pgx/_mahjong/_mahjong2.py +++ b/pgx/_mahjong/_mahjong2.py @@ -255,7 +255,6 @@ def _init(rng: jax.random.KeyArray) -> State: def _step(state: State, action) -> State: - # TODO # - Actionの処理 # - ron, tsumo @@ -343,6 +342,7 @@ def _make_legal_action_mask(state: State, hand, c_p, new_tile): state.last_draw, state.riichi[c_p], FALSE, + _dora_array(state, state.riichi[c_p]), )[0].any() ) return legal_action_mask @@ -360,6 +360,7 @@ def _make_legal_action_mask_w_riichi(state, hand, c_p, new_tile): state.last_draw, state.riichi[c_p], FALSE, + _dora_array(state, state.riichi[c_p]), )[0].any() ) return legal_action_mask @@ -397,8 +398,9 @@ def _discard(state: State, tile: jnp.ndarray): ) def search(i, tpl): + # iは相対位置 meld_type, pon_player, kan_player, ron_player = tpl - player = (c_p + 1 + i) % 4 + player = (c_p + 1 + i) % 4 # 絶対位置 pon_player, meld_type = jax.lax.cond( Hand.can_pon(state.hand[player], tile), lambda: (i, jnp.max(jnp.array([2, meld_type]))), @@ -412,12 +414,13 @@ def search(i, tpl): ron_player, meld_type = jax.lax.cond( Hand.can_ron(state.hand[player], tile) & Yaku.judge( - state.hand[c_p], - state.melds[c_p], - state.n_meld[c_p], + state.hand[player], + state.melds[player], + state.n_meld[player], state.last_draw, - state.riichi[c_p], + state.riichi[player], FALSE, + _dora_array(state, state.riichi[player]), )[0].any(), lambda: (i, jnp.max(jnp.array([4, meld_type]))), lambda: (ron_player, meld_type), @@ -794,6 +797,7 @@ def _tsumo(state: State): state.target, state.riichi[c_p], is_ron=FALSE, + dora=_dora_array(state, state.riichi[c_p]), ) s1 = score + (-score) % 100 s2 = (score * 2) + (-(score * 2)) % 100 @@ -826,6 +830,7 @@ def _ron(state: State): state.target, state.riichi[c_p], is_ron=TRUE, + dora=_dora_array(state, state.riichi[c_p]), ) score = jax.lax.cond( (state.oya + state._round) % 4 == c_p, @@ -862,6 +867,39 @@ def _observe(state: State, player_id: jnp.ndarray) -> jnp.ndarray: return jnp.int8(0) +def _dora_array(state: State, riichi): + def next(tile): + return jax.lax.cond( + tile < 27, + lambda: tile // 9 + (tile + 1) % 9, + lambda: jax.lax.cond( + tile < 31, + lambda: 27 + (tile + 1) % 4, + lambda: 31 + (tile + 1) % 3, + ), + ) + + dora = jnp.zeros(34, dtype=jnp.bool_) + return jax.lax.cond( + riichi, + lambda: jax.lax.fori_loop( + 0, + state.n_kan + 1, + lambda i, arr: arr.at[next(state.deck[5 + 2 * i])] + .set(TRUE) + .at[next(state.doras[4 + 2 * i])] + .set(TRUE), + dora, + ), + lambda: jax.lax.fori_loop( + 0, + state.n_kan + 1, + lambda i, arr: arr.at[next(state.doras[5 + 2 * i])].set(TRUE), + dora, + ), + ) + + # For debug def _show_legal_action(legal_action): S = ["F", "T"] diff --git a/pgx/_mahjong/_yaku.py b/pgx/_mahjong/_yaku.py index 3aa8ade40..13569037b 100644 --- a/pgx/_mahjong/_yaku.py +++ b/pgx/_mahjong/_yaku.py @@ -72,9 +72,12 @@ def score( last: jnp.ndarray, riichi: jnp.ndarray, is_ron: jnp.ndarray, + dora: jnp.ndarray, ) -> int: """handはlast_tileを加えたもの""" - yaku, fan, fu = Yaku.judge(hand, melds, n_meld, last, riichi, is_ron) + yaku, fan, fu = Yaku.judge( + hand, melds, n_meld, last, riichi, is_ron, dora + ) score = fu << (fan + 2) return jax.lax.cond( fu == 0, @@ -250,10 +253,11 @@ def update( def judge( hand: jnp.ndarray, melds: jnp.ndarray, - n_meld, + n_meld: jnp.ndarray, last, riichi, is_ron, + dora, ): is_menzen = jax.lax.fori_loop( jnp.int8(0), @@ -543,7 +547,11 @@ def _update_yaku(suit, tpl): return jax.lax.cond( jnp.any(yakuman), lambda: (yakuman, 0, 0), - lambda: (yaku_best, jnp.dot(fan, yaku_best), fu_best), + lambda: ( + yaku_best, + jnp.dot(fan, yaku_best) + jnp.dot(flatten, dora), + fu_best, + ), ) @staticmethod diff --git a/tests/test_mahjong.py b/tests/test_mahjong.py index e2b39cf49..025220983 100644 --- a/tests/test_mahjong.py +++ b/tests/test_mahjong.py @@ -8,6 +8,9 @@ import jax from pgx.experimental.utils import act_randomly +TRUE = jnp.bool_(True) +FALSE = jnp.bool_(False) + env = Mahjong() init = jit(env.init) step = jit(env.step) @@ -103,7 +106,10 @@ def test_hand(): def test_score(): - # 平和ツモ + # 平和ツモドラ1 + # 参考: + # tobakushi.net/mahjang/cgi-bin/keisan.cgi?hai=02,03,04,05,06,11,12,13,14,15,16,21,21&naki=,,,&agari=01&dora=06,,,,,,,,,&tsumoron=0&honba=0&jifu=32&bafu=31&reach=0 + # fmt:off hand = jnp.int32([ 1, 1, 1, 1, 1, 1, 0, 0, 0, @@ -120,8 +126,9 @@ def test_score(): last=jnp.int8(0), riichi=jnp.bool_(False), is_ron=jnp.bool_(False), + dora=jnp.zeros(34, dtype=jnp.bool_).at[5].set(TRUE), ) - == 320 + == 640 ) # 国士無双 # fmt:off @@ -141,6 +148,7 @@ def test_score(): last=jnp.int8(33), riichi=jnp.bool_(False), is_ron=jnp.bool_(False), + dora=jnp.zeros(34, dtype=jnp.bool_).at[5].set(TRUE), ) == 8000 ) @@ -163,6 +171,7 @@ def test_score(): last=jnp.int8(27), riichi=jnp.bool_(False), is_ron=jnp.bool_(False), + dora=jnp.zeros(34, dtype=jnp.bool_).at[5].set(TRUE), ) == 800 )