diff --git a/src/notebooks/regression_models.ipynb b/src/notebooks/regression_models.ipynb index bc55d77..fc421e5 100644 --- a/src/notebooks/regression_models.ipynb +++ b/src/notebooks/regression_models.ipynb @@ -14,8 +14,8 @@ "id": "3c39794f97968a3f", "metadata": { "ExecuteTime": { - "end_time": "2024-01-15T12:35:06.727683366Z", - "start_time": "2024-01-15T12:34:55.681492412Z" + "end_time": "2024-01-15T16:36:40.581338796Z", + "start_time": "2024-01-15T16:36:35.574694703Z" } }, "outputs": [], @@ -32,8 +32,8 @@ "id": "32f1ea8c815713bb", "metadata": { "ExecuteTime": { - "end_time": "2024-01-15T12:35:08.335982316Z", - "start_time": "2024-01-15T12:35:08.181573198Z" + "end_time": "2024-01-15T16:36:40.598939365Z", + "start_time": "2024-01-15T16:36:40.564765676Z" } }, "outputs": [], @@ -55,8 +55,8 @@ "id": "f806c99cd0f3a8fc", "metadata": { "ExecuteTime": { - "end_time": "2024-01-15T12:35:09.442537206Z", - "start_time": "2024-01-15T12:35:09.153303959Z" + "end_time": "2024-01-15T16:36:41.001371375Z", + "start_time": "2024-01-15T16:36:40.621694256Z" } }, "outputs": [], @@ -73,8 +73,8 @@ "id": "75c601e5b4463afb", "metadata": { "ExecuteTime": { - "end_time": "2024-01-15T12:35:09.918528885Z", - "start_time": "2024-01-15T12:35:09.570058280Z" + "end_time": "2024-01-15T16:36:59.028790304Z", + "start_time": "2024-01-15T16:36:58.769138004Z" } }, "outputs": [], @@ -97,8 +97,8 @@ "id": "2e13065ce6d1c754", "metadata": { "ExecuteTime": { - "end_time": "2024-01-15T12:35:10.825912834Z", - "start_time": "2024-01-15T12:35:10.431680709Z" + "end_time": "2024-01-15T16:37:00.836256191Z", + "start_time": "2024-01-15T16:37:00.499623619Z" } }, "outputs": [], @@ -113,8 +113,8 @@ "id": "e21809bf32473272", "metadata": { "ExecuteTime": { - "end_time": "2024-01-15T12:35:11.894336813Z", - "start_time": "2024-01-15T12:35:10.976153027Z" + "end_time": "2024-01-15T16:37:08.386010297Z", + "start_time": "2024-01-15T16:37:07.678184948Z" } }, "outputs": [ @@ -138,15 +138,15 @@ "id": "30d4ec14f4e62f92", "metadata": { "ExecuteTime": { - "end_time": "2024-01-15T12:35:13.575111156Z", - "start_time": "2024-01-15T12:35:12.943338119Z" + "end_time": "2024-01-15T16:37:12.840613822Z", + "start_time": "2024-01-15T16:37:12.159465710Z" } }, "outputs": [ { "data": { - "text/plain": "\n duration col = 'day_succ'\n event col = 'Status'\n baseline estimation = breslow\n number of observations = 4175\nnumber of events observed = 1962\n partial log-likelihood = -15103.16\n time fit was run = 2024-01-15 12:35:10 UTC\n\n---\n coef exp(coef) se(coef) coef lower 95% coef upper 95% exp(coef) lower 95% exp(coef) upper 95%\ncovariate \nhas_video 0.69 1.99 0.07 0.56 0.82 1.74 2.28\nfacebook_connected -0.02 0.98 0.05 -0.12 0.08 0.89 1.08\ngoal -20.25 0.00 1.54 -23.27 -17.23 0.00 0.00\nfacebook_friends 0.16 1.17 0.02 0.12 0.20 1.12 1.22\n\n cmp to z p -log2(p)\ncovariate \nhas_video 0.00 10.07 <0.005 76.89\nfacebook_connected 0.00 -0.45 0.65 0.62\ngoal 0.00 -13.14 <0.005 128.60\nfacebook_friends 0.00 7.68 <0.005 45.88\n---\nConcordance = 0.66\nPartial AIC = 30214.31\nlog-likelihood ratio test = 462.69 on 4 df\n-log2(p) of ll-ratio test = 325.90", - "text/html": "
\n\n\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n
modellifelines.CoxPHFitter
duration col'day_succ'
event col'Status'
baseline estimationbreslow
number of observations4175
number of events observed1962
partial log-likelihood-15103.16
time fit was run2024-01-15 12:35:10 UTC
\n
\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n
coefexp(coef)se(coef)coef lower 95%coef upper 95%exp(coef) lower 95%exp(coef) upper 95%cmp tozp-log2(p)
has_video0.691.990.070.560.821.742.280.0010.07<0.00576.89
facebook_connected-0.020.980.05-0.120.080.891.080.00-0.450.650.62
goal-20.250.001.54-23.27-17.230.000.000.00-13.14<0.005128.60
facebook_friends0.161.170.020.120.201.121.220.007.68<0.00545.88

\n\n\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n
Concordance0.66
Partial AIC30214.31
log-likelihood ratio test462.69 on 4 df
-log2(p) of ll-ratio test325.90
\n
", + "text/plain": "\n duration col = 'day_succ'\n event col = 'Status'\n baseline estimation = breslow\n number of observations = 4175\nnumber of events observed = 1962\n partial log-likelihood = -15103.16\n time fit was run = 2024-01-15 16:37:07 UTC\n\n---\n coef exp(coef) se(coef) coef lower 95% coef upper 95% exp(coef) lower 95% exp(coef) upper 95%\ncovariate \nhas_video 0.69 1.99 0.07 0.56 0.82 1.74 2.28\nfacebook_connected -0.02 0.98 0.05 -0.12 0.08 0.89 1.08\ngoal -20.25 0.00 1.54 -23.27 -17.23 0.00 0.00\nfacebook_friends 0.16 1.17 0.02 0.12 0.20 1.12 1.22\n\n cmp to z p -log2(p)\ncovariate \nhas_video 0.00 10.07 <0.005 76.89\nfacebook_connected 0.00 -0.45 0.65 0.62\ngoal 0.00 -13.14 <0.005 128.60\nfacebook_friends 0.00 7.68 <0.005 45.88\n---\nConcordance = 0.66\nPartial AIC = 30214.31\nlog-likelihood ratio test = 462.69 on 4 df\n-log2(p) of ll-ratio test = 325.90", + "text/html": "
\n\n\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n
modellifelines.CoxPHFitter
duration col'day_succ'
event col'Status'
baseline estimationbreslow
number of observations4175
number of events observed1962
partial log-likelihood-15103.16
time fit was run2024-01-15 16:37:07 UTC
\n
\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n
coefexp(coef)se(coef)coef lower 95%coef upper 95%exp(coef) lower 95%exp(coef) upper 95%cmp tozp-log2(p)
has_video0.691.990.070.560.821.742.280.0010.07<0.00576.89
facebook_connected-0.020.980.05-0.120.080.891.080.00-0.450.650.62
goal-20.250.001.54-23.27-17.230.000.000.00-13.14<0.005128.60
facebook_friends0.161.170.020.120.201.121.220.007.68<0.00545.88

\n\n\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n
Concordance0.66
Partial AIC30214.31
log-likelihood ratio test462.69 on 4 df
-log2(p) of ll-ratio test325.90
\n
", "text/latex": "\\begin{tabular}{lrrrrrrrrrrr}\n & coef & exp(coef) & se(coef) & coef lower 95% & coef upper 95% & exp(coef) lower 95% & exp(coef) upper 95% & cmp to & z & p & -log2(p) \\\\\ncovariate & & & & & & & & & & & \\\\\nhas_video & 0.69 & 1.99 & 0.07 & 0.56 & 0.82 & 1.74 & 2.28 & 0.00 & 10.07 & 0.00 & 76.89 \\\\\nfacebook_connected & -0.02 & 0.98 & 0.05 & -0.12 & 0.08 & 0.89 & 1.08 & 0.00 & -0.45 & 0.65 & 0.62 \\\\\ngoal & -20.25 & 0.00 & 1.54 & -23.27 & -17.23 & 0.00 & 0.00 & 0.00 & -13.14 & 0.00 & 128.60 \\\\\nfacebook_friends & 0.16 & 1.17 & 0.02 & 0.12 & 0.20 & 1.12 & 1.22 & 0.00 & 7.68 & 0.00 & 45.88 \\\\\n\\end{tabular}\n" }, "metadata": {}, @@ -163,8 +163,8 @@ "id": "7853f17f3e332e49", "metadata": { "ExecuteTime": { - "end_time": "2024-01-15T12:35:16.076522242Z", - "start_time": "2024-01-15T12:35:14.614432587Z" + "end_time": "2024-01-15T16:37:15.051136636Z", + "start_time": "2024-01-15T16:37:12.993437330Z" } }, "outputs": [ @@ -323,180 +323,142 @@ }, { "cell_type": "code", - "execution_count": 27, - "id": "28bd8514a4b9862d", + "outputs": [], + "source": [], "metadata": { + "collapsed": false, "ExecuteTime": { - "end_time": "2024-01-15T13:00:56.292470977Z", - "start_time": "2024-01-15T13:00:56.008965764Z" + "end_time": "2024-01-15T16:39:33.936594190Z", + "start_time": "2024-01-15T16:39:33.592732071Z" } }, - "outputs": [], - "source": [ - "a = cph.predict_expectation(df)" - ] + "id": "eafacaa74f64964c", + "execution_count": 14 }, { "cell_type": "code", - "execution_count": 28, - "id": "5616d83a75506e07", + "execution_count": 15, + "id": "51c7d6eb32e90435", "metadata": { "ExecuteTime": { - "end_time": "2024-01-15T13:00:58.129775071Z", - "start_time": "2024-01-15T13:00:58.063191328Z" + "end_time": "2024-01-15T16:39:34.281672961Z", + "start_time": "2024-01-15T16:39:34.004821245Z" } }, - "outputs": [ - { - "data": { - "text/plain": "(4175, 6)" - }, - "execution_count": 28, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ - "df.shape" + "import numpy as np\n", + "from lifelines import AalenAdditiveFitter, CoxPHFitter, WeibullFitter\n", + "from lifelines.utils import k_fold_cross_validation\n", + "\n", + "\n", + "#create the three models we'd like to compare.\n", + "aaf_1 = AalenAdditiveFitter(coef_penalizer=0.5)\n", + "aaf_2 = AalenAdditiveFitter(coef_penalizer=10)\n", + "cph = CoxPHFitter()\n" ] }, { "cell_type": "code", - "execution_count": 29, - "id": "763c6bc587a90e52", + "id": "6444223328e05f9", "metadata": { "ExecuteTime": { - "end_time": "2024-01-15T13:00:59.740619735Z", - "start_time": "2024-01-15T13:00:59.621534120Z" + "end_time": "2024-01-15T16:39:43.284553670Z", + "start_time": "2024-01-15T16:39:36.080198026Z" } }, "outputs": [ { - "data": { - "text/plain": "(4175,)" - }, - "execution_count": 29, - "metadata": {}, - "output_type": "execute_result" + "name": "stdout", + "output_type": "stream", + "text": [ + "0.6565838815185863\n", + "0.5893522643535443\n", + "0.5863262066574004\n" + ] } ], "source": [ - "a.shape" - ] + "print(np.mean(k_fold_cross_validation(cph, df, duration_col=\"day_succ\", event_col=\"Status\", scoring_method=\"concordance_index\")))\n", + "print(np.mean(k_fold_cross_validation(aaf_1, df, duration_col=\"day_succ\", event_col=\"Status\", scoring_method=\"concordance_index\")))\n", + "print(np.mean(k_fold_cross_validation(aaf_2, df, duration_col=\"day_succ\", event_col=\"Status\", scoring_method=\"concordance_index\")))" + ], + "execution_count": 17 }, { "cell_type": "code", - "outputs": [ - { - "data": { - "text/plain": "0 39.119406\n1 30.782592\n2 32.615499\n3 30.328077\n4 30.635064\ndtype: float64" - }, - "execution_count": 30, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ - "a.head()" + "# todo: faire un tableau pour dire quelle métrique on utilise (concordance), et comparer les résultats" ], "metadata": { - "collapsed": false, - "ExecuteTime": { - "end_time": "2024-01-15T13:01:04.846641554Z", - "start_time": "2024-01-15T13:01:04.784144727Z" - } + "collapsed": false }, - "id": "30cb55b0b468f8e1", - "execution_count": 30 + "id": "ae3e0a8fe12444b6" }, { - "cell_type": "code", - "execution_count": 20, - "id": "51c7d6eb32e90435", + "cell_type": "markdown", + "source": [ + "# Prédictions" + ], "metadata": { - "ExecuteTime": { - "end_time": "2024-01-15T12:43:00.477950052Z", - "start_time": "2024-01-15T12:43:00.333775445Z" - } + "collapsed": false }, - "outputs": [], - "source": [ - "import numpy as np\n", - "from lifelines import AalenAdditiveFitter, CoxPHFitter, WeibullFitter\n", - "from lifelines.utils import k_fold_cross_validation\n", - "\n", - "\n", - "#create the three models we'd like to compare.\n", - "aaf_1 = AalenAdditiveFitter(coef_penalizer=0.5)\n", - "aaf_2 = AalenAdditiveFitter(coef_penalizer=10)\n", - "cph = CoxPHFitter()\n", - "weibull = WeibullFitter()\n" - ] + "id": "da5a0c4a3a3f25cd" }, { "cell_type": "code", "outputs": [ { "data": { - "text/plain": "Index(['day_succ', 'Status', 'has_video', 'facebook_connected', 'goal',\n 'facebook_friends'],\n dtype='object')" + "text/plain": "290.36207370051335" }, - "execution_count": 21, + "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "df.columns" + "from sklearn.metrics import mean_squared_error\n", + "\n", + "y_pred = cph.predict_expectation(df)\n", + "mean_squared_error(event_times, y_pred)" ], "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2024-01-15T12:43:05.895151205Z", - "start_time": "2024-01-15T12:43:05.631593378Z" + "end_time": "2024-01-15T16:41:05.361065301Z", + "start_time": "2024-01-15T16:41:05.132048045Z" } }, - "id": "4ddde558c8048537", - "execution_count": 21 + "id": "a39bc955421d018d", + "execution_count": 19 }, { "cell_type": "code", - "id": "6444223328e05f9", - "metadata": { - "ExecuteTime": { - "end_time": "2024-01-15T12:42:02.696879023Z", - "start_time": "2024-01-15T12:41:55.129415082Z" - } - }, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "0.6554911958267244\n", - "0.5836754914997611\n", - "0.5856958185269712\n" - ] + "data": { + "text/plain": "300.9594070099238" + }, + "execution_count": 23, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ - "\n", - "\n", - "print(np.mean(k_fold_cross_validation(cph, df, duration_col=\"day_succ\", event_col=\"Status\", scoring_method=\"concordance_index\")))\n", - "print(np.mean(k_fold_cross_validation(aaf_1, df, duration_col=\"day_succ\", event_col=\"Status\", scoring_method=\"concordance_index\")))\n", - "print(np.mean(k_fold_cross_validation(aaf_2, df, duration_col=\"day_succ\", event_col=\"Status\", scoring_method=\"concordance_index\")))" - ], - "execution_count": 19 - }, - { - "cell_type": "code", - "outputs": [], - "source": [ - "# todo: faire un tableau pour dire quelle métrique on utilise (concordance), et comparer les résultats" + "y_pred = aaf_2.predict_expectation(df)\n", + "mean_squared_error(event_times, y_pred)" ], "metadata": { - "collapsed": false + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-01-15T16:41:38.206588260Z", + "start_time": "2024-01-15T16:41:37.688744662Z" + } }, - "id": "ae3e0a8fe12444b6" + "id": "cf62fe9527fd4c52", + "execution_count": 23 }, { "cell_type": "markdown", @@ -571,9 +533,50 @@ "id": "c602de5b305abc46", "execution_count": 25 }, + { + "cell_type": "markdown", + "source": [ + "$$\\text{AIC}(\\text{model}) = -2 \\ln{L} + 2k$$\n", + "\n", + "avec $k$ le nombre de paramètres (degrés de liberté) du modèle\n", + "et $L$ la vraisemblance" + ], + "metadata": { + "collapsed": false + }, + "id": "60405299e721d518" + }, + { + "cell_type": "markdown", + "source": [ + "$$\\text{BIC}(\\text{model}) = -2 \\ln{L} + k \\cdot \\ln N$$\n", + "\n", + "avec $k$ le nombre de paramètres (degrés de liberté) du modèle,\n", + "$L$ la vraisemblance\n", + "et $N$ le nombre d'observations" + ], + "metadata": { + "collapsed": false + }, + "id": "86dd805eaa66f613" + }, { "cell_type": "code", "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "WeibullFitter a un AIC de 19887.721204759433\n", + "WeibullFitter a un BIC de 19900.394944034004\n", + "LogNormalFitter a un AIC de 20172.180859246975\n", + "LogNormalFitter a un BIC de 20184.854598521546\n", + "LogLogisticFitter a un AIC de 20021.274197390605\n", + "LogLogisticFitter a un BIC de 20033.947936665176\n", + "ExponentialFitter a un AIC de 19897.507803798864\n", + "ExponentialFitter a un BIC de 19903.84467343615\n" + ] + }, { "data": { "text/plain": "
", @@ -593,17 +596,19 @@ "\n", "for i, model in enumerate([WeibullFitter(), LogNormalFitter(), LogLogisticFitter(), ExponentialFitter()]):\n", " model.fit(event_times, event_observed)\n", - " qq_plot(model, ax=axes[i])\n" + " qq_plot(model, ax=axes[i])\n", + " print(f\"{model.__class__.__name__} a un AIC de {model.AIC_}\")\n", + " print(f\"{model.__class__.__name__} a un BIC de {model.BIC_}\")\n" ], "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2024-01-15T12:51:40.010529332Z", - "start_time": "2024-01-15T12:51:34.148334671Z" + "end_time": "2024-01-15T16:45:40.545572064Z", + "start_time": "2024-01-15T16:45:34.983161281Z" } }, "id": "d618135ec6556bea", - "execution_count": 26 + "execution_count": 30 }, { "cell_type": "code", @@ -612,7 +617,7 @@ "metadata": { "collapsed": false }, - "id": "9aecf73164aead4e" + "id": "8862a0a4b5185d67" } ], "metadata": {