-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtreeSampler.hpp
175 lines (138 loc) · 4.2 KB
/
treeSampler.hpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
#ifndef TREE_SAMPLER_HPP
#define TREE_SAMPLER_HPP
#include <cstdlib>
#include <cstddef>
#include <vector>
#include <cassert>
#include <algorithm>
#include <iostream>
#include <cmath>
#include <fstream>
#include <random>
#include "stepRegression.hpp"
#include "asymmTree.hpp"
template<class pointType>
class TreeSampler
{
public:
typedef typename pointType::realScalarType RealScalarType;
typedef asymmTree<pointType> AsymmTreeType;
typedef std::vector<pointType> PointsArrayType;
typedef StepRegression RegressionType;
typedef Eigen::VectorXd VectorXd;
TreeSampler()
{
mTreeRoot = new AsymmTreeType;
}
~TreeSampler()
{
delete(mTreeRoot);
}
void setup(pointType const& boundMin,
pointType const& boundMax,
size_t const thresholdForBranching,
size_t const treeIndex,
size_t const level
) {
mTreeRoot->setup(boundMin,boundMax,thresholdForBranching,treeIndex,level);
setup();
}
void addPoints(PointsArrayType const& points,bool const makeTree = true)
{
mTreeRoot->addPoints(points,makeTree);
}
void dumpTree(std::ofstream & outFile)
{
mTreeRoot->dumpTree(outFile);
}
template<class RNGType>
pointType walkRandomPoint(RNGType & rng)
{
return randomPoint(rng);
}
void deleteNodes(RealScalarType const weightStar)
{
mTreeRoot->deleteNodes(weightStar);
}
// void setup(AsymmTreeType* arg_tree_root) {
void setup()
{
// Sorted list of active nodes by ascending volume
// mTreeRoot = arg_tree_root;
list_active_nodes(mTreeRoot,mActiveNodes);
std::sort(mActiveNodes.begin(),mActiveNodes.end(),compareVolumes);
// Number of active nodes
size_t nnodes = mActiveNodes.size();
// Interpolation vectors
VectorXd idx = VectorXd(nnodes+1);
VectorXd fcvol = VectorXd(nnodes+1);
// Initial values
idx[0] = 0.;
fcvol[0] = 0.;
// Store cumulative volumes and reference indices
for(size_t ii=0;ii<nnodes;ii++)
{
idx[ii+1] = ii+1;
fcvol[ii+1] = fcvol[ii] + nodeVolume(mActiveNodes[ii]);
}
// Inverse total volume
RealScalarType icvol = 1. / fcvol[nnodes];
// Convert to cumulative fractional volumes
for(size_t ii=0;ii<nnodes;ii++) { fcvol[ii+1] *= icvol; }
// Fit step regressor
mRegressionIndex.setup(fcvol,idx);
}
template<class RNGType>
pointType randomPoint(RNGType & rng)
{
std::uniform_real_distribution<> distUniReal;
size_t random_node_index = mRegressionIndex(distUniReal(rng));
assert(random_node_index >= 0 && random_node_index < mActiveNodes.size());
return mActiveNodes[random_node_index]->walkRandomPoint(rng);
}
private:
// Build a standard vector of all 'active' (leaf) nodes in the tree
void list_active_nodes(AsymmTreeType* arg_tree_root, std::vector<AsymmTreeType*> &arg_node_list)
{
arg_node_list.clear();
recurse_active_nodes(arg_tree_root,arg_node_list);
}
// Recursively add leaf (sub)nodes to a vector of nodes
void recurse_active_nodes(AsymmTreeType* arg_tree_node, std::vector<AsymmTreeType*> &arg_node_list)
{
if(!( arg_tree_node->hasLeftSubTree() || arg_tree_node->hasRightSubTree() ))
{
arg_node_list.push_back(arg_tree_node);
return;
}
if(arg_tree_node->hasLeftSubTree())
{
recurse_active_nodes(arg_tree_node->leftSubTree(),arg_node_list);
}
if(arg_tree_node->hasRightSubTree())
{
recurse_active_nodes(arg_tree_node->rightSubTree(),arg_node_list);
}
}
// Return the volume of the given node
static RealScalarType nodeVolume(AsymmTreeType* arg_node)
{
RealScalarType volume = 1.;
pointType boundMin;
pointType boundMax;
arg_node->getBounds(boundMin,boundMax,arg_node->treeIndex());
assert(boundMin.size() == boundMax.size());
size_t ndim = boundMin.size();
for(size_t ii=0; ii<ndim; ii++) { volume *= (boundMax[ii] - boundMin[ii]); }
return volume;
}
// return true if the volume of the node lhs is greter than that of rhs
static bool compareVolumes(AsymmTreeType* lhs, AsymmTreeType* rhs)
{
return nodeVolume(lhs) > nodeVolume(rhs);
}
AsymmTreeType* mTreeRoot;
std::vector<AsymmTreeType*> mActiveNodes;
RegressionType mRegressionIndex;
};
#endif // TREE_SAMPLER_HPP