From 9c43abdfcc0c6a92cf272c582a761cf9593a9355 Mon Sep 17 00:00:00 2001 From: T4ras123 Date: Fri, 13 Sep 2024 21:21:33 +0400 Subject: [PATCH] Polished code, added f-strings, and added a new dataset, improved the readability of the code. --- decision tree classification.ipynb | 114 +++++++++++++--------- iris.csv | 151 +++++++++++++++++++++++++++++ 2 files changed, 217 insertions(+), 48 deletions(-) create mode 100644 iris.csv diff --git a/decision tree classification.ipynb b/decision tree classification.ipynb index 3ffdabe..ff34977 100644 --- a/decision tree classification.ipynb +++ b/decision tree classification.ipynb @@ -9,7 +9,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -26,7 +26,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 4, "metadata": { "scrolled": false }, @@ -66,7 +66,7 @@ " 3.5\n", " 1.4\n", " 0.2\n", - " 0\n", + " Setosa\n", " \n", " \n", " 1\n", @@ -74,7 +74,7 @@ " 3.0\n", " 1.4\n", " 0.2\n", - " 0\n", + " Setosa\n", " \n", " \n", " 2\n", @@ -82,7 +82,7 @@ " 3.2\n", " 1.3\n", " 0.2\n", - " 0\n", + " Setosa\n", " \n", " \n", " 3\n", @@ -90,7 +90,7 @@ " 3.1\n", " 1.5\n", " 0.2\n", - " 0\n", + " Setosa\n", " \n", " \n", " 4\n", @@ -98,7 +98,7 @@ " 3.6\n", " 1.4\n", " 0.2\n", - " 0\n", + " Setosa\n", " \n", " \n", " 5\n", @@ -106,7 +106,7 @@ " 3.9\n", " 1.7\n", " 0.4\n", - " 0\n", + " Setosa\n", " \n", " \n", " 6\n", @@ -114,7 +114,7 @@ " 3.4\n", " 1.4\n", " 0.3\n", - " 0\n", + " Setosa\n", " \n", " \n", " 7\n", @@ -122,7 +122,7 @@ " 3.4\n", " 1.5\n", " 0.2\n", - " 0\n", + " Setosa\n", " \n", " \n", " 8\n", @@ -130,7 +130,7 @@ " 2.9\n", " 1.4\n", " 0.2\n", - " 0\n", + " Setosa\n", " \n", " \n", " 9\n", @@ -138,27 +138,27 @@ " 3.1\n", " 1.5\n", " 0.1\n", - " 0\n", + " Setosa\n", " \n", " \n", "\n", "" ], "text/plain": [ - " sepal_length sepal_width petal_length petal_width type\n", - "0 5.1 3.5 1.4 0.2 0\n", - "1 4.9 3.0 1.4 0.2 0\n", - "2 4.7 3.2 1.3 0.2 0\n", - "3 4.6 3.1 1.5 0.2 0\n", - "4 5.0 3.6 1.4 0.2 0\n", - "5 5.4 3.9 1.7 0.4 0\n", - "6 4.6 3.4 1.4 0.3 0\n", - "7 5.0 3.4 1.5 0.2 0\n", - "8 4.4 2.9 1.4 0.2 0\n", - "9 4.9 3.1 1.5 0.1 0" + " sepal_length sepal_width petal_length petal_width type\n", + "0 5.1 3.5 1.4 0.2 Setosa\n", + "1 4.9 3.0 1.4 0.2 Setosa\n", + "2 4.7 3.2 1.3 0.2 Setosa\n", + "3 4.6 3.1 1.5 0.2 Setosa\n", + "4 5.0 3.6 1.4 0.2 Setosa\n", + "5 5.4 3.9 1.7 0.4 Setosa\n", + "6 4.6 3.4 1.4 0.3 Setosa\n", + "7 5.0 3.4 1.5 0.2 Setosa\n", + "8 4.4 2.9 1.4 0.2 Setosa\n", + "9 4.9 3.1 1.5 0.1 Setosa" ] }, - "execution_count": 2, + "execution_count": 4, "metadata": {}, "output_type": "execute_result" } @@ -178,7 +178,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ @@ -206,7 +206,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 6, "metadata": {}, "outputs": [], "source": [ @@ -221,6 +221,7 @@ " self.min_samples_split = min_samples_split\n", " self.max_depth = max_depth\n", " \n", + " \n", " def build_tree(self, dataset, curr_depth=0):\n", " ''' recursive function to build the tree ''' \n", " \n", @@ -229,14 +230,17 @@ " \n", " # split until stopping conditions are met\n", " if num_samples>=self.min_samples_split and curr_depth<=self.max_depth:\n", + " \n", " # find the best split\n", - " best_split = self.get_best_split(dataset, num_samples, num_features)\n", + " best_split = self.get_best_split(dataset, num_features)\n", + " \n", " # check if information gain is positive\n", " if best_split[\"info_gain\"]>0:\n", " # recur left\n", " left_subtree = self.build_tree(best_split[\"dataset_left\"], curr_depth+1)\n", " # recur right\n", " right_subtree = self.build_tree(best_split[\"dataset_right\"], curr_depth+1)\n", + " \n", " # return decision node\n", " return Node(best_split[\"feature_index\"], best_split[\"threshold\"], \n", " left_subtree, right_subtree, best_split[\"info_gain\"])\n", @@ -246,7 +250,8 @@ " # return leaf node\n", " return Node(value=leaf_value)\n", " \n", - " def get_best_split(self, dataset, num_samples, num_features):\n", + " \n", + " def get_best_split(self, dataset, num_features):\n", " ''' function to find the best split '''\n", " \n", " # dictionary to store the best split\n", @@ -257,16 +262,19 @@ " for feature_index in range(num_features):\n", " feature_values = dataset[:, feature_index]\n", " possible_thresholds = np.unique(feature_values)\n", + " \n", " # loop over all the feature values present in the data\n", " for threshold in possible_thresholds:\n", " # get current split\n", " dataset_left, dataset_right = self.split(dataset, feature_index, threshold)\n", + " \n", " # check if childs are not null\n", " if len(dataset_left)>0 and len(dataset_right)>0:\n", " y, left_y, right_y = dataset[:, -1], dataset_left[:, -1], dataset_right[:, -1]\n", " # compute information gain\n", " curr_info_gain = self.information_gain(y, left_y, right_y, \"gini\")\n", - " # update the best split if needed\n", + "\n", + " # update the best split if needed \n", " if curr_info_gain>max_info_gain:\n", " best_split[\"feature_index\"] = feature_index\n", " best_split[\"threshold\"] = threshold\n", @@ -278,6 +286,7 @@ " # return best split\n", " return best_split\n", " \n", + " \n", " def split(self, dataset, feature_index, threshold):\n", " ''' function to split the data '''\n", " \n", @@ -285,17 +294,20 @@ " dataset_right = np.array([row for row in dataset if row[feature_index]>threshold])\n", " return dataset_left, dataset_right\n", " \n", + " \n", " def information_gain(self, parent, l_child, r_child, mode=\"entropy\"):\n", " ''' function to compute information gain '''\n", " \n", " weight_l = len(l_child) / len(parent)\n", " weight_r = len(r_child) / len(parent)\n", + " \n", " if mode==\"gini\":\n", " gain = self.gini_index(parent) - (weight_l*self.gini_index(l_child) + weight_r*self.gini_index(r_child))\n", " else:\n", " gain = self.entropy(parent) - (weight_l*self.entropy(l_child) + weight_r*self.entropy(r_child))\n", " return gain\n", " \n", + " \n", " def entropy(self, y):\n", " ''' function to compute entropy '''\n", " \n", @@ -306,6 +318,7 @@ " entropy += -p_cls * np.log2(p_cls)\n", " return entropy\n", " \n", + " \n", " def gini_index(self, y):\n", " ''' function to compute gini index '''\n", " \n", @@ -316,12 +329,14 @@ " gini += p_cls**2\n", " return 1 - gini\n", " \n", + " \n", " def calculate_leaf_value(self, Y):\n", " ''' function to compute leaf node '''\n", " \n", " Y = list(Y)\n", " return max(Y, key=Y.count)\n", " \n", + " \n", " def print_tree(self, tree=None, indent=\" \"):\n", " ''' function to print the tree '''\n", " \n", @@ -332,17 +347,19 @@ " print(tree.value)\n", "\n", " else:\n", - " print(\"X_\"+str(tree.feature_index), \"<=\", tree.threshold, \"?\", tree.info_gain)\n", + " print(f'X_{str(tree.feature_index)} <= {tree.threshold} ? {tree.info_gain}')\n", " print(\"%sleft:\" % (indent), end=\"\")\n", " self.print_tree(tree.left, indent + indent)\n", " print(\"%sright:\" % (indent), end=\"\")\n", " self.print_tree(tree.right, indent + indent)\n", + " \n", " \n", " def fit(self, X, Y):\n", " ''' function to train the tree '''\n", " \n", " dataset = np.concatenate((X, Y), axis=1)\n", " self.root = self.build_tree(dataset)\n", + " \n", " \n", " def predict(self, X):\n", " ''' function to predict new dataset '''\n", @@ -350,12 +367,13 @@ " preditions = [self.make_prediction(x, self.root) for x in X]\n", " return preditions\n", " \n", + " \n", " def make_prediction(self, x, tree):\n", " ''' function to predict a single data point '''\n", " \n", - " if tree.value!=None: return tree.value\n", + " if tree.value is not None: return tree.value\n", " feature_val = x[tree.feature_index]\n", - " if feature_val<=tree.threshold:\n", + " if feature_val <= tree.threshold:\n", " return self.make_prediction(x, tree.left)\n", " else:\n", " return self.make_prediction(x, tree.right)" @@ -370,7 +388,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 7, "metadata": {}, "outputs": [], "source": [ @@ -389,24 +407,24 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "X_2 <= 1.9 ? 0.33741385372714494\n", - " left:0.0\n", - " right:X_3 <= 1.5 ? 0.427106638180289\n", - " left:X_2 <= 4.9 ? 0.05124653739612173\n", - " left:1.0\n", - " right:2.0\n", - " right:X_2 <= 5.0 ? 0.019631171921475288\n", - " left:X_1 <= 2.8 ? 0.20833333333333334\n", - " left:2.0\n", - " right:1.0\n", - " right:2.0\n" + "X_2 <= 1.9 ? 0.33741385372714494\n", + " left:Setosa\n", + " right:X_3 <= 1.5 ? 0.427106638180289\n", + " left:X_2 <= 4.9 ? 0.05124653739612173\n", + " left:Versicolor\n", + " right:Virginica\n", + " right:X_2 <= 5.0 ? 0.019631171921475288\n", + " left:X_1 <= 2.8 ? 0.20833333333333334\n", + " left:Virginica\n", + " right:Versicolor\n", + " right:Virginica\n" ] } ], @@ -425,7 +443,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 9, "metadata": {}, "outputs": [ { @@ -434,7 +452,7 @@ "0.9333333333333333" ] }, - "execution_count": 7, + "execution_count": 9, "metadata": {}, "output_type": "execute_result" } @@ -462,7 +480,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.5" + "version": "3.12.3" } }, "nbformat": 4, diff --git a/iris.csv b/iris.csv new file mode 100644 index 0000000..1b9d029 --- /dev/null +++ b/iris.csv @@ -0,0 +1,151 @@ +"sepal.length","sepal.width","petal.length","petal.width","variety" +5.1,3.5,1.4,.2,"Setosa" +4.9,3,1.4,.2,"Setosa" +4.7,3.2,1.3,.2,"Setosa" +4.6,3.1,1.5,.2,"Setosa" +5,3.6,1.4,.2,"Setosa" +5.4,3.9,1.7,.4,"Setosa" +4.6,3.4,1.4,.3,"Setosa" +5,3.4,1.5,.2,"Setosa" +4.4,2.9,1.4,.2,"Setosa" +4.9,3.1,1.5,.1,"Setosa" +5.4,3.7,1.5,.2,"Setosa" +4.8,3.4,1.6,.2,"Setosa" +4.8,3,1.4,.1,"Setosa" +4.3,3,1.1,.1,"Setosa" +5.8,4,1.2,.2,"Setosa" +5.7,4.4,1.5,.4,"Setosa" +5.4,3.9,1.3,.4,"Setosa" +5.1,3.5,1.4,.3,"Setosa" +5.7,3.8,1.7,.3,"Setosa" +5.1,3.8,1.5,.3,"Setosa" +5.4,3.4,1.7,.2,"Setosa" +5.1,3.7,1.5,.4,"Setosa" +4.6,3.6,1,.2,"Setosa" +5.1,3.3,1.7,.5,"Setosa" +4.8,3.4,1.9,.2,"Setosa" +5,3,1.6,.2,"Setosa" +5,3.4,1.6,.4,"Setosa" +5.2,3.5,1.5,.2,"Setosa" +5.2,3.4,1.4,.2,"Setosa" +4.7,3.2,1.6,.2,"Setosa" +4.8,3.1,1.6,.2,"Setosa" +5.4,3.4,1.5,.4,"Setosa" +5.2,4.1,1.5,.1,"Setosa" +5.5,4.2,1.4,.2,"Setosa" +4.9,3.1,1.5,.2,"Setosa" +5,3.2,1.2,.2,"Setosa" +5.5,3.5,1.3,.2,"Setosa" +4.9,3.6,1.4,.1,"Setosa" +4.4,3,1.3,.2,"Setosa" +5.1,3.4,1.5,.2,"Setosa" +5,3.5,1.3,.3,"Setosa" +4.5,2.3,1.3,.3,"Setosa" +4.4,3.2,1.3,.2,"Setosa" +5,3.5,1.6,.6,"Setosa" +5.1,3.8,1.9,.4,"Setosa" +4.8,3,1.4,.3,"Setosa" +5.1,3.8,1.6,.2,"Setosa" +4.6,3.2,1.4,.2,"Setosa" +5.3,3.7,1.5,.2,"Setosa" +5,3.3,1.4,.2,"Setosa" +7,3.2,4.7,1.4,"Versicolor" +6.4,3.2,4.5,1.5,"Versicolor" +6.9,3.1,4.9,1.5,"Versicolor" +5.5,2.3,4,1.3,"Versicolor" +6.5,2.8,4.6,1.5,"Versicolor" +5.7,2.8,4.5,1.3,"Versicolor" +6.3,3.3,4.7,1.6,"Versicolor" +4.9,2.4,3.3,1,"Versicolor" +6.6,2.9,4.6,1.3,"Versicolor" +5.2,2.7,3.9,1.4,"Versicolor" +5,2,3.5,1,"Versicolor" +5.9,3,4.2,1.5,"Versicolor" +6,2.2,4,1,"Versicolor" +6.1,2.9,4.7,1.4,"Versicolor" +5.6,2.9,3.6,1.3,"Versicolor" +6.7,3.1,4.4,1.4,"Versicolor" +5.6,3,4.5,1.5,"Versicolor" +5.8,2.7,4.1,1,"Versicolor" +6.2,2.2,4.5,1.5,"Versicolor" +5.6,2.5,3.9,1.1,"Versicolor" +5.9,3.2,4.8,1.8,"Versicolor" +6.1,2.8,4,1.3,"Versicolor" +6.3,2.5,4.9,1.5,"Versicolor" +6.1,2.8,4.7,1.2,"Versicolor" +6.4,2.9,4.3,1.3,"Versicolor" +6.6,3,4.4,1.4,"Versicolor" +6.8,2.8,4.8,1.4,"Versicolor" +6.7,3,5,1.7,"Versicolor" +6,2.9,4.5,1.5,"Versicolor" +5.7,2.6,3.5,1,"Versicolor" +5.5,2.4,3.8,1.1,"Versicolor" +5.5,2.4,3.7,1,"Versicolor" +5.8,2.7,3.9,1.2,"Versicolor" +6,2.7,5.1,1.6,"Versicolor" +5.4,3,4.5,1.5,"Versicolor" +6,3.4,4.5,1.6,"Versicolor" +6.7,3.1,4.7,1.5,"Versicolor" +6.3,2.3,4.4,1.3,"Versicolor" +5.6,3,4.1,1.3,"Versicolor" +5.5,2.5,4,1.3,"Versicolor" +5.5,2.6,4.4,1.2,"Versicolor" +6.1,3,4.6,1.4,"Versicolor" +5.8,2.6,4,1.2,"Versicolor" +5,2.3,3.3,1,"Versicolor" +5.6,2.7,4.2,1.3,"Versicolor" +5.7,3,4.2,1.2,"Versicolor" +5.7,2.9,4.2,1.3,"Versicolor" +6.2,2.9,4.3,1.3,"Versicolor" +5.1,2.5,3,1.1,"Versicolor" +5.7,2.8,4.1,1.3,"Versicolor" +6.3,3.3,6,2.5,"Virginica" +5.8,2.7,5.1,1.9,"Virginica" +7.1,3,5.9,2.1,"Virginica" +6.3,2.9,5.6,1.8,"Virginica" +6.5,3,5.8,2.2,"Virginica" +7.6,3,6.6,2.1,"Virginica" +4.9,2.5,4.5,1.7,"Virginica" +7.3,2.9,6.3,1.8,"Virginica" +6.7,2.5,5.8,1.8,"Virginica" +7.2,3.6,6.1,2.5,"Virginica" +6.5,3.2,5.1,2,"Virginica" +6.4,2.7,5.3,1.9,"Virginica" +6.8,3,5.5,2.1,"Virginica" +5.7,2.5,5,2,"Virginica" +5.8,2.8,5.1,2.4,"Virginica" +6.4,3.2,5.3,2.3,"Virginica" +6.5,3,5.5,1.8,"Virginica" +7.7,3.8,6.7,2.2,"Virginica" +7.7,2.6,6.9,2.3,"Virginica" +6,2.2,5,1.5,"Virginica" +6.9,3.2,5.7,2.3,"Virginica" +5.6,2.8,4.9,2,"Virginica" +7.7,2.8,6.7,2,"Virginica" +6.3,2.7,4.9,1.8,"Virginica" +6.7,3.3,5.7,2.1,"Virginica" +7.2,3.2,6,1.8,"Virginica" +6.2,2.8,4.8,1.8,"Virginica" +6.1,3,4.9,1.8,"Virginica" +6.4,2.8,5.6,2.1,"Virginica" +7.2,3,5.8,1.6,"Virginica" +7.4,2.8,6.1,1.9,"Virginica" +7.9,3.8,6.4,2,"Virginica" +6.4,2.8,5.6,2.2,"Virginica" +6.3,2.8,5.1,1.5,"Virginica" +6.1,2.6,5.6,1.4,"Virginica" +7.7,3,6.1,2.3,"Virginica" +6.3,3.4,5.6,2.4,"Virginica" +6.4,3.1,5.5,1.8,"Virginica" +6,3,4.8,1.8,"Virginica" +6.9,3.1,5.4,2.1,"Virginica" +6.7,3.1,5.6,2.4,"Virginica" +6.9,3.1,5.1,2.3,"Virginica" +5.8,2.7,5.1,1.9,"Virginica" +6.8,3.2,5.9,2.3,"Virginica" +6.7,3.3,5.7,2.5,"Virginica" +6.7,3,5.2,2.3,"Virginica" +6.3,2.5,5,1.9,"Virginica" +6.5,3,5.2,2,"Virginica" +6.2,3.4,5.4,2.3,"Virginica" +5.9,3,5.1,1.8,"Virginica" \ No newline at end of file