Skip to content

Commit

Permalink
commit from colab
Browse files Browse the repository at this point in the history
  • Loading branch information
sotetsuk committed Oct 16, 2023
1 parent 6b5a6cf commit 81e416d
Showing 1 changed file with 32 additions and 24 deletions.
56 changes: 32 additions & 24 deletions colab/check_chess.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": []
"provenance": [],
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
Expand All @@ -14,6 +15,16 @@
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/github/sotetsuk/pgx/blob/sotetsuk%2Fcolab%2Fupdate-check-chess/colab/check_chess.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"source": [
Expand All @@ -24,7 +35,7 @@
"base_uri": "https://localhost:8080/"
},
"id": "qSaJpiijsElM",
"outputId": "c28f8b7d-d278-4932-8946-25c1baf85c64"
"outputId": "84a350c9-d84e-4e56-9376-c6e8b098460c"
},
"execution_count": 1,
"outputs": [
Expand All @@ -34,24 +45,24 @@
"text": [
"Collecting open_spiel\n",
" Downloading open_spiel-1.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (5.4 MB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m5.4/5.4 MB\u001b[0m \u001b[31m10.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m5.4/5.4 MB\u001b[0m \u001b[31m9.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hCollecting pgx\n",
" Downloading pgx-1.1.0-py3-none-any.whl (410 kB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m410.7/410.7 kB\u001b[0m \u001b[31m14.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
" Downloading pgx-1.4.0-py3-none-any.whl (413 kB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m413.9/413.9 kB\u001b[0m \u001b[31m35.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hRequirement already satisfied: pip>=20.0.2 in /usr/local/lib/python3.10/dist-packages (from open_spiel) (23.1.2)\n",
"Requirement already satisfied: attrs>=19.3.0 in /usr/local/lib/python3.10/dist-packages (from open_spiel) (23.1.0)\n",
"Requirement already satisfied: absl-py>=0.10.0 in /usr/local/lib/python3.10/dist-packages (from open_spiel) (1.4.0)\n",
"Requirement already satisfied: numpy>=1.21.5 in /usr/local/lib/python3.10/dist-packages (from open_spiel) (1.22.4)\n",
"Requirement already satisfied: scipy>=1.10.1 in /usr/local/lib/python3.10/dist-packages (from open_spiel) (1.10.1)\n",
"Requirement already satisfied: jax>=0.3.25 in /usr/local/lib/python3.10/dist-packages (from pgx) (0.4.13)\n",
"Requirement already satisfied: numpy>=1.21.5 in /usr/local/lib/python3.10/dist-packages (from open_spiel) (1.23.5)\n",
"Requirement already satisfied: scipy>=1.10.1 in /usr/local/lib/python3.10/dist-packages (from open_spiel) (1.11.3)\n",
"Requirement already satisfied: jax>=0.3.25 in /usr/local/lib/python3.10/dist-packages (from pgx) (0.4.16)\n",
"Collecting svgwrite (from pgx)\n",
" Downloading svgwrite-1.4.3-py3-none-any.whl (67 kB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m67.1/67.1 kB\u001b[0m \u001b[31m5.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hRequirement already satisfied: typing-extensions in /usr/local/lib/python3.10/dist-packages (from pgx) (4.7.1)\n",
"Requirement already satisfied: ml-dtypes>=0.1.0 in /usr/local/lib/python3.10/dist-packages (from jax>=0.3.25->pgx) (0.2.0)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m67.1/67.1 kB\u001b[0m \u001b[31m6.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hRequirement already satisfied: typing-extensions in /usr/local/lib/python3.10/dist-packages (from pgx) (4.5.0)\n",
"Requirement already satisfied: ml-dtypes>=0.2.0 in /usr/local/lib/python3.10/dist-packages (from jax>=0.3.25->pgx) (0.3.1)\n",
"Requirement already satisfied: opt-einsum in /usr/local/lib/python3.10/dist-packages (from jax>=0.3.25->pgx) (3.3.0)\n",
"Installing collected packages: svgwrite, open_spiel, pgx\n",
"Successfully installed open_spiel-1.3 pgx-1.1.0 svgwrite-1.4.3\n"
"Successfully installed open_spiel-1.3 pgx-1.4.0 svgwrite-1.4.3\n"
]
}
]
Expand All @@ -76,7 +87,7 @@
"height": 53
},
"id": "V1DUFO14sHXC",
"outputId": "18c750e6-adbc-47c3-8007-d1a1e6271b98"
"outputId": "94c1d971-d7eb-4dfd-ec1f-244c542c23ce"
},
"execution_count": 2,
"outputs": [
Expand All @@ -91,7 +102,7 @@
"output_type": "execute_result",
"data": {
"text/plain": [
"'1.1.0'"
"'1.4.0'"
],
"application/vnd.google.colaboratory.intrinsic+json": {
"type": "string"
Expand All @@ -115,11 +126,11 @@
"init = jax.jit(env.init)\n",
"step = jax.jit(env.step)\n",
"\n",
"\n",
"def check(seed):\n",
" np.random.seed(seed)\n",
" spiel_state = game.new_initial_state()\n",
" pgx_state = init(jax.random.PRNGKey(0)) # seed is not related\n",
" action_seq = []\n",
" for _ in range(512): # pgx chess terminates after 512 steps (following AZ paper)\n",
" fen_before = spiel_state.debug_string().splitlines()[0][5:]\n",
"\n",
Expand All @@ -132,44 +143,41 @@
" for a in expected_legal_actions:\n",
" ok = ok and pgx_state.legal_action_mask[a]\n",
"\n",
" if not ok:\n",
" print(\"legal action mask is different\")\n",
" legal_actions = jnp.nonzero(pgx_state.legal_action_mask)[0]\n",
" pgx_state.save_svg(\"failed.svg\")\n",
" assert False, f\"\\n{fen_before}\\n{pgx_state.legal_action_mask.sum()} != {len(expected_legal_actions)}\\nactual:{legal_actions}\\nexpected:{expected_legal_actions}\"\n",
" assert ok, f\"\\n{fen_before}\\n{pgx_state.legal_action_mask.sum()} != {len(expected_legal_actions)}\\nactual:{jnp.nonzero(pgx_state.legal_action_mask)[0]}\\nexpected:{expected_legal_actions}\\naction sequence: {action_seq}\"\n",
"\n",
" # step by OpenSpiel\n",
" action = np.random.choice(expected_legal_actions)\n",
" action_seq.append(action)\n",
" spiel_state.apply_action(action)\n",
" fen_after = spiel_state.debug_string().splitlines()[0][5:]\n",
"\n",
" # step by Pgx\n",
" pgx_state = step(pgx_state, jnp.int32(action))\n",
"\n",
" # check state transition\n",
" assert pgx_state._to_fen() == fen_after, f\"\\n{fen_before}\\nactual:{pgx_state._to_fen()}\\nexpected: {fen_after}\"\n"
" assert pgx_state._to_fen() == fen_after, f\"\\n{fen_before}\\nactual:{pgx_state._to_fen()}\\nexpected: {fen_after}\\naction sequence: {action_seq}\""
]
},
{
"cell_type": "code",
"source": [
"for i in tqdm(range(100)):\n",
"for i in tqdm(range(1000)):\n",
" check(i)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "tfD1p9bmsLhc",
"outputId": "e1bd986c-760c-4ca5-c65a-623f4479b9a7"
"outputId": "5ad9e0ef-4820-440f-df7b-18915410c689"
},
"execution_count": 5,
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"100%|██████████| 100/100 [53:31<00:00, 32.11s/it]\n"
"100%|██████████| 1000/1000 [3:05:04<00:00, 11.10s/it]\n"
]
}
]
Expand Down

0 comments on commit 81e416d

Please sign in to comment.