diff --git a/dataset analysis/Dataset analysis.ipynb b/code/Dataset analysis.ipynb similarity index 100% rename from dataset analysis/Dataset analysis.ipynb rename to code/Dataset analysis.ipynb diff --git a/code/main.ipynb b/code/main.ipynb index 10b40a0..0edb2f3 100644 --- a/code/main.ipynb +++ b/code/main.ipynb @@ -1,264 +1,280 @@ { - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "Fg5GvKa0qXkT" - }, - "source": [ - "# Установка нужных библиотек" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": { - "id": "e51DLLWEqXkW", - "outputId": "27094984-95e4-4301-937e-2f3d1bd7f9b7", - "colab": { - "base_uri": "https://localhost:8080/" - } - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "\u001b[33m DEPRECATION: A future pip version will change local packages to be built in-place without first copying to a temporary directory. We recommend you use --use-feature=in-tree-build to test your packages with this new behavior before it becomes the default.\n", - " pip 21.3 will remove support for this functionality. You can find discussion regarding this at https://github.com/pypa/pip/issues/7555.\u001b[0m\n", - " Building wheel for mylib (setup.py) ... \u001b[?25l\u001b[?25hdone\n" - ] - } - ], - "source": [ - "import warnings\n", - "warnings.filterwarnings(\"ignore\")\n", - "\n", - "try:\n", - " import google.colab\n", - " IN_COLAB = True\n", - "except:\n", - " IN_COLAB = False\n", - " \n", - "if IN_COLAB:\n", - " !git clone -qq https://github.com/Intelligent-Systems-Phystech/ProjectTemplate.git /tmp/repo\n", - " !python3 -m pip install -qq /tmp/repo/src/ && rm -rf /tmp/repo" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "7P4TWOOmqXkY" - }, - "source": [ - "# Импорт библиотек" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": { - "id": "4EVJmkwOqXkY" - }, - "outputs": [], - "source": [ - "import os\n", - "\n", - "from sklearn.linear_model import LogisticRegression\n", - "import matplotlib.pyplot as plt\n", - "\n", - "from mylib.train import cv_parameters, Trainer, SyntheticBernuliDataset" - ] - }, - { - "cell_type": "markdown", - "source": [ - "# Настройка окружения" - ], - "metadata": { - "id": "stLbGQHDq6lS" - } - }, - { - "cell_type": "code", - "source": [ - "if IN_COLAB:\n", - " figures = '.'\n", - "else:\n", - " figures = '../figures'" - ], - "metadata": { - "id": "0TbwjK9Qq5Pg" - }, - "execution_count": 3, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "f2HeCQ89qXkZ" - }, - "source": [ - "# Работа с данными" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "dJJn3rfVqXka" - }, - "source": [ - "## Генерация синтетической выборки" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": { - "id": "OSQfsmRrqXka" - }, - "outputs": [], - "source": [ - "dataset = SyntheticBernuliDataset(n=10, m=100, seed=42)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "KBgjk1tvqXkb" - }, - "source": [ - "# Эксперимент с логистической регрессией" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "19nb_usNqXkc" - }, - "source": [ - "## Обучение одной модели" - ] - }, - { - "cell_type": "code", - "source": [ - "trainer = Trainer(\n", - " LogisticRegression(penalty='l1', solver='saga', C=1.0),\n", - " dataset.X, dataset.y,\n", - ")\n", - "\n", - "trainer.train()\n", - "print(trainer.eval())" - ], - "metadata": { - "id": "ZMK7mqNQZPXJ", - "outputId": "a95524d6-db85-4a34-9c36-fa2befec2f34", - "colab": { - "base_uri": "https://localhost:8080/" - } - }, - "execution_count": 5, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - " precision recall f1-score support\n", - "\n", - " 0 1.00 0.91 0.95 11\n", - " 1 0.93 1.00 0.97 14\n", - "\n", - " accuracy 0.96 25\n", - " macro avg 0.97 0.95 0.96 25\n", - "weighted avg 0.96 0.96 0.96 25\n", - "\n" - ] - } - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "g1Mq2ylfqXkd" - }, - "source": [ - "## Зависимость весов параметров от параметров регуляризации" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": { - "id": "HvHCcNPwqXkd" - }, - "outputs": [], - "source": [ - "Cs, accuracy, parameters = cv_parameters(dataset.X, dataset.y)" - ] - }, + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "Fg5GvKa0qXkT" + }, + "source": [ + "# Установка нужных библиотек" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "7P4TWOOmqXkY" + }, + "source": [ + "# Импорт библиотек" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "id": "4EVJmkwOqXkY", + "ExecuteTime": { + "end_time": "2024-04-01T17:21:48.882165Z", + "start_time": "2024-04-01T17:21:45.522872Z" + } + }, + "outputs": [ { - "cell_type": "code", - "execution_count": 7, - "metadata": { - "id": "LQL1mX1VqXke", - "outputId": "0868006d-d6c1-4504-9e66-d66af0fb56e9", - "colab": { - "base_uri": "https://localhost:8080/", - "height": 283 - } - }, - "outputs": [ - { - "output_type": "display_data", - "data": { - "text/plain": [ - "
" - ], - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYoAAAEKCAYAAAAMzhLIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nOzdd3hUZdrH8e8zM6kz6b0QQidAqIHQiyAgIujacMXuuuoWd9d1++u6Vd1mXVddda1rFxBsVOkt9FBCaAkJ6X0yk8mU5/0jkQ0hCRCTTMD7c125TGbOnHMnXsxvzlOV1hohhBCiNQZvFyCEEKJ7k6AQQgjRJgkKIYQQbZKgEEII0SYJCiGEEG0yebuAjhYZGamTk5O9XYYQQlxUduzYUaq1jmrpuUsuKJKTk8nIyPB2GUIIcVFRSuW09pw0PQkhhGiTBIUQQog2SVAIIYRokwSFEEKINklQCCGEaJMEhRBCiDZJUAghhGiT14JCKdVDKbVGKXVAKbVfKfVAC8copdTTSqkjSqm9SqmRnVVPlaOK5/c8z/6y/Z11CSGEuCh5c8KdC3hQa71TKRUE7FBKrdBaH2hyzBVAv8avdOBfjf/tcAZl4LndzwEwOGJwZ1xCCCEuSl67o9BaF2itdzZ+XwMcBBKaHTYfeF032AKEKqXiOqOeIN8geoX0Yl/pvs44vRBCXLS6RR+FUioZGAFsbfZUAnCyyc95nB0mKKXuUUplKKUySkpK2l1HamQqmaWZyK5/QgjxP14PCqWUBfgQ+JHWuro959Bav6i1TtNap0VFtbim1XlJjUylvK6cfGt+u88hhBCXGq8GhVLKh4aQeEtr/VELh+QDPZr8nNj4WKdIjUoFkOYnIYRowpujnhTwMnBQa/2PVg77GLi1cfTTWKBKa13QWTX1C+uHn9FPgkIIIZrw5qinCcAtwD6l1O7Gx34FJAForZ8HPgXmAEcAG3BHZxbkY/AhJTyFfSUSFEII8RWvBYXWegOgznGMBr7XNRU1SI1K5b2s93B6nPgYfLry0kII0S15vTO7u0mNTMXhdpBdke3tUoQQoluQoGgmNbKhQzuzNNPLlQghRPcgQdFMgiWBML8w9pbs9XYpQgjRLUhQNKOUIjUqVUY+CSFEIwmKFqRGpnK86jg19TXeLkUIIbxOgqIFqZGpaLSsJCuEEEhQtGhI5BAA9hTv8XIlQgjhfRIULQjxCyElPIVNpzZ5uxQhhPA6CYpG2qNxFttwW+sBmJgwkT0le6iub9c6hUIIccmQoGjkrnZQ9I8d2PeVAjApcRJu7Wbzqc1erkwIIbxLgqKRMcQPg9mH+jwr0NChHeQbxIb8DV6uTAghvEuCopFSCp8EC878hiGxJoOJ8fHj2ZC/AY/2eLk6IYTwHgmKJnwTLDiLbWinG2jopyi1l5JVnuXlyoQQwnskKJrwTbCAB+oLaoGGoACk+UkI8Y0mQdGET6IFAGd+Qz9FZEAkKeEpEhRCiG80CYomGjq0Tac7tOF/w2SrHFVerEwIIbxHgqKJhg7toNN3FNBkmGyBDJMVQnwzSVA009ChXXu6Qzs1MpUQvxDW5K7xcmVCCOEdEhTNNO/QNhlMzOw5kzUn12Bz2rxcnRBCdD0Jimaad2gDzO09F7vLzqrcVd4qSwghvEaCopmWOrSHRw8nwZLAsmPLvFiZEEJ4hwRFMy11aBuUgTm95rClYAslthIvVieEEF1PgqIFzTu0Aeb2mYtHe/js+GderEwIIbqeBEULmndoA/QO6c3giMHS/CSE+MaRoGhBSx3a0NCpfbD8IEcrj3qjLCGE8AqvBoVS6hWlVLFSKrOV56cqpaqUUrsbvx7uirqMIX4Yg31xHDtzNvbsXrMxKiNLjy7tijKEEKJb8PYdxavA7HMcs15rPbzx6/ddUBNKKfz6hVF3pBLt0acfjwyIZFLiJBYdWYTD7eiKUoQQwuu8GhRa63VAuTdraI1//1C03XVW89PNKTdTXlcundpCiG8Mb99RnI9xSqk9SqnPlFKDWzpAKXWPUipDKZVRUtIxw1f9+oaBgrrDFWc8nh6bTt/Qvrx18C201q28WgghLh3dPSh2Aj211sOAZ4DFLR2ktX5Ra52mtU6LiorqkAsbzT74xFuoyz4zKJRSfDvl2xwqP8SOoh0dci0hhOjOunVQaK2rtdbWxu8/BXyUUpFddX3/fmHU51bjqXOd8fjc3nMJ8QvhrYNvdVUpQgjhNd06KJRSsUop1fj9GBrqLeuq6/v1CwUPOI6eOfopwBTAtf2uZfXJ1eRb87uqHCGE8ApvD499G9gMDFBK5Sml7lJK3auUurfxkOuATKXUHuBpYIHuwo4Bv57BKF/DWc1PADcNvAmF4u2Db3dVOUII4RUmb15ca33TOZ5/Fni2i8qhzu1BAwHGhvxUJgN+vUNxtBAUseZYZibP5L3D73HHkDuICIjoqjKFEKJLdeump66UY3eQsmEfi4vPDAW/fqG4yupwldnPes19w+7D4XbwcubLXVWmEEJ0OQmKRkn+vgSZjHxZXnPG4/79wgCoy6486zW9Qnoxr8883j30LoW1hV1SpxBCdDUJiq84nUx0O/iytAp3k24QU1QAxnB/7Ada7kO/d9i9ePDw4t4Xu6pSIYToUhIUjVxlZQx+8Z9UeTS7q/+35alSisChUTiOVOC21p/1ugRLAtf2u5ZF2Ys4WXOyK0sWQoguIUHRyCcujnH2GpTWrGnW/BQ4PAo8YN9X2uJr7xl6D0aDkX/t/ldXlCqEEF1KgqKJ+JEjGJh7jDWlZ86b8Ik14xMbiG13y8uDRAdGszBlIUuPLWV38e6uKFUIIbqMBEUT5gnjGZ25m11WGxXOM2djBwyPpj6nGld5XYuvvWfoPcQExvCnrX/C7XG3eIwQQlyMJCiaMI8ezZis/XhQrKto1vw0tGENKduelu8qAn0C+enon3Ko/BDvHX6v02sVQoiuIkHRhMFsZkRYEEF1dtaUnRkUpnB/fHsGY9td3OrrZ/WcRXpcOs/seoYye5etNCKEEJ1KgqKZkPHjGbl/D2tKq85aRjxweBSuIhvOwtoWX6uU4ldjfoXdaeeJHU90RblCCNHpJCiaMU+YwOgDeyhyuTlUe2Z/REBqJBigdlfrdxW9Q3tz2+DbWHJ0CZvyN3V2uUII0ekkKJrxH5TC2JPHAVhRVn3Gc0aLL/4DI7BlFKGdnlbPcd/w++gV0otHNj+Ctd7a6nFCCHExkKBoRhmNJA9JYUjuMd4rLD+r+ckyPh5PrRPbntbvKvyMfvxhwh8oshXxjx3/6OyShRCiU0lQtMAyYQJXfLmCIzYHGU1maQP49QnBFBOIdeOpNrdCHRY1jFsH3cr7h99n86nNnV2yEEJ0GgmKFpjHj2fqzi0EeNy8XXDm6CWlFJYJ8TgLaqk/XtXKGRp8b/j3SA5O5rebfkt1fXWbxwohRHclQdECn/h4wvv0ZvqhfSwprqTWdeYEusDh0RgCTVg3nmrzPP4mf/488c+U2Er43abftXkHIoQQ3ZUERSuC513FrE8+otbtYUnJmUuMG3yNmMfEYj9Q1upM7a+kRqXyvRHfY3nOchYdWdSZJQshRKeQoGhF8Jw5DD5xlOQ6G+8UlJ/1vHlsPCiwbm77rgLgziF3kh6bzmPbHuNY1bHOKFcIITqNBEUrfKKjsYwbx5yNq9lWVUt2szkVplA/AlKjqN1a0OLy400ZlIE/T/ozfkY/fr7u5zjcjs4sXQghOpQERRtC5l3F9OXLMKF549TZS3IEz0hCOz3UrM0757miA6P544Q/cqj8EH/d/tfOKFcIITqFBEUbgmbMIMJZz6yiPN44VUZJvfOM532iAgkcEY11cwHu6nPfJUzpMYU7Bt/Bu1nv8umxTzurbCGE6FASFG0wmM0EXXYZN73xEg6Ph+dPnr1ybPD0JPB4qF5zfrvb/WDkDxgRPYLfbf4dx6uOd3TJQgjR4SQoziFk3lUkHD3MldrBf/JLKas/c58KU0QA5lGx1G4rxFXZ9ggoAB+DD3+Z/Bf8jH48uPZBbE7bOV8jhBDeJEFxDubx4zFGRrLws4+wuz28cPLspTuCpvcAoGb1+d1VxJpjeWzSYxytPMpvNv4Gj2593SghhPA2CYpzUD4+hC1YQPTSj5kbaOLl/FLKm+1+Zwr1xzI2jtrthdTnn98igOMTxvOTUT9hRc4KXtjzQmeULoQQHUKC4jyEffsmlK8vt679glq3h+dzz76rCJ7RE0OgD5UfH0V7zm8G9q2DbmVen3k8t+c5VuSs6OiyhRCiQ3g1KJRSryilipVSma08r5RSTyuljiil9iqlRnZ1jQCm8HBC5s8n6q03uCY0kBfySjhuO3OUkyHARMgVydTnVGNrY7+KppRSPDzuYYZGDeXXG35NZmmLfwYhhPAqb99RvArMbuP5K4B+jV/3AP/qgppaFH7brWiHgx9sX4uPUvwqO+/sHfBGxuDbI4iqz47jqXO1cqYz+Rn9eGraU4T7h/O9Vd8jpzqnM8oXQoh282pQaK3XAWevj/E/84HXdYMtQKhSKq5rqjuTX9++mCdPwvTaazzUI4o15TV8Wnrm6rHKoAid3wdPrZPqFef/hh8ZEMnzM55Ha813V3yXUntpR5cvhBDt5u07inNJAJoOJcprfOwMSql7lFIZSqmMkpKz5zp0lIg77sBdVsa1OzcxyOzPw9n51LrPXFnWNzEI85hYrJtO4cg5/6XFk0OSeXb6s5TXlXP/yvupqa/p6PKFEKJduntQnBet9Yta6zStdVpUVFSnXSdw7Fj8UlKofOFF/tw7lnyHkydOFJ11XMgVvTCG+FHx/mE89e4WztSyoVFD+duUv5Fdkc39K++XORZCiG6huwdFPtCjyc+JjY95hVKKqAd+iDM3lwErP2dBbDjPnyzmoNV+xnEGfxNh1/XHVWqn+osTF3SNyYmTeXzy4+wr3cf3V38fu8t+7hcJIUQn6u5B8TFwa+Pop7FAlda6wJsFWaZMIWDUKEr++U9+Ex9GsMnIQ1kn8TTr2PbvG4p5XFxDE9SxtnfCa25m8kz+NPFPZBRm8MPVP6TOde4Z30II0Vm8PTz2bWAzMEAplaeUukspda9S6t7GQz4FjgFHgH8D93up1NOUUkQ/+BPcJaXwzts80jeBjGpbi6vLhszuhTHMn/IPDuOxn98oqK9c2ftK/jDhD2wt2MoPV/9Q7iyEEF6jLrXtOdPS0nRGRkanX+fkffdjy8igz/IvWHCijL1WG+vHpBDj53PGcY6cakpe2Iv/wHAibklBKXVB11lyZAn/t/H/GB07mmcue4ZAn8CO/DWEEAIApdQOrXVaS89196anbivqRz/CY7VS9u+X+MuAHjg8ml+3MLfCr2cwIXN6UXegDOv6C+9emd93Pn+e9GcyijK4b+V91DprO+pXEEKI8yJB0U7+A/oTcs01lL/+OvEnT/DT5FiWlVTxYVHFWcdaJsQTMCSCqs+P4zh+Yf0VAHN7z+XxyY+zp2QP31n+HaocF34OIYRoLwmKryH6oZ9iDA6m4Df/x30JEaSHmPnl4TxO1p25NapSirDr+mMKD6Dsv4dwVV74Vqizk2fzxNQnyCrP4vbPb6fE1nnzRYQQoikJiq/BFBZGzK9+Rd2+fVS/9RZPpyShgR8cyMHdrAnK4G8iYmEKut5N6X8yz3uJj6amJU3juRnPkW/N59bPbuVkzfktay6EEF+HBEUjm83G6tWrKSi4sNG3wVfOwTxlMsVPPU1cWSl/6pfIlqpanmthhVmfWDMRt6TgKrFT9uZBtOvC96FIj0vnpZkvUV1fzcJPF7K/dP8Fn0MIIS6EBEUjg8HAhg0b2Ldv3wW9TilF3G9/C0Dhww9zfXQIc6NCeOx4AZsqzt6bwr9vGGHX9sNxpJKKj7LP6vw+H0OjhvLGnDcIMAVwxxd3sPbk2gs+hxBCnC8Jikb+/v707t2bgwcPXvCbt098PDE/e4jaTZuo+M9/+MfAJJL9/bhn/wlONeuvADCPiiF4RhK2ncVUfXq8XWHRO6Q3b855k14hvfjhmh/y34P/bdd5hBDiXCQomkhJSaGiooKiorPXbzqX0BtvJGjmTIqffAqfzH38J7UXdo+HuzJPUOc+u4kpaHpSw8zt9fnnvYVqc5EBkfxn1n+YnDCZR7c9yu+3/B6n29mucwkhRGskKJoYMGAAAAcPHrzg1yqliPvjH/CJiSH/wZ/Sx13PMylJ7Kqx8csW5lcopQi9qg+BI6OpXpFDzcb2LWEV6BPIk9Oe5O7Uu/ng8AfcvfxuyuxnzxIXQoj2kqBowmKxkJSUxKFDh9r1emNwMAl//xvOoiIKfv0brogM4cc9Y3i7oJync87u3FYGRdi1/fEfHEHV0mNYN59q33UNRh4Y+QCPT3qc/WX7uXHZjewu3t2ucwkhRHMSFI201hw4VU1S734UFRVRVta+T+UBw4cT/ZOfULNiBWXPP89DvWK5NiaMR48X8F7h2Xs0KaMi4qaB+A+KoHLJUWrW5bX7d5jTew5vXPEGPgYf7vj8Dt488Kb0WwghvjYJikZ5FXbmPL2eY85QgHbfVQCE33E7IfPnUfLU01hXruSJgT2YGGrhJ4dyWVt+9oZEymQg4uaBBAyNpOrT41Svym33tVMiUnj3qneZlDiJx7c/zoNrH5SZ3EKIr0WColGP8ED6RJlZn2sjNja2Xf0UX1FKEfv73+M/dCinfv4LPNnZvJLai36B/tyZeZyd1Wev16SMBsJvHEjgiIY+i8plx9Ce9t0NBPsG89S0p3hw1IOsyV3DdUuvY0fRjnb/PkKIbzYJiiamDYhm67Fy+vYfQF5eHtXV57+VaXMGPz8Sn30Go8XCyfvuI6C0hP8O602kj4mb9hwjs+bs3euUURF2ff+G0VAb8ql4L6tdk/KgIaxuH3I7b855E1+DL3d+cSdP73xaRkUJIS6YBEUT0wZGU+/2UBsQC7Rv9FNTPtHRJP7rOTxV1eTedTdRtlreH94Hs9HADXuOklV79oZEyqAIndeH4Fk9se0uofS1/XgcF77cx1cGRw7mvave46reV/Hvff/mpk9uIqs86+v8WkKIbxgJiiZGJ4dj9jWytcBJXFwc27Ztw+Np3yf6rwQMHkziv57DmZfHybu/Q4LbyQfD+2JSiht2H+FwS2GhFMHTkhpmcB+tpPi5PbjK2r9xkdnHzB8n/pGnpz1Nqb2UBZ8s4IU9L8jdhRDivEhQNOFrMjCxXyRfZpUwbtw4ysrKyM7O/trnNY8ZQ8KTT1CXlUXevffR0+PiveF98ADX7DrCfmvLIWAeHUvkHUNwV9dT/M/d1B2t/Fp1TEuaxqL5i5iRNINndz/LDctuYG/J3q91TiHEpU+CoplpA6I5VVWHT0QSISEhbNq0qUPOGzRtGvGPP4Zt1y5yb7+dvg47i0f0xc+guHbXEXZVn91nAeDfL4zo7w3HYPGh9OV91Kw/e/LehQjzD+OvU/7KM5c9Q019DQs/XcijWx+lpv7s0VhCCAESFGeZOiAagC+zyxg7diw5OTnk5bV/bkNTIVdeSeIzz+DIzibn5oUkVZazaERfgk1Grt99hA0VLb9Z+0QGEH3/cPwHRlD1yXHKXj+Ax/b1mo2m9pjK4vmLWTBwAW8fept5i+ex7NgymXchhDjLOYNCKeXfFYV0F7Eh/gyKC2ZNVjEjR47Ez8+PzZs3d9j5gy6bRtLLL+EqLeXETd8mJvcEi0f0JcHfl5v2HOODFiblQeN+FrekEDK3N3WHKyh6eheOE19vfoTF18Kv0n/F23PfJs4cxy/X/5I7v7iTQ+Xtn0MihLj0nM8dxTal1N+VUn07vZpuYtrAKHbkVFDnMTBq1CgOHDhARcXZW5y2V2BaGj3ffAOAnG/fTPD2rXw8oi+jQ8x8/2AuT+cUtfjJXilF0MQEou8dBgZFyQt7qVp+At3CooMXYnDEYN6c8yYPj3uYI5VHuGHpDTyy6RFK7aVf67xCiEvD+QTFcOBL4Aml1CdKqblKKdW5ZXnXtAHRuD2a9dklpKeno5Riw4YNHXoN/wEDSH73HXySkjh57314PvyAt4f15proUP58rIAHDuW2uOosgG+PIGJ+OILAkTHUrD5J8b/24CxuuY/jfBmUgev7X88n3/qEhYMWsuTIEuYumsuLe1/E7mr/iCshxMXvfIIiFNgP/A74CPgLcKwzi/K2EUlhRFr8WLzrFCEhIYwePZqdO3de8O535+ITG0vPN97AMnEihY/8jvLf/Z5n+sbxYHIM7xVW8K3dRyh0tNwXYfA3EX59f8JvTsFdXkfR07uoWZfX7tncXwn2DeZno3/GR/M/Ij02nWd2PcPcj+by4eEPcXnaP59DCHHxOp+gKAXeAG4A4oEXgT90ZlHeZjQobkhLZPWhIk5V2pk6dSoBAQF89tlnHd7Za7SYSfzns0R85ztUvvsuJxfewo/8NC8PSeZQbR2zMrLYWnn2TnlfCUyNJObHo/DvH0bVp8cpeWEvzpKvd3cB0CukF09d9hSvzX6NWEssj2x+hGuWXMNnxz/Do79eU5cQ4uJyPkGRBhwGUoEDwNNa61c6tapu4KYxSWjgne0nCQgIYMaMGeTm5l7wVqnnQ5lMRD/4ExKeeZr6Y8c4fu11TDmwh09G9iPAaOCaXUf4x4lC3K2ElDHIl4hbUgi7cQDOYhtFT+2kelVuu5f/aGpkzEjevOJNnpr2FCaDiZ+t+xnXL72elTkrJTCE+IY4Z1BorXdqre8AFgJ9gXVKqV91emVe1iM8kMn9onh3ey4ut4fhw4cTHx/P8uXLcTgcnXLN4MsvJ/mD9zFFRXHyu/cS/sxTfDG0F1fHhPGX44Vcv/so+S1srQoNHd3mEdHE/mQUAYMiqF6RQ9HTO3Ec//orxyqluCzpMj6c9yGPT3och9vBj7/8MdctvY4vTnwhgSHEJe58hseuVUplAOuB22jos7iuIy6ulJqtlMpSSh1RSv2ihedvV0qVKKV2N37d3RHXPV83pydRVO1g1aFiDAYDc+bMwWq1smbNmk67pl+vXiS//x5hN99M+WuvUb5wIf/wqefJgT3YXWNjyrZDvJZfiqetu4tvpxBx+2B0vYeSF/ZS/l4W7pqWA+ZCGJSBOb3nsGT+Eh6d9ChOt5Ofrv0pVy+5msVHFuP0yJIgQlyK1Lna3JVSPYFKoFp3YAO9UspIQ5PW5UAesB24SWt9oMkxtwNpWuvvn+9509LSdEZGRofU6HJ7mPj4GvrHBvH6nWMAWLZsGRkZGdxyyy306dOnQ67TmppVqyj49W/w1NYS+f3vU3vzQn565BTrK6yMCzXz9wFJ9A70a/X1nno3NatPUrM+D+VjIPjynljGxqGMHTPP0u1xsyJnBS/te4msiizizHHcOOBGru57NREBER1yDSFE11BK7dBap7X03Pk0PeVoras6MiQajQGOaK2Paa3rgXeA+R18ja/FZDRw4+gerM8uIbesoYN45syZREVFsWjRImprz95XoiMFTZ9O72VLsUybRskTT+C57VZeM9XxjwE92G+1M3XbIf5yvAB7K8NoDb5GQmYnE/OjkfgmBlG19BhFT+2k7nDHzAkxGozM7jWb9696n39O/yeJQYk8ufNJZnwwg4fWPsSWgi3SLCXEJcCbS3gkACeb/JzX+Fhz1yql9iqlPlBK9WjpREqpe5RSGUqpjJKSkg4tcsGYHijg9c0nAPD19eW6667DbrezePHiTl/ywhQZSeLTT5Hw5JM4CwvJuf4Gpv37WdYOjGdudCj/OFHEtO2HWF5a1WotPlGBRN41hIhbBqHdmtJXMil9dT/Ooo4JOqUUkxMn88qsV1gyfwkLBixg46mNfGf5d5jz0Rye3/M8BdaOHVoshOg652x66rQLK3UdMFtrfXfjz7cA6U2bmZRSEYBVa+1QSn0XuFFrfVlb5+3Ipqev/Pjd3XyWWcDah6YRE9ywosnWrVv57LPPmDVrFuPGjevQ67XGXV1NyTPPUvHWWxiDg4n60QNkzpjNr44WkG1zMDUsiEf6xTPQHNDqObTLg3VjPtWrT6Lr3ZhHxxJ8eU+MQb4dWmudq45VuatYlL2IrYVbUSjGxI1hfp/5TE+aTqBPYIdeTwjx9bTV9OTNoBgHPKK1ntX48y8BtNaPtnK8ESjXWoe0dd7OCIqcslqm/30tC8b04I9Xp9JYJ++++y5ZWVksWLCAAQMGdOg121KXlUXRH/6ILSMDv359CXvo53zQawB/O1GI1e3m5rgIHkyOJcbPp9VzuGud1KzKxbqlAGVSWCYkEDQlEYO/qcPrzavJY+nRpSw5uoR8az4BpgBmJM1gbu+5pMelYzQYO/yaQogL012DwkRDZ/Z0IJ+Gzuxva633NzkmTmtd0Pj9NcDPtdZj2zpvZwQFwK8X7ePd7SdZ/eBUkiIaPg07HA5effVVSktLuf3220lIaKnlrHNoralZuZLiv/4NZ24u5vHjMf34xzzrF8rrp0rxUQbu6RHF/T2iCPFp/c3fWWqnekUO9j0lqAATwVMTMY+Lx+Db8W/eHu1hZ9FOlh1bxvITy6lx1hAZEMms5FnM6TWH1MhULvHVYYTotrplUAAopeYATwJG4BWt9Z+UUr8HMrTWHyulHgXmAS6gHLhPa93m0qadFRRF1XVM/ssarkyN4x83Dj/9eE1NDS+99BIul4u7776bsLCwDr92Wzz19VT897+U/et53FVVBM+ZQ+393+MJh4FFxZWEmIx8t0cUdydGEWxq/c2/Pt9K9fIT1GVVYLD4EDSlB5axsSifzvm073A7WJe3jk+Pfcq6vHXUe+pJsCQwK3kWs5NnMzB8oISGEF2o2wZFZ+isoAB49NODvLj+GJ8/MJkBsUGnHy8uLuaVV14hMDCQ22+/neDg4E65flvcNTWUvfwy5a+9jq6vJ2T+fArvvJun7ZrPS6sJNRn5TmIUdyZGEtbGHYbjRBXVK3NxHKnEEORD0OQemNNjO+UO4ys19TWsyl3F58c/Z0vBFtzaTc/gnlze83Iu73k5KeEpEhpCdDIJig5SUVvP5L+sYXhSKK/fOeaMN6/c3FzefPNNLBaL18ICwFVaStm/X6LinXfQbjch8+dRcNudPO1QfLV8lkUAACAASURBVFFajdlo4Jb4CO7tEU1sG30YjmOVDYFxrKrhDmNSAuaxcRj8Or4Po6mKugpW5a5i+YnlbCvchlu7SbQkMqPnDKYnTWdo1FAMSvbbEqKjSVB0oNc2neC3H+/nHzcM41sjE894rruEBYCzqJiyf/+byvffRzudBM+ZQ8kdd/GiCmBxUQVGpbg2Jox7k6LaHCXlOFFF9apcHNmVqAATlnFxWCYkYDS3HjIdpbKuktUnV7MiZwVbCrbg8riICohiao+pXJZ0GWNix+Br7NjRWkJ8U0lQdCCPR3P9C5s5WmJl5U+mEGk5c2Z007BYuHAh4eHhnVbL+XCVlFD26qtUvP0O2mbDMmUKtXfexethsbxdUI7do7ksPIjv9ohmcpil1Sae+pM1VK85Sd2BMpSPAfPoWCwTEzCFd80GiDX1NazNW8vq3NVsyN+A3WXH7GNmQvwEpvaYyuTEyYT4tTkgTgjRBgmKDnakuIY5T21g5uAYnv32yLOeP3nyJG+99RZGo5GFCxcSFxfXqfWcD1dFBRVvv03FG2/irqjAf+hQ1G238+HAofynoIJSp4uBZn/uSYzimpgwAlpZ5sNZVEvN2jxse0pAawJSowialIBvYlCLx3cGh9vB1oKtrM5dzdq8tZTaSzEoA8OjhjOlxxQmJ0ymT2gf6dcQ4gJIUHSCZ1Zl8/cVh3nxllHMHBx71vMlJSW88cYb1NXVsWDBAnr37t3pNZ0PT10dVYsWUf7qa9Tn5GCKjcWycCFrps7kpfJaDtTWEWYyclNcBLclRNAzoOW1pFxVDqwb86ndWoh2uPFNDiZoYgL+gyJQhq57g/ZoD/tL97Pm5BrW568/vd93vDmeSYmTmJQwidGxo2WCnxDnIEHRCepdHuY9u4GSGgefPjDp9IztpqqqqnjrrbcoLS1lzpw5pKW1+P/AK7THg/XLtZS/9hq2rVtRfn4EXTWX7Otv4k1jIJ+VVuHRMDU8iFviI7g8IgSfFgLAU+eiNqMI68Z83BUOTJEBWCYlYB4Zg/Lp+k7nwtpC1uWtY0P+BrYUbMHusuNr8GVUzCgmJkxkYuJEegX3krsNIZqRoOgk2UU1zHt2I8N6hPDW3WMxtvBGarfb+fDDDzly5AhpaWnMnj0bk6lzRw5dqLqsLCrefIuqpUvRdXUEDBtG3U03s2TwcN4pqabA4STG18QNseEsiAunT+DZoajdGvv+UmrW5eHMs2Kw+GAZH485Pa5LOr5bUu+uJ6Mogw35G9iQv4HjVccBiDPHMSFhAuPjxzMmdoz0bQiBBEWn+mBHHj99fw8/nN6Pn1zev8VjPB4Pq1atYuPGjSQlJXHdddd5dURUa9xVVVQtXkzFu+9Rf+wYhuBgzFfNY/e8a3hP+bOqvBq3hvQQMzfEhnNVdOhZk/i01jiOVVGzNg/H4QqUj4HAtBiCJiZgimh9dFVXyLfmszF/I5tObWJLwRZqnbUYlIEhEUMYGz+WcXHjGBY1DB+jd4JNCG+SoOhkP31/Dx/uzOONO9OZ2C+y1eP27t3L0qVLMZlMXH311V26PtSF0Fpj276dynffo2bFCnR9Pf5DhlB/3fV8MSKd9yptHLE58DcoZkWGcG1MGFPDg/A1nNnU5CyspWZ9PrbdxeDR+KdENHR8Jwd7venH6XGSWZrJplOb2HRqE5mlmXi0hwBTAGkxaaTHpTM2biz9wvrJvA3xjSBB0cls9S7mP7uRUquDRfdPIDnS3OqxpaWlfPDBBxQWFjJ69GhmzpyJj0/3/QTrrqyk6uOlVH74IY6sLJSPD+bp08mZdw2fxCezuKSKCpebMJORudGhXBkVwvhQyxmh4a6ux7r5FLVbC/DYXPgkWLCMjydwWBTK1D3ehKvrq9lesJ3NBZvZUrCFnOocAML8wkiLTSM9Np3RsaPpFSL9G+LSJEHRBU6U1nLNcxsJC/Tlo/vHExrY+kQwl8vFypUr2bJlC5GRkXzrW98iPj6+C6ttn7qDB6n8aBHVS5firqzEGBZGwJVXsnfWlXwSGMrnZdXY3B6CjAYuiwhmZkQw0yKCCW9cMsRT78a2qxjrxnxcxXYMFh/M6XFY0uMwBneviXOFtYVsKdjC9sLtbC3YSpGtCIBw/3BGxYxiVMwoRkaPpH9Yf1n9VlwSJCi6yPYT5dz8762MSArljbvS8T3Hp+WjR4+yePFiamtrmTJlChMnTsRo7P5vOtrpxLp+A1Uff4x19Wp0fT0+iYn4zr2KzOkzWeUTyPKyakrqXRiAtBAzMyKCmR4RzCBzQ0e4I7sS68Z86rIqwKAISI3EMj4e36SgbveJXWvNyZqTZBRlkFGYQUZRBgW1DRsxmX3MDI8azojoEYyIHkFqVCoBJu/2xQjRHhIUXWjJ7nweeGc314xI4O/XD8NwjjkFdrudTz75hMzMTGJjY5k/f363mKB3vtw1NdSsXEX1smXUbt4MHg/+gwcTNH8eJ6bOYLVbsbK0mr1WOwBxfj5MCw9iangwk8MsWKqcDc1SGUVohxufeDOWcfEEDIvq1IUIv64CawE7i3eys2gnu0p2caTiCBqNSZkYGD6Q4dHDGRkzkhHRI4gMaL3fSojuQoKiiz27Opu/LT/Mt9OT+NPVQ87rE/KBAwf45JNPsNvtTJgwgcmTJ3frvouWuEpLqf7kE6qWfEzdgQNgNGIeP56QeVdhmziZL+tcrCqrZl1FDdUuDwZgZHAgU8KDmGwxM+ColbrNBbiKbKgAE+ZRMZjHxuET2f0/oVc5qthTsoddxbvYVbyLzNJMHG4HAImWREZEj2horooZSXJwcre7axJCgqKLaa356xdZPPflUW4fn8xvrxp0Xm8MNpuN5cuXs3v3bsLDw7nyyivp06dPF1Tc8RzZ2VQtXUb1smU4T51C+fsTdNk0gq+8Er+JE9lT52J1WTVrK2rYXW3DAwQZDYwPtTDOY2Lk4RridlegPBq/fqFY0uPwTwlHtbK0SHfjdDs5UH6A3cW7T4dHeV05ABH+EaTHpTM+fjxj48YSY47xcrVCSFB4hdaaP31ykJc2HOfuib349ZXnv6fCsWPHWLZsGeXl5QwZMoSZM2d2y3kX50N7PNh37qTqk0+o+fwL3BUVGIKCCJo+neA5V2AeN45KFOsrrKyvqGFdeQ05dfUARJtMjHEaGHGsllH5DnqYTFjSYjGPju2yxQg7itaaE9Un2FG0g+2F29lSsOV0cPQI6nH6jmN8/HhizWcvCSNEZ5Og8BKtNb9beoBXN53gulGJPPqtVHzO8xOx0+lk48aNrF+/HqPRyNSpU0lPT78oOrtbo51OajdvpvrTz6hZtQpPTQ2GkJCG0Jg9C/PYsShfX3LsDjZUWNlYaWVDRQ3F9S4A4lwwqsjJyAoX40IsDBgZS0BKxEVzl9GUR3s4XHGYrQVbG/o5indR4agAoG9oXyYmTGRc/DhGRo/E33RxhaK4OElQeJHWmqdWZfPkymym9I/iuZtHYr6AzX/Kysr4/PPPyc7OJjIyktmzZ9O3b99OrLhreOrrqd2wkZovPqdm1Wo8ViuG4GCCpk0jaNZMzBMmYPDzQ2tNts1xOjQ2lVupcLsBiLF7GFWjGRdmYfLgWPrHe38iX3tprTlWdYwN+RtYn7+eHUU7cHlc+Bh8GBE9gvS4dNLj0hkcMRiToXstASMuDRIU3cDb23L59aJ9DIoP5l83j6JH+IWtZpqVlcXnn39ORUUF/fv3Z+bMmURGXhqjaU6HxvLl1Kxejae6GkNgIObJkwm6fAaWKVMwWiwNx2pNVm0dmyusrD9ZzjabnbLGm6xwF6QF+DM+MYz08CCGWAJaXMjwYmBz2thRtIOtBVvZUrCFrIosoGE4blpMGmNix5Aely4zx0WHkaDoJlYfKuKBd3aDhseuHcqVQy9sGKzL5WLLli2sW7cOl8vF6NGjmTJlCoGBl84S2trppHbrNmpWrKBm1SrcpaUoHx8Cx44laPp0LJdNwyc6+n/Ha012cQ3rMgvZWlbDTrMiP7DhjTNAKUaGmBkTYiYtxExacCAhbewX3p2V15Wfnvy3tWAruTW5QMPM8dGxo0mPS2dM7Bh6Bve8aO+qhHdJUHQjJ8tt/ODtXew+WclNY5J4eO4gAi5wvoDVamX16tXs2rULX19fJk2aRHp6+kU3nPZctNuNffdualauombVKpy5DW+O/kOHEnTZZQRNvwzfvn1PvzFqj8ZxtJLjOwrYWlTF7hAje6J9OBygcDees3+gP6NDAhkVYiYt2EzfQD8MF+Eba4G1gK2FW093jBfbigGICYw5HRrpcenSMS7OmwRFN+N0e/jb8ixeWHuMftEWnr5pBClxFz6qqbi4mBUrVpCdnU1wcDBTpkxh+PDhF3WHd2u01jiys7GuXk3NqtXU7dsHgE9iIpZp0wiaNpXAtDSUb8NSIO5aJ/bdxdRmFFFVXMv+cBP7+wWRGWVil8tJlashOkJNRkYEBzIq2Myo4EBGBAcSepHddWityanOYVvhNrYWNITHVx3jSUFJp+84RseOlsl/olUSFN3U+uwSfvLeHqrsTn55xUBuG5d8zpncLTl+/DgrV64kPz+fiIgIpk2bxqBBgzAYLt22a2dREdY1X2Jds4baLVvQDgcGiwXzxIlYpk7BMnkypvBwtNY4T9VSm1GIbVcJus6FCvWjaGQEB3qZ2eWuJ6PaxuHaOr76l9Av0K+xqcrM2FAzvQP8LqrmHI/2kF2RzbbCbWwr2EZGUQZWpxWAPiF9GBM3hvTYdNJi02QvDnGaBEU3VmZ18LMP9rLqUDET+0by+HVDSQi98JnIWmuysrJYtWoVJSUlxMTEMHXqVAYOHHhRvcm1h8dmo3bzZqxffknNl1/iLikFpQgYOrQhNKZMwS8lBVwa+4FSajOKcBypBA2+vUIwj4rBlRLGXoeDHdW1ZFTb2FFVS0XjXUecnw8TQi1MDLMwMSyIRP/utYDhubg9bg6VH2Jr4Va2FW5jZ9FO7C47CsXA8IGMiR3DmLgxjIoZhdmn9ZWPxaVNgqKb01rz9raT/PGTAxiV4v/mDuL6tMR2vcF7PB4yMzP58ssvKS8vJy4ujilTpjBgwIBLPjCgYYJf3f4DWNeuxbp27ekmKlNUFOYpk7FMmox5/Di02xfbriJsGUW4yupQvgYChkQSOCoGv14hoOCo3cHmSivrKxqG5pY7G4Kjp78vE8MsTAgLYkKohRi/i6tvyOl2klmWeXp13N3Fu3F6nBiVkcERgxkTN4YxsWMYHj1cFjj8Bum2QaGUmg08BRiBl7TWjzV73g94HRgFlAE3aq1PtHXO9gaFrbqKdx7+GR63C7fLhcftxuTrh5/ZjH+gGf+gIAKDQwgIDiEwOITAkDDMIaFYIiIJiojE2AHbm+aW2Xjogz1sPV7OpH6R/Pma1AseRvsVt9vN3r17WbduHRUVFcTGxp4OjEu5Sao5V0kJ1vUbsK5bR+3GjXhqasBkInDECCxTJhM4cRLKLwb7zhJse0vQDjfGMD8CR8ZgHhl9ele+r4blbqxsmEG+udJKtcsDQN9AP8aHWpgQZmF8qIUo34srOOpcdewu2c22gm1sL9xOZmkmLt0wh2No1NCGO47YMQyNGoqv8eK6mxLnr1sGhVLKCBwGLgfygO3ATVrrA02OuR8YqrW+Vym1ALhGa31jW+dtb1DU220sf+EZjCYTBpMJg9GIy+GgzlaLo7YWe0019uoq7NYaaPY3UwYDwZFRhMTEERYbR2hMHCExsYRExxISHYNf4Pnfzns8mje35vD4Z4fwaHhwZn8Wju2Jv0/7OqibB0ZUVBSTJ09m8ODB36jAgIaht/bdu7GuW4d17Tochw8DYIqNxTJpIoHjJ2EMHYj9QNX/mqaSgzGPiiEgNRKD//8+DLi1JtNqZ2OFlY0VVrZWWbG6G4JjgNmf8aENoTEu1EKk78XVOW5z2thZvJNtBdvYWriVg2UH0Wj8jf4Mix52OjhSI1NlL45LSHcNinHAI1rrWY0//xJAa/1ok2O+aDxms1LKBBQCUbqNoju76cnjdmOvqcZWVUltVSU1ZSVU5edQmX+CqpISKssqqLPZz3iNf4A/QaHBBEeYCYp1E2gJJDDIQkCgP7QyWcrqcLIxu5STlTZ8DEZ6hAfSKyKQQL/2/cP0aCi32imoqKXO6cZkVAQF+BLk70ugrwk6oFXKXQ9uN2it4GJo0ayvh6oaqKxq+K/b0/B3MAdgCAolIDCWQEIwaV88uKk3OHAZ6nEqBx7loekv6UaRExjOIUs0WZYYss2ROBr33k6wV9K3tpTIeisR9TZCnTYMF8UfqIFHQ52up95Tj0M7cXkallQxYMDf4IufwQ+jTPrrFszKxPe/+1C7XttWUHjzo04CcLLJz3lAemvHaK1dSqkqIAIo7ZIKW2AwGjGHhmGuLyRqw6OQuxkc1Q1PWhq+7G4TVfX+VDn9qXb6UeXjS31vNz597OCnsQG2r07Y2vuFLwwaDIOaPVz/NWoPDobg7r+RXrfiaPz6igJaa3wZxP/+f7kwckz35SCDOOg/hAz/PtSqi3/pFdG99XId5/udcN6L6564FUqpe4B7AJKSkjr3YrVl8OWfIeM/4GeBoTdAaE8I7QH+IYAiAAgAIj11HKn4gGrrWvxRRAWOItY8Fo9TYau1UWevO+/LejyaMpsTt7vjPolqral3axxu97kPPuOF4HSArdpIXY0BrcHHV+MfpDH6aBS6tRuli4fWUOdAnfW3UZjwRelz/4LxQDxuprMH2EOdwUS5r5kqnwA65Baum3BpNxfHLeSlL1Crhsb8DubNoMgHejT5ObHxsZaOyWtsegqhoVP7DFrrF4EXoaHpqVOqBTj8BSy+H+wVkHYnTP0lmCNaPLSicjsHDvyOuro8EhNvoWfSd/D3v7g/ztfVOsnaUsj+DaeoKKjF5GekX1o0gybEE9Pr4l2QTwjRNm8GxXagn1KqFw2BsAD4drNjPgZuAzYD1wGr2+qf6DTOOljxMGx7AWKGwG0fQ8zgFg/1eFwcO/4EOTkvEODfg5Ej3yYsdHQXF9yxio5Xk7k2j+wdxbidHqKTg5l2y0D6jorG1/+SuCkVQrTBa//KG/scvg98QcPw2Fe01vuVUr8HMrTWHwMvA28opY4A5TSESdeylsBb10LBHki/D2Y8Aj4t7w/gcJSQuf8BKiu3Eh9/I/36/hqT6eKcwORyusneXkzm2jyKc2ow+RkZODaWwZMSiEoK8nZ5Qogu5NWPg1rrT4FPmz32cJPv64Dru7qu06ry4fX5UJUHN70DA65o9dDKygz2ZX4fl6uGQSl/Iy7umi4stONYKxxkrs1j/4ZT1FmdhMWZmbygPwPSY/ENkLsHIb6J5F9+a8qPw+vzwF4JtyyCnuNaPbSkZCWZ+3+An18cI4a/hsUyoAsL7RjFOdXsXnmSozuK8WhNr6GRDL2sBwn9Q6XvQYhvOAmKllhL4NUrwWlr6I+IH9HqoQUFizh46OcEWQYzbNjL+PqGd2GhX4/Wmpx9Zexakcup7Ep8/I2kTktk6LREgiNl6QYhRAMJiuY8Hlh0D9jK4K7lEDes1UPz8v9LVtb/ERY6lqFDX8BksnRhoe3n8WiO7ihmx+c5lOVbsYT7MeG6vgyaEC/NS0KIs8i7QnMbn4Sjq2Huk22GRFnZerKyfktkxGUMGfIsRqNfFxbZPh6PJnt7ERmfnqCyyEZYbCAzbk+h7+gYjMaLfeKDEKKzSFA0lbsFVv8RBl8Do25v9TCbLYfM/Q9gMfdjyJCnun1IaI/myM5iti87TkWhjYgEC7O/O4Tew6JQF+me0kKIriNB8RVbOXxwV8MM66ueglY6cF2uWvbuuxdQDB36PEZj992vWmvNyQPlbF58lNKTVsLizMz6zhD6jJCAEEKcPwmKr3jcDZPopv6icSmOs2mtOXjoF9TWHmHE8FcJCOjk5UK+huKcajZ9dIT8rEqCIvyZcXsK/cbEtmsHPSHEN5sExVcsUXDze20eUlzyGcXFn9Kn94OEh0/oosIuTHWpnS1LjpG9vYiAIB8m3diPwZMSMJqkD0II0T4SFOfJ6awgK+sRgoKGkJR0j7fLOUt9nYsdn+Wwe1UuBqUYdUVPRs7sKaOYhBBfm7yLnKfD2X/C5aoiZeBrGAzd58+mPZpDWwrZsvgotup6BqTHMvbq3ljCWl5mRAghLlT3ecfrxkrLvqSwcBHJyd8jKCjF2+WcVpJbw7p3sig8Vk1Mr2CuuC+V2F4t968IIUR7SVCcg8fjICvrYQID+9Ir+XveLgcAh93F1iXHyFybh7/Fh8tuTWHg2FgZySSE6BQSFOdwquBD6uryGT7sVQwG786X0FpzbFcJ6989TG11PamTE0if3xu/QB+v1iWEuLRJULTB43Fw4sRzhISMJDx8oldrsVY4WPt2Fif2lhLZw8IV9w0lJjnYqzUJIb4ZJCjacKrgQxyOAlJSHvPaCqpaaw5uKmDjB0fwuDyM/1Zfhk1PxCBLbgghuogERSvOuJsI886ciZryOta8eYiTB8qJ7xfKtFsGEhrdfWeCCyEuTRIUrfDm3YTWmsPbilj3zmE8Hs3kBf0ZMjlBOquFEF4hQdECj8fltbuJulonX76VxdGdxcT2DmHGHSmERMldhBDCeyQoWlBRsRGHo4D+/f6vS+8mTh2pZMXL+7FV1zP26t6MmNlT1mYSQnidBEULCguXYDKFEBk5tUuu5/Fodnx2gu3LjhMUGcC1PxtFdE8Z0SSE6B4kKJpxuWopLllOXOzVXTJvwm6tZ8XL+zl5sIL+Y2KYctMAWZ9JCNGtyDtSMyUly/F47MTGXt3p1yo8XsUXL2Zir3Ey7ZaBDJoQ3+nXFEKICyVB0Uxh0RL8/RMJCRnZqdfZvz6fde8cxhzqx7ceGilNTUKIbkuCogmHo5jy8o0k97wXpTpnQpvb5WH9u4fZv/4USYPCufyuwfibZQkOIUT3JUHRRFHRMsDTac1Otup6Pn9xHwVHqhg5K4n0+X1kVJMQotuToGiisHAxQUGpmM19OvzcBUer+OLFfThsLi6/axD9R8d2+DWEEKIzeCUolFLhwLtAMnACuEFrXdHCcW5gX+OPuVrreZ1Vk82WQ411P/36/aZDz6u1Zu+aPDZ9cARLhD/X/mAYkYlBHXoNIYToTN66o/gFsEpr/ZhS6heNP/+8hePsWuvhXVFQYGBPxqavwNc3vMPO6XS4WfPmIbK3F5E8NJIZt6fIkuBCiIuOt4JiPjC18fvXgC9pOSi6lNncu8POVVVi47PnMyk7ZSV9fm9GzeopazUJIS5K3gqKGK11QeP3hUBMK8f5K6UyABfwmNZ6cUsHKaXuAe4BSEpK6uhaL1jugTKWv7QfgLnfH0bPwRFerkgIIdqv04JCKbUSaKnH9tdNf9Baa6WUbuU0PbXW+Uqp3sBqpdQ+rfXR5gdprV8EXgRIS0tr7VxdInNtHuvezSY8LpAr7h1KSFSAN8sRQoivrdOCQms9o7XnlFJFSqk4rXWBUioOKG7lHPmN/z2mlPoSGAGcFRTdgcej2fhBNntX59EzNYKZdw3G118GlQkhLn7e2ibtY+C2xu9vA5Y0P0ApFaaU8mv8PhKYABzosgovgKvezecv7GPv6jyGXdaDOfcNlZAQQlwyvPVu9hjwnlLqLiAHuAFAKZUG3Ku1vhtIAV5QSnloCLTHtNbdLijq7S4+eW4vp45UMunGfgyd1sPbJQkhRIfySlBorcuA6S08ngHc3fj9JiC1i0u7ILbqepY9u4eyPCuX3ymT6IQQlyZpH2kna4WDJU/uwlpex5z7h9JziIxsEkJcmiQo2qGmvI7FT+zCXlPPVQ8MJ75vqLdLEkKITiNBcYGqSuwseWIXDruLeQ8MJ7ZXiLdLEkKITiVBcQGsFQ4WP7ETp8PN1T8eQVSSrNkkhLj0eWt47EXHYXex7Nk9OGpdzH9AQkII8c0hQXEe3C4Pn7+wj4qCWmZ/d4iEhBDiG0WC4hy01qx58xB5hyqYdstAkgbJ6CYhxDeLBMU5HNxUQNaWQkbP7cXAcXHeLkcIIbqcBEUbygtqWf/OYRIHhjF6TrK3yxFCCK+QoGiFy+lm+Uv7MfkZmXHHINlLQgjxjSVB0YpNHx2lLN/K9NtSMIf4ebscIYTwGgmKFhSdqGbfmjyGTkskOTXS2+UIIYRXSVA0o7Vm04dHCAjyIX1+x22NKoQQFysJimZO7C3l1P+3d/cxclVlHMe/P5aKUrG0SEur0JZatdLoFgqxsZomvqQ1JkBC0aLSqhEbef0DX6PQmJAQVDRRUatWiqDGKCCSKBqCUTC2pWWRtmvpi0Va+oJgU7HFwvbxj3uG3t3O3JmFmbnLzO+TNDv33DNznnPPdJ655+7es3kf53xgqteUMDPDiWKQgYHD/OX2rYw95XhmzJ1UdjhmZiOCE0VO//1PsG/PAeacP42eHh8aMzNwonjBoYPPs/rufzBp+olMeasvYJuZVXgSPnnu0AATp53ImfMnI/lvJszMKpwoktFjjmPB0hG98qqZWSk89WRmZoWcKMzMrJAThZmZFXKiMDOzQk4UZmZWyInCzMwKOVGYmVkhJwozMyukiCg7hqaS9CTw2It46muBfzU5nJeDbuy3+9wd3OfhmRwRJ1fb0XGJ4sWS9GBEzC47jnbrxn67z93BfW4eTz2ZmVkhJwozMyvkRHHE8rIDKEk39tt97g7uc5P4GoWZmRXyGYWZmRVyojAzs0JOFICk+ZI2Sdoi6fNlx9MOkrZLekRSn6QHy46nFSStkLRX0vpc2ThJf5C0Of0cW2aMrVCj38sk7Uzj3Sfp/WXG2EySTpV0n6SNkjZIujKVd+xYF/S5JePc9dcoJPUAjwLvBXYAa4BFEbGx1MBaTNJ2YHZEdOwfJEl6F/AMcEtEzExlNwBPR8T16UvB2Ij4XJlxNluNfi8DnomIr5UZWytImghMab9pSAAABetJREFUjIh1kk4A1gLnAUvo0LEu6POFtGCcfUYB5wBbImJbRBwCfg6cW3JM1gQR8Sfg6SHF5wIr0+OVZP+5OkqNfnesiNgVEevS4/8A/cDr6OCxLuhzSzhRZAf38dz2Dlp4wEeQAH4vaa2kS8oOpo0mRMSu9Hg3MKHMYNrsMkl/S1NTHTMNkydpCjALWEWXjPWQPkMLxtmJonvNjYgzgQXApWm6oqtENu/aLXOv3wWmAb3ALuDr5YbTfJJeDfwKuCoi9uf3depYV+lzS8bZiQJ2Aqfmtl+fyjpaROxMP/cCd5BNwXWDPWl+tzLPu7fkeNoiIvZExEBEHAZ+QIeNt6RRZB+Yt0XE7am4o8e6Wp9bNc5OFNnF6+mSpkp6BfAh4K6SY2opSaPTBTAkjQbeB6wvflbHuAtYnB4vBn5dYixtU/nATM6ng8ZbkoAfAf0RcWNuV8eOda0+t2qcu/63ngDSr5B9E+gBVkTEdSWH1FKSTic7iwA4FvhpJ/ZZ0s+AeWS3Xt4DXAvcCfwCOI3sdvQXRkRHXfit0e95ZNMRAWwHPpWbv39ZkzQX+DPwCHA4FX+RbM6+I8e6oM+LaME4O1GYmVkhTz2ZmVkhJwozMyvkRGFmZoWcKMzMrJAThZmZFXKiMDOzQk4UZmZWyInCmkrSQO5e+H2S/inp22XHZSOPpIWSVqX3yQZJ15Ydk1V3bNkBWMc5GBG9lQ1JS4DZ5YVjI5GkxcDlwHkRsUPS8cAnSg7LavAZhbWNpDvTbc03VG5tLmmKpL9Luk1Sv6Rfpg+NqvVzzwlJS9N2T1rV6+a0/RFJq9M31e+n/V9N27tzK4B9paj9IbEXxVmtvSmSDubOrG5J7V2Ve83rciuT5esPOguT9GVlKzD2pTpTasWTyten542StK3yWpJOlrRG0kOSHpb0zjrjkl8h74Lc8a3aRlHbQ47la4AbyW6psQMgIg5ExLeG946ydnGisHb6eEScRXaGcYWkk1L5m4CbImIGsB/4dJ36AFs4shDNfNKaIpJmAB8E3pHObAaAD0fEZ9L294BvRERvRFxTp/2hjqpXq71Uf2tqpzciLgZWABenOI8huwHlraluD7A5vUYlrsqH6uVAb9q3tSieIfFeQrbSHQAR8WREnB0Rs4DvNHic6xnURgPlkI3bqojYNox2rEROFNZOV0h6GPgr2a3dp6fyxyPigfT4VmBunfoA/wO2SDoD+Cjwk1T+buAsYI2kvrR9ep24arXfSL2G24uI7cBTkmaR3bH3oYh4Ku1+FfBsjXaV9jcct7K7An8MuGnQC0m9kh4Frgcq3/aLjnNNBW1ULc+ZCfQ10oaNDL5GYW0haR7wHmBORByQ9EfglWn30DtTRp36FT8GPkv2Pt5TaQpYGRFfGEZ4R7U/jHpV21O26lg1PyRby/kUsjOMiknAE0c1GLFf0jXANkmPkS1K00jcVwLLgUNDXq8PeKOkRcBFytY0qHeca6naRkF5xX+pnvhshPIZhbXLGODf6cPozcDbc/tOkzQnPb4IuL9OfQAiYi0wnixhVNwLXCBpPICkcZIm14mtWvuN1htue3eQTZWdDdyTK18IPFD1GdmCO7+JiLcxeOqpVtxjyKZ38okISSdI6kmbz5J9s697nGuo2kZBed5vgYWSJqS4jpP0yQbbtRL4jMLa5XfAUkn9wCayaY6KTWTLsa4ANpIt5zhQUP8FEbEAsoutaXujpC+RrQd+DPAccCnZegS1VGu/oXrpA7Zae7trxHtI0n3AvogYSLHfAIwmu24wiKQ3AFeTfetvJO7xZKs0Xh0Rz0vK1z8DWC6psizoZWQLd9U6zlMlVZLPScA4SQuA/oI2apXnj8FqScuAe9IxG8WRazU2Ank9CitVmqK5OyJmjuT2mxVn+mBcByyMiM0v4XWaEs8w21wCEBE3t6tNGxk89WTWJpLeQvbbWve+lCRRonXpn3UZn1GYmVkhn1GYmVkhJwozMyvkRGFmZoWcKMzMrJAThZmZFXKiMDOzQk4UZmZW6P96cM++Ei5TFgAAAABJRU5ErkJggg==\n" - }, - "metadata": { - "needs_background": "light" - } - } - ], - "source": [ - "plt.plot(Cs, parameters)\n", - "\n", - "plt.xlabel('Параметр регуляризации $C$')\n", - "plt.ylabel('$w$')\n", - "\n", - "plt.savefig(\n", - " os.path.join(figures, 'log_reg_cs_exp.eps'),\n", - " bbox_inches='tight')\n", - "\n", - "plt.show()" - ] + "name": "stderr", + "output_type": "stream", + "text": [ + "2024-04-01 20:21:47.470487: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n", + "To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n", + "2024-04-01 20:21:48.253379: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n" + ] } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.7.3" - }, + ], + "source": [ + "import os\n", + "import json\n", + "import glob\n", + "import torch\n", + "\n", + "import matplotlib.pyplot as plt\n", + "\n", + "from src.mylib.models.models import BaselineModel, MHAModel\n", + "from src.mylib.train import Trainer\n", + "\n", + "file = os.path.abspath('')" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "f2HeCQ89qXkZ" + }, + "source": [ + "# Работа с данными" + ] + }, + { + "cell_type": "markdown", + "source": [ + "Начальные параметры" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "outputs": [], + "source": [ + "batch_size = 64\n", + "\n", + "# Длина окна\n", + "window_length_seconds = 5 \n", + "sample_rate = 64\n", + "window_length = window_length_seconds * sample_rate\n", + "\n", + "# Расстояние между двумя окнами\n", + "hop_length_seconds = 1\n", + "hop_length = sample_rate * hop_length_seconds\n", + "\n", + "# Количество ложные стимулов\n", + "number_of_mismatch = 4" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-04-01T17:21:48.887002Z", + "start_time": "2024-04-01T17:21:48.883345Z" + } + }, + "execution_count": 2 + }, + { + "cell_type": "code", + "outputs": [], + "source": [ + "experiment_folder = os.path.dirname(file)\n", + "\n", + "# Load the config file\n", + "with open(os.path.join(experiment_folder, \"src/mylib/utils/config.json\")) as file_path:\n", + " config = json.load(file_path)\n", + "\n", + "# Path to the dataset, which is already split to train, val, test\n", + "data_folder = os.path.join(config[\"dataset_folder\"], config['derivatives_folder'], config[\"split_folder\"])\n", + "\n", + "# Пути к данным тренировочным, валидационным и тестовым данным\n", + "train_files = [x for x in glob.glob(os.path.join(data_folder, \"train_-_*\")) if\n", + " os.path.basename(x).split(\"_-_\")[-1].split(\".\")[0] in [\"eeg\", \"envelope\"]]\n", + "val_files = [x for x in glob.glob(os.path.join(data_folder, \"val_-_*\")) if\n", + " os.path.basename(x).split(\"_-_\")[-1].split(\".\")[0] in [\"eeg\", \"envelope\"]]\n", + "test_files = [x for x in glob.glob(os.path.join(data_folder, \"test_-_*\")) if\n", + " os.path.basename(x).split(\"_-_\")[-1].split(\".\")[0] in [\"eeg\", \"envelope\"]]" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-04-01T17:21:48.897759Z", + "start_time": "2024-04-01T17:21:48.887963Z" + } + }, + "execution_count": 3 + }, + { + "cell_type": "markdown", + "metadata": { + "id": "19nb_usNqXkc" + }, + "source": [ + "## Обучение " + ] + }, + { + "cell_type": "markdown", + "source": [ + "Базовое решение" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "source": [ + "# model = BaselineModel()\n", + "args = {\"window_length\" : window_length, \"hop_length\" : hop_length, \"number_of_mismatch\" : number_of_mismatch, \"batch_size\" : batch_size, \n", + " \"max_files\" : 100}\n", + "for model in [BaselineModel, MHAModel]:\n", + " model = model()\n", + " print(f\"Model: {model.__class__.__name__}\")\n", + " \n", + " trainer = Trainer(\n", + " model, train_files, val_files, test_files, args, torch.optim.Adam(model.parameters(), lr=0.001), torch.nn.CrossEntropyLoss(), \n", + " )\n", + "\n", + " trainer.train_model(epochs=5, run_name=f\"{model.__class__.__name__}\")\n", + " print(trainer.test())\n", + " print(\"-----\" * 5)" + ], + "metadata": { + "id": "ZMK7mqNQZPXJ", + "outputId": "a95524d6-db85-4a34-9c36-fa2befec2f34", "colab": { - "name": "main.ipynb", - "provenance": [] + "base_uri": "https://localhost:8080/" + }, + "ExecuteTime": { + "end_time": "2024-04-01T17:30:37.468006Z", + "start_time": "2024-04-01T17:21:48.898678Z" } + }, + "execution_count": 4, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "EPOCH 1:\n", + " batch 100 loss: 0.1482216489744413\n", + " batch 200 loss: 0.0002496870611862123\n", + " batch 300 loss: 0.0004293969416104437\n", + " batch 400 loss: 0.00014352307940049714\n", + " batch 500 loss: 0.0003824394412247301\n", + " batch 600 loss: 0.0005175532017259399\n", + " batch 700 loss: 0.00033389098716623877\n", + " batch 800 loss: 8.32907846506714e-05\n", + " batch 900 loss: 2.4057848494294377e-05\n", + "LOSS train 2.4057848494294377e-05 valid 0.0\n", + "EPOCH 2:\n", + " batch 100 loss: 2.239500692658325e-05\n", + " batch 200 loss: 1.717424765956821e-07\n", + " batch 300 loss: 2.382965994911501e-06\n", + " batch 400 loss: 2.0097729375834205e-07\n", + " batch 500 loss: 5.319957272149622e-06\n", + " batch 600 loss: 4.880248161498457e-07\n", + " batch 700 loss: 7.180011834861944e-08\n", + " batch 800 loss: 1.2190427505487378e-06\n", + " batch 900 loss: 2.3094077732821461e-07\n", + "LOSS train 2.3094077732821461e-07 valid 0.0\n", + "EPOCH 3:\n", + " batch 100 loss: 1.9006388822617737e-06\n", + " batch 200 loss: 1.8029883719350436e-08\n", + " batch 300 loss: 5.650982788552028e-07\n", + " batch 400 loss: 3.155075988914291e-08\n", + " batch 500 loss: 2.0160868007224052e-06\n", + " batch 600 loss: 1.7061551261576823e-07\n", + " batch 700 loss: 1.929664222188876e-08\n", + " batch 800 loss: 5.126564428525882e-07\n", + " batch 900 loss: 6.992110229475657e-08\n", + "LOSS train 6.992110229475657e-08 valid 0.0\n", + "EPOCH 4:\n", + " batch 100 loss: 8.588485644622779e-07\n", + " batch 200 loss: 5.774140259262595e-09\n", + " batch 300 loss: 2.494339387482114e-07\n", + " batch 400 loss: 8.158223745446947e-09\n", + " batch 500 loss: 1.0297044354956597e-06\n", + " batch 600 loss: 7.403204108413774e-08\n", + " batch 700 loss: 6.612346510337375e-09\n", + " batch 800 loss: 2.495401167823541e-07\n", + " batch 900 loss: 2.607663617482103e-08\n", + "LOSS train 2.607663617482103e-08 valid 0.0\n", + "EPOCH 5:\n", + " batch 100 loss: 4.253824216959856e-07\n", + " batch 200 loss: 2.067528156457499e-09\n", + " batch 300 loss: 1.1764754141552203e-07\n", + " batch 400 loss: 2.4027991685215965e-09\n", + " batch 500 loss: 5.495035293279216e-07\n", + " batch 600 loss: 3.333958602524944e-08\n", + " batch 700 loss: 2.384180106673739e-09\n", + " batch 800 loss: 1.2554179193102755e-07\n", + " batch 900 loss: 1.0691511533877929e-08\n", + "LOSS train 1.0691511533877929e-08 valid 0.0\n", + " precision recall f1-score support\n", + "\n", + " 0 1.00 0.99 0.99 1551\n", + " 1 0.99 1.00 1.00 1524\n", + " 2 0.99 1.00 0.99 1501\n", + " 3 0.99 0.99 0.99 1475\n", + " 4 0.99 1.00 0.99 1507\n", + "\n", + " accuracy 0.99 7558\n", + " macro avg 0.99 0.99 0.99 7558\n", + "weighted avg 0.99 0.99 0.99 7558\n" + ] + } + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.3" }, - "nbformat": 4, - "nbformat_minor": 0 -} \ No newline at end of file + "colab": { + "name": "main.ipynb", + "provenance": [] + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/code/model.py b/code/model.py index 27f8870..2e47637 100644 --- a/code/model.py +++ b/code/model.py @@ -15,12 +15,14 @@ def build(self, input_shape): super(TransformerBlock, self).build(input_shape) def call(self, inputs): - attn_output = self.att(inputs, inputs) - attn_output = self.dropout1(attn_output) - out1 = self.layernorm1(inputs + attn_output) - ffn_output = self.ffn(out1) - ffn_output = self.dropout2(ffn_output) - out = self.layernorm2(out1 + ffn_output) + attn_output = self.att(inputs, inputs) + attn_output = self.dropout1(attn_output) + out1 = self.layernorm1(inputs + attn_output) + print(out1.shape) + ffn_output = self.ffn(out1) + print(ffn_output.shape) + ffn_output = self.dropout2(ffn_output) + out = self.layernorm2(out1 + ffn_output) return out @@ -105,7 +107,7 @@ def Model( transformer_block = TransformerBlock(embed_dim=eeg_input_dimension, num_heads=2, ff_dim=32) eeg_proj_1 = transformer_block(eeg) - + print(eeg_proj_1.shape) # Construct dilation layers for layer_index in range(layers): diff --git a/src/mylib/models/eeg_encoders.py b/src/mylib/models/eeg_encoders.py new file mode 100644 index 0000000..3741007 --- /dev/null +++ b/src/mylib/models/eeg_encoders.py @@ -0,0 +1,47 @@ +import torch.nn as nn + + +class BaselineEEGEncoder(nn.Module): + """Encoder for EEG""" + + def __init__(self, in_channels=8, dilation_filters=16, kernel_size=3, layers=3): + super(BaselineEEGEncoder, self).__init__() + + self.eeg_convos = nn.Sequential() + + for layer_index in range(layers): + self.eeg_convos.add_module(f"conv1d_lay{layer_index}", + nn.Conv1d( + in_channels=dilation_filters * (layer_index != 0) + ( + layer_index == 0) * in_channels, + out_channels=dilation_filters, + kernel_size=kernel_size, + dilation=kernel_size ** layer_index, + bias=True)) + self.eeg_convos.add_module(f"relu_lay{layer_index}", nn.ReLU()) + + def forward(self, eeg): + return self.eeg_convos(eeg) + + +class MultiheadAttentionEEGEncoder(nn.Module): + """EEG Encoder using transformer""" + + def __init__(self, embed_dim, ff_dim): + super(MultiheadAttentionEEGEncoder, self).__init__() + + self.mha_attention = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=2) + self.ffn = nn.Sequential(nn.Linear(embed_dim, ff_dim), nn.ReLU(), nn.Linear(ff_dim, embed_dim)) + self.layer_norm1 = nn.LayerNorm(embed_dim, eps=1e-6) + self.layer_norm2 = nn.LayerNorm(embed_dim, eps=1e-6) + self.dropout1 = nn.Dropout(p=0.5) + self.dropout2 = nn.Dropout(p=0.5) + + def forward(self, x): + attn_output, _ = self.mha_attention(x, x, x) + attn_output = self.dropout1(attn_output) + out1 = self.layer_norm1(attn_output + x) + ffn_output = self.ffn(out1) + ffn_output = self.dropout2(ffn_output) + out = self.layer_norm2(out1 + ffn_output) + return out diff --git a/src/mylib/models/models.py b/src/mylib/models/models.py new file mode 100644 index 0000000..d90a4ef --- /dev/null +++ b/src/mylib/models/models.py @@ -0,0 +1,88 @@ +import torch +import torch.nn as nn +from src.mylib.models.eeg_encoders import BaselineEEGEncoder, MultiheadAttentionEEGEncoder +from src.mylib.models.stimulus_encoders import BaselineStimulusEncoder + + +class BaselineModel(nn.Module): + """Baseline model""" + + def __init__(self, + layers=3, + kernel_size=3, + spatial_filters=8, + dilation_filters=16): + super(BaselineModel, self).__init__() + + # EEG spatial transformation + self.spatial_transformation = nn.Conv1d( + in_channels=64, + out_channels=spatial_filters, + kernel_size=1, + bias=True + ) + + args = {"dilation_filters": dilation_filters, "kernel_size": kernel_size, "layers": layers} + + # EEG encoder + self.eeg_encoder = BaselineEEGEncoder(in_channels=spatial_filters, **args) + + # Stimulus encoder + self.stimulus_encoder = BaselineStimulusEncoder(**args) + + self.fc = nn.Linear(in_features=dilation_filters * dilation_filters, + out_features=1, + bias=True) + + def forward(self, eeg, stimuli): + eeg = self.spatial_transformation(eeg) + eeg = self.eeg_encoder(eeg) + + # shared weights for stimuli + for i in range(len(stimuli)): + stimuli[i] = self.stimulus_encoder(stimuli[i]) + + cosine_sim = [] + for stimulus in stimuli: + cosine_sim.append(eeg @ stimulus.transpose(-1, -2)) + sim_projections = [self.fc(torch.flatten(sim, start_dim=1)) for sim in cosine_sim] + return torch.cat(sim_projections, dim=1) + + +class MHAModel(nn.Module): + """Model with transformer block as spatial transformation""" + + def __init__(self, + layers=3, + kernel_size=3, + dilation_filters=16): + super(MHAModel, self).__init__() + + # EEG spatial transformation + self.spatial_transformation = MultiheadAttentionEEGEncoder(embed_dim=64, ff_dim=32) + + args = {"dilation_filters": dilation_filters, "kernel_size": kernel_size, "layers": layers} + + # EEG encoder + self.eeg_encoder = BaselineEEGEncoder(in_channels=64, **args) + + # Stimulus encoder + self.stimulus_encoder = BaselineStimulusEncoder(**args) + + self.fc = nn.Linear(in_features=dilation_filters * dilation_filters, + out_features=1, + bias=True) + + def forward(self, eeg, stimuli): + eeg = self.spatial_transformation(eeg.transpose(1, 2)) + eeg = self.eeg_encoder(eeg.transpose(1, 2)) + + # shared weights for stimuli + for i in range(len(stimuli)): + stimuli[i] = self.stimulus_encoder(stimuli[i]) + + cosine_sim = [] + for stimulus in stimuli: + cosine_sim.append(eeg @ stimulus.transpose(-1, -2)) + sim_projections = [self.fc(torch.flatten(sim, start_dim=1)) for sim in cosine_sim] + return torch.cat(sim_projections, dim=1) diff --git a/src/mylib/models/stimulus_encoders.py b/src/mylib/models/stimulus_encoders.py new file mode 100644 index 0000000..0786887 --- /dev/null +++ b/src/mylib/models/stimulus_encoders.py @@ -0,0 +1,22 @@ +import torch.nn as nn + + +class BaselineStimulusEncoder(nn.Module): + """Stimulus encoder from baseline solution""" + + def __init__(self, dilation_filters=16, kernel_size=3, layers=3): + super(BaselineStimulusEncoder, self).__init__() + + self.env_convos = nn.Sequential() + for layer_index in range(layers): + self.env_convos.add_module(f"conv1d_lay{layer_index}", + nn.Conv1d( + in_channels=dilation_filters * (layer_index != 0) + (layer_index == 0), + out_channels=dilation_filters, + kernel_size=kernel_size, + dilation=kernel_size ** layer_index, + bias=True)) + self.env_convos.add_module(f"relu_lay{layer_index}", nn.ReLU()) + + def forward(self, stimulus): + return self.env_convos(stimulus) diff --git a/src/mylib/train.py b/src/mylib/train.py index 15f6729..912d79e 100755 --- a/src/mylib/train.py +++ b/src/mylib/train.py @@ -1,132 +1,150 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -''' -The :mod:`mylib.train` contains classes: +import os -- :class:`mylib.train.Trainer` +import torch +from torch.utils.tensorboard import SummaryWriter +import torch.nn as nn -The :mod:`mylib.train` contains functions: - -- :func:`mylib.train.cv_parameters` -''' -from __future__ import print_function - -__docformat__ = 'restructuredtext' - -import numpy -from scipy.special import expit -from sklearn.linear_model import LogisticRegression -from sklearn.model_selection import train_test_split +from src.mylib.utils.data import TaskDataset from sklearn.metrics import classification_report -class SyntheticBernuliDataset(object): - r'''Base class for synthetic dataset.''' - def __init__(self, n=10, m=100, seed=42): - r'''Constructor method - :param n: the number of feature - :type n: int - :param m: the number of object - :type m: int - :param seed: seed for random state. - :type seed: int - ''' - rs = numpy.random.RandomState(seed) +class Trainer(object): + r"""Base class for all trainer.""" - self.w = rs.randn(n) # Генерим вектор параметров из нормального распределения - self.X = rs.randn(m, n) # Генерим вектора признаков из нормального распределения + def __init__(self, model, train_files, val_files, test_files, args, optimizer, loss_fn): + r"""Constructor method - self.y = rs.binomial(1, expit(self.X@self.w)) # Гипотеза порождения данных - целевая переменная из схемы Бернули + :param train_files: path to train files + :type train_files: list + :param val_files: path to val files + :type val_files: list -class Trainer(object): - r'''Base class for all trainer.''' - def __init__(self, model, X, Y, seed=42): - r'''Constructor method - - :param model: The class with fit and predict methods. - :type model: object - - :param X: The array of shape - `num_elements` :math:`\times` `num_feature`. - :type X: numpy.array - :param Y: The array of shape - `num_elements` :math:`\times` `num_answers`. - :type Y: numpy.array - - :param seed: Seed for random state. - :type seed: int - ''' - self.model = model - self.seed = seed - ( - self.X_train, - self.X_val, - self.Y_train, - self.Y_val - ) = train_test_split(X, Y, random_state=self.seed) - - def train(self): - r''' Train model - ''' - self.model.fit(self.X_train, self.Y_train) - - def eval(self, output_dict=False): - r'''Evaluate model for initial validadtion dataset. - ''' - return classification_report( - self.Y_val, - self.model.predict( - self.X_val), output_dict=output_dict) - - def test(self, X, Y, output_dict=False): - r"""Evaluate model for given dataset. - - :param X: The array of shape - `num_elements` :math:`\times` `num_feature`. - :type X: numpy.array - :param Y: The array of shape - `num_elements` :math:`\times` `num_answers`. - :type Y: numpy.array + :param test_files: path to test files + :type test_files: list """ - return classification_report( - Y, self.model.predict(X), output_dict=output_dict) - - -def cv_parameters(X, Y, seed=42, minimal=0.1, maximum=25, count=100): - r'''Function for the experiment with different regularisation parameters - and return accuracy and weidth for LogisticRegression for each parameter. - - :param X: The array of shape - `num_elements` :math:`\times` `num_feature`. - :type X: numpy.array - :param Y: The array of shape - `num_elements` :math:`\times` `num_answers`. - :type Y: numpy.array - - :param seed: Seed for random state. - :type seed: int - :param minimal: Minimum value for the Cs linspace. - :type minimal: int - :param maximum: Maximum value for the Cs linspace. - :type maximum: int - :param count: Number of the Cs points. - :type count: int - ''' - - Cs = numpy.linspace(minimal, maximum, count) - parameters = [] - accuracy = [] - for C in Cs: - trainer = Trainer( - LogisticRegression(penalty='l1', solver='saga', C=1/C), - X, Y, - ) - - trainer.train() - - accuracy.append(trainer.eval(output_dict=True)['accuracy']) - - parameters.extend(trainer.model.coef_) - - return Cs, accuracy, parameters + self.model = model + self.args = args + self.optimizer = optimizer + self.loss_fn = loss_fn + self.test_files = test_files + self.initialize_dataloaders(train_files, val_files, test_files) + + def initialize_dataloaders(self, train_files, val_files, test_files): + r"""Initialize dataloaders""" + + conf = {"window_length": self.args["window_length"], "hop_length": self.args["hop_length"], + "number_of_mismatch": self.args["number_of_mismatch"], "max_files": self.args["max_files"]} + self.train_dataloader = torch.utils.data.DataLoader(TaskDataset(train_files, **conf), + batch_size=self.args["batch_size"]) + self.val_dataloader = torch.utils.data.DataLoader(TaskDataset(val_files, **conf), + batch_size=self.args["batch_size"]) + self.test_dataloader = torch.utils.data.DataLoader(TaskDataset(test_files, **conf), + batch_size=1) + + def train_one_epoch(self, epoch_index, writer): + r"""Train one epoch""" + + running_loss = 0 + last_loss = 0 + + for i, data in enumerate(self.train_dataloader): + inputs, labels = data + + self.optimizer.zero_grad() + outputs = self.model(inputs[0], inputs[1:]) + + # TODO: CLASSIFICATION METRIC DURING TRAINING + # probs = (torch.nn.functional.softmax(outputs.data, dim=1) >= 0.5) + # _, predicted = torch.max(probs.data, 1) + + loss = self.loss_fn(outputs, labels) + loss.backward() + + self.optimizer.step() + + running_loss += loss.item() + if i % 100 == 99: + last_loss = running_loss / 100 + print(' batch {} loss: {}'.format(i + 1, last_loss)) + x = epoch_index * len(self.train_dataloader) + i + 1 + writer.add_scalar('Loss/train', last_loss, x) + running_loss = 0 + + return last_loss + + def train_model(self, epochs, run_name): + r""" Train models""" + + writer = SummaryWriter(f"runs/{run_name}_{self.model.__class__.__name__}") + + best_vloss = 1_000_000 + if not os.path.isdir("saved_models"): + os.makedirs("saved_models") + + for epoch in range(epochs): + print(f"EPOCH {epoch + 1}:") + self.model.train() + avg_loss = self.train_one_epoch(epoch + 1, writer) + + running_vloss = 0.0 + self.model.eval() + with torch.no_grad(): + for i, vdata in enumerate(self.val_dataloader): + vinputs, vlabels = vdata + voutputs = self.model(vinputs[0], vinputs[1:]) + vloss = self.loss_fn(voutputs, vlabels) + running_vloss += vloss.item() + + avg_vloss = running_vloss / (i + 1) + print("LOSS train {} valid {}".format(avg_loss, avg_vloss)) + + writer.add_scalars("Training vs. Validation Loss", + {"Training": avg_loss, "Validation": avg_vloss}, + epoch + 1) + writer.flush() + + if avg_vloss < best_vloss: + best_vloss = avg_vloss + model_path = f"saved_models/{self.model.__class__.__name__}_{epoch}" + torch.save(self.model.state_dict(), model_path) + + def eval(self): + r"""Evaluate model for initial validation dataset.""" + pass + + def test(self): + r"""Evaluate model for given dataset""" + + total = 0 + self.model.eval() + y_pred = [] + y_true = [] + subjects = list(set([os.path.basename(x).split("_-_")[1] for x in self.test_files])) + loss_fn = nn.functional.cross_entropy + with torch.no_grad(): + for sub in subjects: + sub_test_files = [f for f in self.test_files if sub in os.path.basename(f)] + test_dataloader = torch.utils.data.DataLoader(TaskDataset(sub_test_files, self.args["window_length"], self.args["hop_length"])) + loss = 0 + correct = 0 + for inputs, label in test_dataloader: + outputs = self.model(inputs[0], inputs[1:]) + + loss += loss_fn(outputs, label).item() + probs = (torch.nn.functional.softmax(outputs.data, dim=1) >= 0.5) + _, predicted = torch.max() + + for data in self.test_dataloader: + inputs, labels = data + + outputs = self.model(inputs[0], inputs[1:]) + _, predicted = torch.max(outputs.data, 1) + total += labels.size(0) + + y_pred.extend(predicted.tolist()) + y_true.extend(labels.tolist()) + + correct += (predicted == labels).sum().item() + + return classification_report(y_true, y_pred) diff --git a/code/config.json b/src/mylib/utils/config.json similarity index 67% rename from code/config.json rename to src/mylib/utils/config.json index e29f701..dbf1b54 100644 --- a/code/config.json +++ b/src/mylib/utils/config.json @@ -1,8 +1,8 @@ { - "dataset_folder": "--absolute path to dataset folder--", + "dataset_folder": "/home/bukkacha/Desktop/EEGDataset", "derivatives_folder": "derivatives", "preprocessed_eeg_folder": "preprocessed_eeg", "preprocessed_stimuli_folder": "preprocessed_stimuli", "split_folder": "split_data", - "test_folder": "test_set" + "stimuli": "stimuli" } diff --git a/src/mylib/utils/data.py b/src/mylib/utils/data.py new file mode 100644 index 0000000..573f4d8 --- /dev/null +++ b/src/mylib/utils/data.py @@ -0,0 +1,71 @@ +import torch +import numpy as np +import itertools +import os +from torch.utils.data import Dataset + + +class TaskDataset(Dataset): + """Generate data for the Match/Mismatch task.""" + + def __init__(self, files, window_length, hop_length, number_of_mismatch, max_files=100): + self.labels = [] + assert number_of_mismatch != 0 + self.window_length = window_length + self.hop_length = hop_length + self.number_of_mismatch = number_of_mismatch + self.files = files + self.max_files = max_files + self.group_recordings() + self.frame_recordings() + self.create_imposter_segments() + self.create_labels_randomize_positions() + + def group_recordings(self): + new_files = [] + grouped = itertools.groupby(sorted(self.files), lambda x: "_-_".join(os.path.basename(x).split("_-_")[:3])) + + for recording_name, feature_paths in grouped: + sub_recordings = sorted(feature_paths, key=lambda x: "0" if x == "eeg" else x) + eeg, envelope = np.load(sub_recordings[0]), np.load(sub_recordings[1]) # eeg [L, C], env [L, 1] + new_files += [[torch.tensor(eeg.T).float(), torch.tensor(envelope.T).float()]] + + if self.max_files is not None and len(new_files) == self.max_files: + break + + self.files = new_files + + def frame_recordings(self): + new_files = [] + for i in range(len(self.files)): + self.files[i][0] = self.files[i][0].unfold( + 1, self.window_length, self.hop_length).transpose(0, 1) # [num_of_frames, C, window_length] + self.files[i][1] = self.files[i][1].unfold( + 1, self.window_length, self.hop_length).transpose(0, 1) # [num_of_frames, C, window_length] + eegs = list(torch.tensor_split(self.files[i][0], self.files[i][0].shape[0], dim=0)) + envs = list(torch.tensor_split(self.files[i][1], self.files[i][1].shape[0], dim=0)) + for eeg, env in zip(eegs, envs): + new_files.append([eeg.squeeze(), env.squeeze(dim=0)]) + self.files = new_files + + def create_imposter_segments(self): + for i in range(len(self.files)): + for _ in range(self.number_of_mismatch): + t = self.files[i][-1].view(-1) + t = t[torch.randperm(t.shape[-1])].view(self.files[i][-1].shape) + self.files[i].append(t) + + def create_labels_randomize_positions(self): + for i in range(len(self.files)): + idx_permutation = torch.randperm(self.number_of_mismatch + 1) + 1 + permuted = [] + for idx in idx_permutation: + permuted.append(self.files[i][idx]) + self.files[i][1:] = permuted + self.labels.append(torch.argmax((idx_permutation == 1).long())) + + def __len__(self): + return len(self.files) + + def __getitem__(self, idx): + return self.files[idx], self.labels[idx]