diff --git a/mkdocs.yml b/mkdocs.yml index d64884f..e347015 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -42,7 +42,9 @@ nav: - Lateral Control (state-based): notebooks/lateral_control_state_based_notebook.ipynb - Lateral Control (Riccati): notebooks/lateral_control_riccati_notebook.ipynb - Graph Search: notebooks/a_star_notebook.ipynb - - Decision Making: notebooks/mdp.ipynb + - Decision Making: + - Value Iteration: notebooks/mdp_value_iteration.ipynb + - Q-Learning: notebooks/mdp_q_learning.ipynb - API Documentation (partial): reference/ plugins: diff --git a/notebooks/mdp_q_learning.ipynb b/notebooks/mdp_q_learning.ipynb new file mode 100644 index 0000000..c73ec74 --- /dev/null +++ b/notebooks/mdp_q_learning.ipynb @@ -0,0 +1,238 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from behavior_generation_lecture_python.mdp.mdp import (\n", + " MDP,\n", + " GridMDP,\n", + " expected_utility_of_action,\n", + " derive_policy,\n", + " q_learning,\n", + " GRID_MDP_DICT,\n", + " HIGHWAY_MDP_DICT,\n", + " LC_RIGHT_ACTION,\n", + " STAY_IN_LANE_ACTION,\n", + ")\n", + "from behavior_generation_lecture_python.utils.grid_plotting import (\n", + " make_plot_grid_step_function,\n", + " make_plot_policy_step_function,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## TOY EXAMPLE" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "grid_mdp = GridMDP(**GRID_MDP_DICT)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "computed_utility_history = q_learning(\n", + " mdp=grid_mdp, alpha=0.1, epsilon=0.1, iterations=10000, return_history=True\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%matplotlib inline\n", + "plot_grid_step = make_plot_grid_step_function(\n", + " columns=4, rows=3, U_over_time=computed_utility_history\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "scrolled": false + }, + "outputs": [], + "source": [ + "mkdocs_flag = False # set to true if you are running the notebook locally\n", + "if mkdocs_flag:\n", + " import ipywidgets\n", + " from IPython.display import display\n", + "\n", + " iteration_slider = ipywidgets.IntSlider(\n", + " min=0, max=len(computed_utility_history) - 1, step=1, value=0\n", + " )\n", + " w = ipywidgets.interactive(plot_grid_step, iteration=iteration_slider)\n", + " display(w)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plot_grid_step(1000)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## HIGHWAY EXAMPLE" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "if False:\n", + " # we will change this to true later on, to see the effect\n", + " HIGHWAY_MDP_DICT[\"transition_probabilities_per_action\"][LC_RIGHT_ACTION] = [\n", + " (0.4, LC_RIGHT_ACTION),\n", + " (0.6, STAY_IN_LANE_ACTION),\n", + " ]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "highway_mdp = GridMDP(**HIGHWAY_MDP_DICT)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "utility_history_highway = q_learning(\n", + " mdp=highway_mdp, alpha=0.1, epsilon=0.1, iterations=10000, return_history=True\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plot_grid_step_highway = make_plot_grid_step_function(\n", + " columns=10, rows=4, U_over_time=utility_history_highway\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "if mkdocs_flag:\n", + " iteration_slider = ipywidgets.IntSlider(\n", + " min=0, max=len(utility_history_highway) - 1, step=1, value=0\n", + " )\n", + " w = ipywidgets.interactive(plot_grid_step_highway, iteration=iteration_slider)\n", + " display(w)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plot_grid_step_highway(1000)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "policy_array = [\n", + " derive_policy(highway_mdp, utility) for utility in utility_history_highway\n", + "]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plot_policy_step_highway = make_plot_policy_step_function(\n", + " columns=10, rows=4, policy_over_time=policy_array\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "if mkdocs_flag:\n", + " iteration_slider = ipywidgets.IntSlider(\n", + " min=0, max=len(utility_history_highway) - 1, step=1, value=0\n", + " )\n", + " w = ipywidgets.interactive(plot_policy_step_highway, iteration=iteration_slider)\n", + " display(w)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plot_policy_step_highway(1000)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.12" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/notebooks/mdp.ipynb b/notebooks/mdp_value_iteration.ipynb similarity index 55% rename from notebooks/mdp.ipynb rename to notebooks/mdp_value_iteration.ipynb index d5845f6..fc91284 100644 --- a/notebooks/mdp.ipynb +++ b/notebooks/mdp_value_iteration.ipynb @@ -16,101 +16,13 @@ " HIGHWAY_MDP_DICT,\n", " LC_RIGHT_ACTION,\n", " STAY_IN_LANE_ACTION,\n", + ")\n", + "from behavior_generation_lecture_python.utils.grid_plotting import (\n", + " make_plot_grid_step_function,\n", + " make_plot_policy_step_function,\n", ")" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# From https://github.com/aimacode/aima-python\n", - "\"\"\"\n", - "The MIT License (MIT)\n", - "\n", - "Copyright (c) 2016 aima-python contributors\n", - "\n", - "Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the \"Software\"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:\n", - "\n", - "The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.\n", - "\n", - "THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.\n", - "\"\"\"\n", - "from collections import defaultdict\n", - "import matplotlib.pyplot as plt\n", - "\n", - "\n", - "def make_plot_grid_step_function(columns, rows, U_over_time, show=True):\n", - " \"\"\"ipywidgets interactive function supports single parameter as input.\n", - " This function creates and return such a function by taking as input\n", - " other parameters.\"\"\"\n", - "\n", - " def plot_grid_step(iteration):\n", - " data = U_over_time[iteration]\n", - " data = defaultdict(lambda: 0, data)\n", - " grid = []\n", - " for row in range(rows):\n", - " current_row = []\n", - " for column in range(columns):\n", - " current_row.append(data[(column, row)])\n", - " grid.append(current_row)\n", - " grid.reverse() # output like book\n", - " grid = [[-200 if y is None else y for y in x] for x in grid]\n", - " fig = plt.imshow(grid, cmap=plt.cm.bwr, interpolation=\"nearest\")\n", - "\n", - " plt.axis(\"off\")\n", - " fig.axes.get_xaxis().set_visible(False)\n", - " fig.axes.get_yaxis().set_visible(False)\n", - "\n", - " for col in range(len(grid)):\n", - " for row in range(len(grid[0])):\n", - " magic = grid[col][row]\n", - " fig.axes.text(\n", - " row, col, \"{0:.2f}\".format(magic), va=\"center\", ha=\"center\"\n", - " )\n", - " if show:\n", - " plt.show()\n", - "\n", - " return plot_grid_step" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import numpy as np\n", - "\n", - "\n", - "def make_plot_policy_step_function(columns, rows, policy_over_time, show=True):\n", - " def plot_grid_step(iteration):\n", - " data = policy_over_time[iteration]\n", - " for row in range(rows):\n", - " for col in range(columns):\n", - " if not (col, row) in data:\n", - " continue\n", - " x = col + 0.5\n", - " y = row + 0.5\n", - " if data[(col, row)] is None:\n", - " plt.scatter([x], [y], color=\"black\")\n", - " continue\n", - " dx = data[(col, row)][0]\n", - " dy = data[(col, row)][1]\n", - " scaling = np.sqrt(dx**2.0 + dy**2.0) * 2.5\n", - " dx /= scaling\n", - " dy /= scaling\n", - " plt.arrow(x, y, dx, dy)\n", - " plt.axis(\"equal\")\n", - " plt.xlim([0, columns])\n", - " plt.ylim([0, rows])\n", - " if show:\n", - " plt.show()\n", - "\n", - " return plot_grid_step" - ] - }, { "cell_type": "markdown", "metadata": {}, diff --git a/src/behavior_generation_lecture_python/mdp/mdp.py b/src/behavior_generation_lecture_python/mdp/mdp.py index e36859f..d2d3008 100644 --- a/src/behavior_generation_lecture_python/mdp/mdp.py +++ b/src/behavior_generation_lecture_python/mdp/mdp.py @@ -137,6 +137,19 @@ def get_transitions_with_probabilities( return [(0.0, state)] return self.transition_probabilities[(state, action)] + def sample_next_state(self, state, action) -> Any: + """Randomly sample the next state given the current state and taken action.""" + if self.is_terminal(state): + return ValueError("No next state for terminal states.") + if action is None: + return ValueError("Action must not be None.") + prob_per_transition = self.get_transitions_with_probabilities(state, action) + num_actions = len(prob_per_transition) + choice = np.random.choice( + num_actions, p=[ppa[0] for ppa in prob_per_transition] + ) + return prob_per_transition[choice][1] + class GridMDP(MDP): def __init__( @@ -333,3 +346,137 @@ def value_iteration( return utility_history return utility raise RuntimeError(f"Did not converge in {max_iterations} iterations") + + +def best_action_from_q_table( + *, state: Any, available_actions: Set[Any], q_table: Dict[Tuple[Any, Any], float] +) -> Any: + """Derive the best action from a Q table. + + Args: + state: The state in which to take an action. + available_actions: Set of available actions. + q_table: The Q table, mapping from state-action pair to value estimate. + + Returns: + The best action according to the Q table. + """ + available_actions = list(available_actions) + values = np.array([q_table[(state, action)] for action in available_actions]) + action = available_actions[np.argmax(values)] + return action + + +def random_action(available_actions: Set[Any]) -> Any: + """Derive a random action from the set of available actions. + + Args: + available_actions: Set of available actions. + + Returns: + A random action. + """ + available_actions = list(available_actions) + num_actions = len(available_actions) + choice = np.random.choice(num_actions) + return available_actions[choice] + + +def greedy_value_estimate_for_state( + *, q_table: Dict[Tuple[Any, Any], float], state: Any +) -> float: + """Compute the greedy (best possible) value estimate for a state from the Q table. + + Args: + state: The state for which to estimate the value, when being greedy. + q_table: The Q table, mapping from state-action pair to value estimate. + + Returns: + The value based on the greedy estimate. + """ + available_actions = [ + state_action[1] for state_action in q_table.keys() if state_action[0] == state + ] + return max([q_table[(state, action)] for action in available_actions]) + + +def q_learning( + *, + mdp: MDP, + alpha: float, + epsilon: float, + iterations: int, + return_history: Optional[bool] = False, +) -> Dict[Tuple[Any, Any], float]: + """Derive a value estimate for state-action pairs by means of Q learning. + + Args: + mdp: The underlying MDP. + alpha: Learning rate. + epsilon: Exploration-exploitation threshold. A random action is taken with + probability epsilon, the best action otherwise. + iterations: Number of iterations. + return_history: Whether to return the whole history of value estimates + instead of just the final estimate. + + Returns: + The final value estimate, if return_history is false. The + history of value estimates as list, if return_history is true. + """ + q_table = {} + for state in mdp.get_states(): + for action in mdp.get_actions(state): + q_table[(state, action)] = mdp.get_reward(state) + q_table_history = [] + state = mdp.initial_state + + np.random.seed(1337) + + for _ in range(iterations): + + # available actions: + avail_actions = mdp.get_actions(state) + + # choose action (exploration-exploitation trade-off) + rand = np.random.random() + if rand < (1 - epsilon): + chosen_action = best_action_from_q_table( + state=state, available_actions=avail_actions, q_table=q_table + ) + else: + chosen_action = random_action(avail_actions) + + # interact with environment + next_state = mdp.sample_next_state(state, chosen_action) + + # update Q table + greedy_value_estimate_next_state = greedy_value_estimate_for_state( + q_table=q_table, state=next_state + ) + q_table[(state, chosen_action)] = (1 - alpha) * q_table[ + (state, chosen_action) + ] + alpha * (mdp.get_reward(state) + greedy_value_estimate_next_state) + + if return_history: + q_table_history.append(q_table.copy()) + + if mdp.is_terminal(next_state): + state = mdp.initial_state # restart + else: + state = next_state # continue + + if return_history: + utility_history = [] + for q_tab in q_table_history: + utility_history.append( + { + state: greedy_value_estimate_for_state(q_table=q_tab, state=state) + for state in mdp.get_states() + } + ) + return utility_history + + return { + state: greedy_value_estimate_for_state(q_table=q_table, state=state) + for state in mdp.get_states() + } diff --git a/src/behavior_generation_lecture_python/utils/grid_plotting.py b/src/behavior_generation_lecture_python/utils/grid_plotting.py new file mode 100644 index 0000000..023c378 --- /dev/null +++ b/src/behavior_generation_lecture_python/utils/grid_plotting.py @@ -0,0 +1,80 @@ +from collections import defaultdict +import matplotlib.pyplot as plt +import numpy as np + + +def make_plot_grid_step_function(columns, rows, U_over_time, show=True): + """ipywidgets interactive function supports single parameter as input. + + This function creates and return such a function by taking as input + other parameters. + + from https://github.com/aimacode/aima-python + + The MIT License (MIT) + + Copyright (c) 2016 aima-python contributors + + Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + """ + + def plot_grid_step(iteration): + data = U_over_time[iteration] + data = defaultdict(lambda: 0, data) + grid = [] + for row in range(rows): + current_row = [] + for column in range(columns): + current_row.append(data[(column, row)]) + grid.append(current_row) + grid.reverse() # output like book + grid = [[-200 if y is None else y for y in x] for x in grid] + fig = plt.imshow(grid, cmap=plt.cm.bwr, interpolation="nearest") + + plt.axis("off") + fig.axes.get_xaxis().set_visible(False) + fig.axes.get_yaxis().set_visible(False) + + for col in range(len(grid)): + for row in range(len(grid[0])): + magic = grid[col][row] + fig.axes.text( + row, col, "{0:.2f}".format(magic), va="center", ha="center" + ) + if show: + plt.show() + + return plot_grid_step + + +def make_plot_policy_step_function(columns, rows, policy_over_time, show=True): + """Create a function that allows plotting a policy over time.""" + + def plot_grid_step(iteration): + data = policy_over_time[iteration] + for row in range(rows): + for col in range(columns): + if not (col, row) in data: + continue + x = col + 0.5 + y = row + 0.5 + if data[(col, row)] is None: + plt.scatter([x], [y], color="black") + continue + dx = data[(col, row)][0] + dy = data[(col, row)][1] + scaling = np.sqrt(dx**2.0 + dy**2.0) * 2.5 + dx /= scaling + dy /= scaling + plt.arrow(x, y, dx, dy) + plt.axis("equal") + plt.xlim([0, columns]) + plt.ylim([0, rows]) + if show: + plt.show() + + return plot_grid_step diff --git a/tests/test_mdp.py b/tests/test_mdp.py index 49f2d6c..1b6189b 100644 --- a/tests/test_mdp.py +++ b/tests/test_mdp.py @@ -1,3 +1,5 @@ +import pytest + from behavior_generation_lecture_python.mdp.mdp import ( GRID_MDP_DICT, MDP, @@ -6,6 +8,10 @@ derive_policy, expected_utility_of_action, value_iteration, + best_action_from_q_table, + random_action, + greedy_value_estimate_for_state, + q_learning, ) @@ -95,3 +101,37 @@ def test_value_iteration_history(): for state in true_utility_1.keys(): assert abs(true_utility_1[state] - computed_utility_history[1][state]) < epsilon + + +def test_best_action_from_q_table(): + q_table = {("A", 1): 0.5, ("A", 2): 0.6, ("B", 1): 0.7, ("B", 2): 0.8} + avail_actions = {1, 2} + assert ( + best_action_from_q_table( + state="A", available_actions=avail_actions, q_table=q_table + ) + == 2 + ) + + +def test_random_action(): + avail_actions = {1, 2} + for _ in range(10): + assert random_action(available_actions=avail_actions) in avail_actions + + +def test_greedy_value_estimate_for_state(): + q_table = {("A", 1): 0.5, ("A", 2): 0.6, ("B", 1): 0.7, ("B", 2): 0.8} + assert greedy_value_estimate_for_state(q_table=q_table, state="A") == 0.6 + assert greedy_value_estimate_for_state(q_table=q_table, state="B") == 0.8 + + +@pytest.mark.parametrize("return_history", (True, False)) +def test_q_learning(return_history): + assert q_learning( + mdp=GridMDP(**GRID_MDP_DICT), + alpha=0.1, + epsilon=0.1, + iterations=10000, + return_history=return_history, + ) diff --git a/tests/utils/test_grid_plotting.py b/tests/utils/test_grid_plotting.py new file mode 100644 index 0000000..78a3fdc --- /dev/null +++ b/tests/utils/test_grid_plotting.py @@ -0,0 +1,47 @@ +import matplotlib + +from behavior_generation_lecture_python.utils.grid_plotting import ( + make_plot_grid_step_function, + make_plot_policy_step_function, +) +from behavior_generation_lecture_python.mdp.mdp import ( + GRID_MDP_DICT, + GridMDP, + derive_policy, +) + +TRUE_UTILITY_GRID_MDP = { + (0, 0): 0.705, + (0, 1): 0.762, + (0, 2): 0.812, + (1, 0): 0.655, + (1, 2): 0.868, + (2, 0): 0.611, + (2, 1): 0.660, + (2, 2): 0.918, + (3, 0): 0.388, + (3, 1): -1.0, + (3, 2): 1.0, +} + + +def test_make_plot_grid_step_function(): + matplotlib.use("Agg") + + plot_grid_step = make_plot_grid_step_function( + columns=4, rows=3, U_over_time=[TRUE_UTILITY_GRID_MDP] + ) + plot_grid_step(0) + + +def test_make_plot_policy_step_function(): + matplotlib.use("Agg") + + policy_array = [ + derive_policy(GridMDP(**GRID_MDP_DICT), utility) + for utility in [TRUE_UTILITY_GRID_MDP] + ] + plot_policy_step = make_plot_policy_step_function( + columns=4, rows=3, policy_over_time=policy_array + ) + plot_policy_step(0)