Skip to content

Commit

Permalink
adding Sklearn interface (#22)
Browse files Browse the repository at this point in the history
  • Loading branch information
tabacof authored Aug 21, 2023
1 parent 59181b4 commit 1ac556a
Show file tree
Hide file tree
Showing 10 changed files with 491 additions and 830 deletions.
5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,8 @@ module-name = "rustrees.rustrees"
requires = ["maturin>=1.0,<2.0"]
build-backend = "maturin"

[tool.black]
line-length = 88

[tool.ruff]
line-length = 88
145 changes: 88 additions & 57 deletions python/notebooks/econml_vs_rustrees.ipynb

Large diffs are not rendered by default.

188 changes: 88 additions & 100 deletions python/notebooks/sklearn_vs_rustrees.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
"from sklearn.metrics import r2_score, accuracy_score\n",
"from sklearn.tree import DecisionTreeRegressor, DecisionTreeClassifier\n",
"from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier\n",
"from rustrees.rustrees import DecisionTree, RandomForest\n",
"import rustrees.tree as rt\n",
"import rustrees.decision_tree as rt_dt\n",
"import rustrees.random_forest as rt_rf\n",
"import time\n",
"import numpy as np"
]
Expand Down Expand Up @@ -55,28 +55,29 @@
"metadata": {},
"outputs": [],
"source": [
"def evaluate_dataset(dataset, problem, model=\"dt\", max_depth=5, n_estimators=10, n_repeats=100):\n",
"def evaluate_dataset(dataset, problem, model, max_depth, n_repeats, n_estimators=None):\n",
" df_train = pd.read_csv(f\"../../datasets/{dataset}_train.csv\")\n",
" df_test = pd.read_csv(f\"../../datasets/{dataset}_test.csv\")\n",
" \n",
" df_train_rt = rt.from_pandas(df_train)\n",
" df_test_rt = rt.from_pandas(df_test)\n",
"\n",
" if problem == \"reg\":\n",
" metric_fn = r2_score\n",
" metric = \"r2\"\n",
" if model == \"dt\":\n",
" model_sk = DecisionTreeRegressor(max_depth=max_depth)\n",
" model_rt = rt_dt.DecisionTreeRegressor(max_depth=max_depth)\n",
" elif model == \"rf\":\n",
" model_sk = RandomForestRegressor(n_estimators=n_estimators, max_depth=max_depth, n_jobs=-1)\n",
" model_rt = rt_rf.RandomForestRegressor(n_estimators=n_estimators, max_depth=max_depth)\n",
" elif problem == \"clf\":\n",
" metric_fn = accuracy_score\n",
" metric = \"acc\"\n",
" if model == \"dt\":\n",
" model_sk = DecisionTreeClassifier(max_depth=max_depth)\n",
" model_rt = rt_dt.DecisionTreeClassifier(max_depth=max_depth)\n",
" elif model == \"rf\":\n",
" model_sk = RandomForestClassifier(n_estimators=n_estimators, max_depth=max_depth, n_jobs=-1)\n",
" \n",
" model_rt = rt_rf.RandomForestClassifier(n_estimators=n_estimators, max_depth=max_depth)\n",
"\n",
" start_time = time.time()\n",
" results_sk = []\n",
" for _ in range(n_repeats):\n",
Expand All @@ -89,21 +90,8 @@
" start_time = time.time()\n",
" results_rt = []\n",
" for _ in range(n_repeats):\n",
" if problem == \"reg\" and model == \"dt\":\n",
" model_rt = DecisionTree.train_reg(df_train_rt, max_depth=max_depth)\n",
" elif problem == \"reg\" and model == \"rf\":\n",
" model_rt = RandomForest.train_reg(df_train_rt, n_estimators=n_estimators, max_depth=max_depth)\n",
" elif problem == \"clf\" and model == \"dt\":\n",
" model_rt = DecisionTree.train_clf(df_train_rt, max_depth=max_depth) \n",
" elif problem == \"clf\" and model == \"rf\":\n",
" model_rt = RandomForest.train_clf(df_train_rt, n_estimators=n_estimators, max_depth=max_depth)\n",
"\n",
" if problem == \"reg\":\n",
" pred_rt = model_rt.predict(df_test_rt)\n",
" results_rt.append(metric_fn(df_test.target, pred_rt))\n",
" elif problem == \"clf\":\n",
" pred_rt = model_rt.predict(df_test_rt)\n",
" results_rt.append(metric_fn(df_test.target, np.array(pred_rt) > 0.5))\n",
" model_rt.fit(df_train.drop(\"target\", axis=1), df_train.target)\n",
" results_rt.append(metric_fn(df_test.target, model_rt.predict(df_test.drop(\"target\", axis=1))))\n",
" rt_time = (time.time() - start_time)/n_repeats\n",
" rt_mean = np.mean(results_rt)\n",
" rt_std = np.std(results_rt)\n",
Expand Down Expand Up @@ -161,23 +149,23 @@
" <tr>\n",
" <th>0</th>\n",
" <td>diabetes</td>\n",
" <td>0.313557</td>\n",
" <td>0.307939</td>\n",
" <td>4.120831e-02</td>\n",
" <td>4.077774e-02</td>\n",
" <td>0.002528</td>\n",
" <td>0.001109</td>\n",
" <td>0.315319</td>\n",
" <td>0.270029</td>\n",
" <td>3.251468e-02</td>\n",
" <td>1.780794e-02</td>\n",
" <td>0.002659</td>\n",
" <td>0.003520</td>\n",
" <td>r2</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>housing</td>\n",
" <td>0.599732</td>\n",
" <td>0.599732</td>\n",
" <td>1.174950e-16</td>\n",
" <td>1.110223e-16</td>\n",
" <td>0.043352</td>\n",
" <td>0.049112</td>\n",
" <td>0.598390</td>\n",
" <td>1.336886e-16</td>\n",
" <td>0.000000e+00</td>\n",
" <td>0.042986</td>\n",
" <td>0.060472</td>\n",
" <td>r2</td>\n",
" </tr>\n",
" <tr>\n",
Expand All @@ -187,30 +175,30 @@
" <td>0.993510</td>\n",
" <td>4.440892e-16</td>\n",
" <td>0.000000e+00</td>\n",
" <td>0.058138</td>\n",
" <td>0.147836</td>\n",
" <td>0.056852</td>\n",
" <td>0.360891</td>\n",
" <td>r2</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>breast_cancer</td>\n",
" <td>0.929649</td>\n",
" <td>0.929895</td>\n",
" <td>6.988977e-03</td>\n",
" <td>6.601342e-03</td>\n",
" <td>0.004454</td>\n",
" <td>0.002363</td>\n",
" <td>0.928702</td>\n",
" <td>0.929018</td>\n",
" <td>6.747068e-03</td>\n",
" <td>6.746612e-03</td>\n",
" <td>0.004165</td>\n",
" <td>0.006442</td>\n",
" <td>acc</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>titanic</td>\n",
" <td>0.786441</td>\n",
" <td>0.779661</td>\n",
" <td>0.806780</td>\n",
" <td>1.110223e-16</td>\n",
" <td>3.330669e-16</td>\n",
" <td>0.002522</td>\n",
" <td>0.001076</td>\n",
" <td>0.002300</td>\n",
" <td>0.002896</td>\n",
" <td>acc</td>\n",
" </tr>\n",
" </tbody>\n",
Expand All @@ -219,18 +207,18 @@
],
"text/plain": [
" dataset sk_mean rt_mean sk_std rt_std sk_time(s) \\\n",
"0 diabetes 0.313557 0.307939 4.120831e-02 4.077774e-02 0.002528 \n",
"1 housing 0.599732 0.599732 1.174950e-16 1.110223e-16 0.043352 \n",
"2 dgp 0.993509 0.993510 4.440892e-16 0.000000e+00 0.058138 \n",
"3 breast_cancer 0.929649 0.929895 6.988977e-03 6.601342e-03 0.004454 \n",
"4 titanic 0.786441 0.779661 1.110223e-16 3.330669e-16 0.002522 \n",
"0 diabetes 0.315319 0.270029 3.251468e-02 1.780794e-02 0.002659 \n",
"1 housing 0.599732 0.598390 1.336886e-16 0.000000e+00 0.042986 \n",
"2 dgp 0.993509 0.993510 4.440892e-16 0.000000e+00 0.056852 \n",
"3 breast_cancer 0.928702 0.929018 6.747068e-03 6.746612e-03 0.004165 \n",
"4 titanic 0.786441 0.806780 1.110223e-16 3.330669e-16 0.002300 \n",
"\n",
" rt_time(s) metric \n",
"0 0.001109 r2 \n",
"1 0.049112 r2 \n",
"2 0.147836 r2 \n",
"3 0.002363 acc \n",
"4 0.001076 acc "
"0 0.003520 r2 \n",
"1 0.060472 r2 \n",
"2 0.360891 r2 \n",
"3 0.006442 acc \n",
"4 0.002896 acc "
]
},
"execution_count": 4,
Expand All @@ -239,8 +227,8 @@
}
],
"source": [
"results_reg = [evaluate_dataset(d, \"reg\") for d in datasets[\"reg\"]]\n",
"results_clf = [evaluate_dataset(d, \"clf\") for d in datasets[\"clf\"]]\n",
"results_reg = [evaluate_dataset(d, \"reg\", model=\"dt\", max_depth=5, n_repeats=100) for d in datasets[\"reg\"]]\n",
"results_clf = [evaluate_dataset(d, \"clf\", model=\"dt\", max_depth=5, n_repeats=100) for d in datasets[\"clf\"]]\n",
"results = results_reg + results_clf\n",
"\n",
"cols = \"dataset sk_mean rt_mean sk_std rt_std sk_time(s) rt_time(s) metric\".split()\n",
Expand Down Expand Up @@ -289,56 +277,56 @@
" <tr>\n",
" <th>0</th>\n",
" <td>diabetes</td>\n",
" <td>0.430064</td>\n",
" <td>0.425717</td>\n",
" <td>0.025384</td>\n",
" <td>0.041980</td>\n",
" <td>0.049278</td>\n",
" <td>0.005452</td>\n",
" <td>0.437938</td>\n",
" <td>0.432859</td>\n",
" <td>0.009338</td>\n",
" <td>0.005773</td>\n",
" <td>0.114510</td>\n",
" <td>0.010676</td>\n",
" <td>r2</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>housing</td>\n",
" <td>0.642372</td>\n",
" <td>0.643270</td>\n",
" <td>0.002420</td>\n",
" <td>0.002528</td>\n",
" <td>0.096647</td>\n",
" <td>0.068624</td>\n",
" <td>0.439645</td>\n",
" <td>0.440555</td>\n",
" <td>0.000613</td>\n",
" <td>0.000857</td>\n",
" <td>0.255593</td>\n",
" <td>0.401618</td>\n",
" <td>r2</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>dgp</td>\n",
" <td>0.995718</td>\n",
" <td>0.995603</td>\n",
" <td>0.000076</td>\n",
" <td>0.000179</td>\n",
" <td>0.108993</td>\n",
" <td>0.267544</td>\n",
" <td>0.756377</td>\n",
" <td>0.756061</td>\n",
" <td>0.000342</td>\n",
" <td>0.000276</td>\n",
" <td>0.322776</td>\n",
" <td>2.913919</td>\n",
" <td>r2</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>breast_cancer</td>\n",
" <td>0.954386</td>\n",
" <td>0.941404</td>\n",
" <td>0.010048</td>\n",
" <td>0.009806</td>\n",
" <td>0.047959</td>\n",
" <td>0.008780</td>\n",
" <td>0.946667</td>\n",
" <td>0.937193</td>\n",
" <td>0.003438</td>\n",
" <td>0.003663</td>\n",
" <td>0.126519</td>\n",
" <td>0.025618</td>\n",
" <td>acc</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>titanic</td>\n",
" <td>0.793898</td>\n",
" <td>0.794915</td>\n",
" <td>0.012574</td>\n",
" <td>0.008058</td>\n",
" <td>0.048556</td>\n",
" <td>0.005197</td>\n",
" <td>0.763390</td>\n",
" <td>0.772881</td>\n",
" <td>0.004982</td>\n",
" <td>0.000000</td>\n",
" <td>0.140300</td>\n",
" <td>0.011944</td>\n",
" <td>acc</td>\n",
" </tr>\n",
" </tbody>\n",
Expand All @@ -347,18 +335,18 @@
],
"text/plain": [
" dataset sk_mean rt_mean sk_std rt_std sk_time(s) \\\n",
"0 diabetes 0.430064 0.425717 0.025384 0.041980 0.049278 \n",
"1 housing 0.642372 0.643270 0.002420 0.002528 0.096647 \n",
"2 dgp 0.995718 0.995603 0.000076 0.000179 0.108993 \n",
"3 breast_cancer 0.954386 0.941404 0.010048 0.009806 0.047959 \n",
"4 titanic 0.793898 0.794915 0.012574 0.008058 0.048556 \n",
"0 diabetes 0.437938 0.432859 0.009338 0.005773 0.114510 \n",
"1 housing 0.439645 0.440555 0.000613 0.000857 0.255593 \n",
"2 dgp 0.756377 0.756061 0.000342 0.000276 0.322776 \n",
"3 breast_cancer 0.946667 0.937193 0.003438 0.003663 0.126519 \n",
"4 titanic 0.763390 0.772881 0.004982 0.000000 0.140300 \n",
"\n",
" rt_time(s) metric \n",
"0 0.005452 r2 \n",
"1 0.068624 r2 \n",
"2 0.267544 r2 \n",
"3 0.008780 acc \n",
"4 0.005197 acc "
"0 0.010676 r2 \n",
"1 0.401618 r2 \n",
"2 2.913919 r2 \n",
"3 0.025618 acc \n",
"4 0.011944 acc "
]
},
"execution_count": 5,
Expand All @@ -367,8 +355,8 @@
}
],
"source": [
"results_reg = [evaluate_dataset(d, \"reg\", model=\"rf\", n_repeats=10) for d in datasets[\"reg\"]]\n",
"results_clf = [evaluate_dataset(d, \"clf\", model=\"rf\", n_repeats=10) for d in datasets[\"clf\"]]\n",
"results_reg = [evaluate_dataset(d, \"reg\", model=\"rf\", max_depth=2, n_estimators=100, n_repeats=10) for d in datasets[\"reg\"]]\n",
"results_clf = [evaluate_dataset(d, \"clf\", model=\"rf\", max_depth=2, n_estimators=100, n_repeats=10) for d in datasets[\"clf\"]]\n",
"results = results_reg + results_clf\n",
"\n",
"cols = \"dataset sk_mean rt_mean sk_std rt_std sk_time(s) rt_time(s) metric\".split()\n",
Expand All @@ -379,7 +367,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "292e7a1f",
"id": "b795007c",
"metadata": {},
"outputs": [],
"source": []
Expand All @@ -401,7 +389,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
"version": "3.10.9"
}
},
"nbformat": 4,
Expand Down
Loading

0 comments on commit 1ac556a

Please sign in to comment.