Skip to content

Commit

Permalink
feat: add neural forest ensemble (#110)
Browse files Browse the repository at this point in the history
* feat: add NeuralForestClassifier

* feat: add NeuralForestClassifier

* update code

* update code

* update code

* Update requirements.txt

* add doc

* Update README.rst

* update code
  • Loading branch information
xuyxu authored Apr 5, 2022
1 parent f532498 commit 9d26a9d
Show file tree
Hide file tree
Showing 13 changed files with 607 additions and 4 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ Changelog
Ver 0.1.*
---------

* |Feature| |API| Add :class:`NeuralForestClassifier` and :class:`NeuralForestRegressor` | `@xuyxu <https://github.com/xuyxu>`__
* |Fix| Relax check on input dataloader | `@xuyxu <https://github.com/xuyxu>`__
* |Feature| |API| Support arbitrary training criteria for all ensembles except Gradient Boosting | `@by256 <https://github.com/by256>`__ and `@xuyxu <https://github.com/xuyxu>`__
* |Fix| Fix missing functionality of ``save_model`` for :meth:`fit` of Soft Gradient Boosting | `@xuyxu <https://github.com/xuyxu>`__
Expand Down
2 changes: 2 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ Supported Ensemble
+------------------------------+------------+---------------------------+
| Voting [1]_ | Parallel | voting.py |
+------------------------------+------------+---------------------------+
| Neural Forest | Parallel | voting.py |
+------------------------------+------------+---------------------------+
| Bagging [2]_ | Parallel | bagging.py |
+------------------------------+------------+---------------------------+
| Gradient Boosting [3]_ | Sequential | gradient_boosting.py |
Expand Down
23 changes: 23 additions & 0 deletions docs/parameters.rst
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,29 @@ GradientBoostingRegressor
.. autoclass:: torchensemble.gradient_boosting.GradientBoostingRegressor
:members:

Neural Tree Ensemble
--------------------

Neural tree ensemble are extensions of voting and gradient boosting, which
uses the neural tree as the base estimator. Neural trees are differentiable
trees that uses the logistic regression in internal nodes to split samples
into child nodes with different probabilities. Model details are available at
`Distilling a neural network into a soft decision tree
<https://arxiv.org/abs/1711.09784>`_.


NeuralForestClassifier
**********************

.. autoclass:: torchensemble.voting.NeuralForestClassifier
:members:

NeuralForestRegressor
**********************

.. autoclass:: torchensemble.voting.NeuralForestRegressor
:members:

Snapshot Ensemble
-----------------

Expand Down
3 changes: 2 additions & 1 deletion docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@ sphinxemoji==0.1.8
sphinx-copybutton
m2r2==0.2.7
mistune==0.8.4
Jinja2<3.1
Jinja2<3.1
Numpy
75 changes: 75 additions & 0 deletions examples/classification_mnist_tree_ensemble.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import time
import torch
from torchvision import datasets, transforms

from torchensemble import NeuralForestClassifier
from torchensemble.utils.logging import set_logger


if __name__ == "__main__":

# Hyper-parameters
n_estimators = 5
depth = 5
lamda = 1e-3
lr = 1e-3
weight_decay = 5e-4
epochs = 50

# Utils
cuda = False
n_jobs = 1
batch_size = 128
data_dir = "../../Dataset/mnist" # MODIFY THIS IF YOU WANT

# Load data
train_loader = torch.utils.data.DataLoader(
datasets.MNIST(
data_dir,
train=True,
download=True,
transform=transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,)),
]
),
),
batch_size=batch_size,
shuffle=True,
)

test_loader = torch.utils.data.DataLoader(
datasets.MNIST(
data_dir,
train=False,
download=True,
transform=transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,)),
]
),
),
batch_size=batch_size,
shuffle=True,
)

logger = set_logger(
"classification_mnist_tree_ensemble", use_tb_logger=False
)

model = NeuralForestClassifier(
n_estimators=n_estimators,
depth=depth,
lamda=lamda,
cuda=cuda,
n_jobs=-1,
)

