-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdecisionTreeWeight.py
41 lines (33 loc) · 1.62 KB
/
decisionTreeWeight.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
from collections import Counter
cars = [['med', 'low', '3', '4', 'med', 'med'], ['med', 'vhigh', '4', 'more', 'small', 'high'], ['high', 'med', '3', '2', 'med', 'low'], ['med', 'low', '4', '4', 'med', 'low'], ['med', 'low', '5more', '2', 'big', 'med'], ['med', 'med', '2', 'more', 'big', 'high'], ['med', 'med', '2', 'more', 'med', 'med'], ['vhigh', 'vhigh', '2', '2', 'med', 'low'], ['high', 'med', '4', '2', 'big', 'low'], ['low', 'low', '2', '4', 'big', 'med']]
car_labels = ['acc', 'acc', 'unacc', 'unacc', 'unacc', 'vgood', 'acc', 'unacc', 'unacc', 'good']
def split(dataset, labels, column):
data_subsets = []
label_subsets = []
counts = list(set([data[column] for data in dataset]))
counts.sort()
for k in counts:
new_data_subset = []
new_label_subset = []
for i in range(len(dataset)):
if dataset[i][column] == k:
new_data_subset.append(dataset[i])
new_label_subset.append(labels[i])
data_subsets.append(new_data_subset)
label_subsets.append(new_label_subset)
return data_subsets, label_subsets
def gini(dataset):
impurity = 1
label_counts = Counter(dataset)
for label in label_counts:
prob_of_label = label_counts[label] / len(dataset)
impurity -= prob_of_label ** 2
return impurity
def information_gain(starting_labels, split_labels):
info_gain = gini(starting_labels)
for subset in split_labels:
info_gain -= (gini(subset) * (len(subset)/len(starting_labels)))
return info_gain
for i in range(0, 6):
split_data, split_labels = split(cars, car_labels, i)
print(information_gain(car_labels, split_labels))