-
Notifications
You must be signed in to change notification settings - Fork 11
/
Copy pathxforest.go
169 lines (163 loc) · 3.79 KB
/
xforest.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
package randomforest
import (
"math"
"math/rand"
"sync"
)
// TrainX Extremely randomized trees
func (forest *Forest) TrainX(trees int) {
forest.NSize = len(forest.Data.X)
forest.Features = len(forest.Data.X[0])
forest.NTrees = trees
forest.Trees = make([]Tree, forest.NTrees)
forest.Classes = 0
for _, c := range forest.Data.Class {
if c >= forest.Classes {
forest.Classes = c + 1
}
}
if forest.MFeatures == 0 {
forest.MFeatures = int(math.Sqrt(float64(forest.Features)))
}
if forest.LeafSize == 0 {
forest.LeafSize = forest.NSize / 20
if forest.LeafSize <= 0 {
forest.LeafSize = 1
} else if forest.LeafSize > 50 {
forest.LeafSize = 50
}
}
var wg sync.WaitGroup
wg.Add(trees)
for i := 0; i < trees; i++ {
go forest.newXTree(i, &wg)
}
wg.Wait()
imp := make([]float64, forest.Features)
for i := 0; i < trees; i++ {
z := forest.Trees[i].importance(forest)
for i := 0; i < forest.Features; i++ {
imp[i] += z[i]
}
}
for i := 0; i < forest.Features; i++ {
imp[i] = imp[i] / float64(trees)
}
forest.FeatureImportance = imp
}
// Calculate a new tree in forest.
func (forest *Forest) newXTree(index int, wg *sync.WaitGroup) {
defer wg.Done()
//data
used := make([]bool, forest.NSize)
x := make([][]float64, forest.NSize)
results := make([]int, forest.NSize)
for i := 0; i < forest.NSize; i++ {
k := rand.Intn(forest.NSize)
x[i] = forest.Data.X[k]
results[i] = forest.Data.Class[k]
used[k] = true
}
// build Root
root := Branch{}
root.xbuild(forest, x, results, 1)
tree := Tree{Root: root}
// validation test tree
count := 0
e := 0.0
for i := 0; i < forest.NSize; i++ {
if !used[i] {
count++
v := root.vote(forest.Data.X[i])
e += v[forest.Data.Class[i]]
}
}
tree.Validation = e / float64(count)
// add tree
mux.Lock()
forest.Trees[index] = tree
mux.Unlock()
}
func (branch *Branch) xbuild(forest *Forest, x [][]float64, class []int, depth int) {
//fmt.Println(repeat(".", depth), depth, len(x))
classCount := make([]int, forest.Classes)
for _, r := range class {
classCount[r]++
}
branch.Gini = gini(classCount)
branch.Size = len(class)
branch.Depth = depth
if (len(x) <= forest.LeafSize) || (branch.Gini == 0) {
branch.IsLeaf = true
branch.LeafValue = make([]float64, forest.Classes)
for i, r := range classCount {
if branch.Size > 0 {
branch.LeafValue[i] = float64(r) / float64(branch.Size)
}
}
return
}
//find best extremely random split
attrsRandom := rand.Perm(forest.Features)[:forest.MFeatures]
var bestAtrr int
var bestSplit float64
var bestGini = 1.0
for _, a := range attrsRandom {
//find min and max
min := x[0][a]
max := x[0][a]
for i := 1; i < branch.Size; i++ {
if x[i][a] > max {
max = x[i][a]
}
if x[i][a] < min {
min = x[i][a]
}
}
if max == min {
continue
}
split := (max + min) / 2
s1 := make([]int, forest.Classes)
s2 := make([]int, forest.Classes)
c1 := 0
copy(s2, classCount)
for i := 0; i < branch.Size; i++ {
if x[i][a] > split {
s1[class[i]]++
s2[class[i]]--
c1++
}
}
g1 := gini(s1)
g2 := gini(s2)
wg := (g1*float64(c1) + g2*float64(branch.Size-c1)) / float64(branch.Size)
if wg < bestGini {
bestGini = wg
bestSplit = split
bestAtrr = a
}
}
//split it
branch.GiniGain = branch.Gini - bestGini
branch.Attribute = bestAtrr
branch.Value = bestSplit
x0 := make([][]float64, 0)
x1 := make([][]float64, 0)
c0 := make([]int, 0)
c1 := make([]int, 0)
for i := 0; i < branch.Size; i++ {
if x[i][branch.Attribute] > branch.Value {
x1 = append(x1, x[i])
c1 = append(c1, class[i])
} else {
x0 = append(x0, x[i])
c0 = append(c0, class[i])
}
}
//create branches
branch.Branch0 = &Branch{}
branch.Branch1 = &Branch{}
branch.Branch0.xbuild(forest, x0, c0, depth+1)
branch.Branch1.xbuild(forest, x1, c1, depth+1)
}