model.set_optimizer("Adam", lr=lr, weight_decay=weight_decay)

tic = time.time()
model.fit(train_loader, epochs=epochs, test_loader=test_loader)
toc = time.time()
training_time = toc - tic
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,4 @@ exclude = '''
| dist
| docs
)/
'''
'''
4 changes: 4 additions & 0 deletions torchensemble/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from .fusion import FusionRegressor
from .voting import VotingClassifier
from .voting import VotingRegressor
from .voting import NeuralForestClassifier
from .voting import NeuralForestRegressor
from .bagging import BaggingClassifier
from .bagging import BaggingRegressor
from .gradient_boosting import GradientBoostingClassifier
Expand All @@ -21,6 +23,8 @@
"FusionRegressor",
"VotingClassifier",
"VotingRegressor",
"NeuralForestClassifier",
"NeuralForestRegressor",
"BaggingClassifier",
"BaggingRegressor",
"GradientBoostingClassifier",
Expand Down
173 changes: 173 additions & 0 deletions torchensemble/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def get_doc(item):
__doc = {
"model": const.__model_doc,
"seq_model": const.__seq_model_doc,
"tree_ensmeble_model": const.__tree_ensemble_doc,
"fit": const.__fit_doc,
"predict": const.__predict_doc,
"set_optimizer": const.__set_optimizer_doc,
Expand Down Expand Up @@ -199,6 +200,50 @@ def predict(self, *x):
return pred


class BaseTreeEnsemble(BaseModule):
def __init__(
self,
n_estimators,
depth=5,
lamda=1e-3,
cuda=False,
n_jobs=None,
):
super(BaseModule, self).__init__()
self.base_estimator_ = BaseTree
self.n_estimators = n_estimators
self.depth = depth
self.lamda = lamda

self.device = torch.device("cuda" if cuda else "cpu")
self.n_jobs = n_jobs
self.logger = logging.getLogger()
self.tb_logger = get_tb_logger()

self.estimators_ = nn.ModuleList()
self.use_scheduler_ = False

def _decidce_n_inputs(self, train_loader):
"""Decide the input dimension according to the `train_loader`."""
for _, elem in enumerate(train_loader):
data = elem[0]
n_samples = data.size(0)
data = data.view(n_samples, -1)
return data.size(1)

def _make_estimator(self):
"""Make and configure a soft decision tree."""
estimator = BaseTree(
input_dim=self.n_inputs,
output_dim=self.n_outputs,
depth=self.depth,
lamda=self.lamda,
cuda=self.device == torch.device("cuda"),
)

return estimator.to(self.device)


class BaseClassifier(BaseModule):
"""Base class for all ensemble classifiers.
Expand Down Expand Up @@ -285,3 +330,131 @@ def evaluate(self, test_loader):
loss += self._criterion(output, target)

return float(loss) / len(test_loader)


class BaseTree(nn.Module):
"""Fast implementation of soft decision tree in PyTorch, copied from:
`https://github.com/xuyxu/Soft-Decision-Tree/blob/master/SDT.py`
"""

def __init__(self, input_dim, output_dim, depth=5, lamda=1e-3, cuda=False):
super(BaseTree, self).__init__()

self.input_dim = input_dim
self.output_dim = output_dim

self.depth = depth
self.lamda = lamda
self.device = torch.device("cuda" if cuda else "cpu")

self._validate_parameters()

self.internal_node_num_ = 2 ** self.depth - 1
self.leaf_node_num_ = 2 ** self.depth

# Different penalty coefficients for nodes in different layers
self.penalty_list = [
self.lamda * (2 ** (-depth)) for depth in range(0, self.depth)
]

# Initialize internal nodes and leaf nodes, the input dimension on
# internal nodes is added by 1, serving as the bias.
self.inner_nodes = nn.Sequential(
nn.Linear(self.input_dim + 1, self.internal_node_num_, bias=False),
nn.Sigmoid(),
)

self.leaf_nodes = nn.Linear(
self.leaf_node_num_, self.output_dim, bias=False
)

def forward(self, X, is_training_data=False):
_mu, _penalty = self._forward(X)
y_pred = self.leaf_nodes(_mu)

# When `X` is the training data, the model also returns the penalty
# to compute the training loss.
if is_training_data:
return y_pred, _penalty
else:
return y_pred

def _forward(self, X):
"""Implementation on the data forwarding process."""

batch_size = X.size()[0]
X = self._data_augment(X)

path_prob = self.inner_nodes(X)
path_prob = torch.unsqueeze(path_prob, dim=2)
path_prob = torch.cat((path_prob, 1 - path_prob), dim=2)

_mu = X.data.new(batch_size, 1, 1).fill_(1.0)
_penalty = torch.tensor(0.0).to(self.device)

# Iterate through internal odes in each layer to compute the final path
# probabilities and the regularization term.
begin_idx = 0
end_idx = 1

for layer_idx in range(0, self.depth):
_path_prob = path_prob[:, begin_idx:end_idx, :]

# Extract internal nodes in the current layer to compute the
# regularization term
_penalty = _penalty + self._cal_penalty(layer_idx, _mu, _path_prob)
_mu = _mu.view(batch_size, -1, 1).repeat(1, 1, 2)

_mu = _mu * _path_prob # update path probabilities

begin_idx = end_idx
end_idx = begin_idx + 2 ** (layer_idx + 1)

mu = _mu.view(batch_size, self.leaf_node_num_)

return mu, _penalty

def _cal_penalty(self, layer_idx, _mu, _path_prob):
"""Compute the regularization term for internal nodes"""

penalty = torch.tensor(0.0).to(self.device)

batch_size = _mu.size()[0]
_mu = _mu.view(batch_size, 2 ** layer_idx)
_path_prob = _path_prob.view(batch_size, 2 ** (layer_idx + 1))

for node in range(0, 2 ** (layer_idx + 1)):
alpha = torch.sum(
_path_prob[:, node] * _mu[:, node // 2], dim=0
) / torch.sum(_mu[:, node // 2], dim=0)

coeff = self.penalty_list[layer_idx]

penalty -= 0.5 * coeff * (torch.log(alpha) + torch.log(1 - alpha))

return penalty

def _data_augment(self, X):
"""Add a constant input `1` onto the front of each sample."""
batch_size = X.size()[0]
X = X.view(batch_size, -1)
bias = torch.ones(batch_size, 1).to(self.device)
X = torch.cat((bias, X), 1)

return X

def _validate_parameters(self):

if not self.depth > 0:
msg = (
"The tree depth should be strictly positive, but got {}"
"instead."
)
raise ValueError(msg.format(self.depth))

if not self.lamda >= 0:
msg = (
"The coefficient of the regularization term should not be"
" negative, but got {} instead."
)
raise ValueError(msg.format(self.lamda))
30 changes: 30 additions & 0 deletions torchensemble/_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,36 @@
"""


__tree_ensemble_doc = """
Parameters
----------
n_estimators : int
The number of neural trees in the ensemble.
depth : int, default=5
The depth of neural tree. A tree with depth ``d`` is with :math:`2^d`
leaf nodes and :math:`2^d-1` internal nodes.
lamda : float, default=1e-3
The coefficient of the regularization term when training neural
trees, proposed in the paper: `Distilling a neural network into a
soft decision tree <https://arxiv.org/abs/1711.09784>`_.
cuda : bool, default=True
- If ``True``, use GPU to train and evaluate the ensemble.
- If ``False``, use CPU to train and evaluate the ensemble.
n_jobs : int, default=None
The number of workers for training the ensemble. This input
argument is used for parallel ensemble methods such as
:mod:`voting` and :mod:`bagging`. Setting it to an integer larger
than ``1`` enables ``n_jobs`` base estimators to be trained
simultaneously.
Attributes
----------
estimators_ : torch.nn.ModuleList
An internal container that stores all fitted base estimators.
"""


__set_optimizer_doc = """
Parameters
----------
Expand Down
Loading

0 comments on commit 9d26a9d

Please sign in to comment.