Skip to content

Commit

Permalink
feat: add hard majority voting strategy for parallel ensembles (#126)
Browse files Browse the repository at this point in the history
* initial attempt to implementing hard majority voting

* fixed linting

* removed comments etc

* moved hard voting to utils, added test

* fixed linting

* fixed linting 2

* add changelog

* switched to value error

* flake8 formating

Co-authored-by: xuyxu <[email protected]>
  • Loading branch information
gardberg and xuyxu authored Oct 6, 2022
1 parent 3a96d34 commit 061bef7
Show file tree
Hide file tree
Showing 6 changed files with 89 additions and 8 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ Changelog
Ver 0.1.*
---------

* |Feature| |API| Add ``voting_strategy`` parameter for :class:`VotingClassifer`, :class:`NeuralForestClassifier`, and :class:`SnapshotEnsembleClassifier` | `@LukasGardberg <https://github.com/LukasGardberg>`__
* |Fix| Fix the sampling issue in :class:`BaggingClassifier` and :class:`BaggingRegressor` | `@SunHaozhe <https://github.com/SunHaozhe>`__
* |Feature| |API| Add :class:`NeuralForestClassifier` and :class:`NeuralForestRegressor` | `@xuyxu <https://github.com/xuyxu>`__
* |Fix| Relax check on input dataloader | `@xuyxu <https://github.com/xuyxu>`__
Expand Down
5 changes: 4 additions & 1 deletion torchensemble/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def predict(self, *x):
class BaseTreeEnsemble(BaseModule):
def __init__(
self,
n_estimators,
n_estimators=10,
depth=5,
lamda=1e-3,
cuda=False,
Expand Down Expand Up @@ -280,8 +280,11 @@ def evaluate(self, test_loader, return_loss=False):

for _, elem in enumerate(test_loader):
data, target = split_data_target(elem, self.device)

output = self.forward(*data)

_, predicted = torch.max(output.data, 1)

correct += (predicted == target).sum().item()
total += target.size(0)
loss += self._criterion(output, target)
Expand Down
18 changes: 16 additions & 2 deletions torchensemble/snapshot_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,14 +209,28 @@ def set_scheduler(self, scheduler_name, **kwargs):
"""Implementation on the SnapshotEnsembleClassifier.""", "seq_model"
)
class SnapshotEnsembleClassifier(_BaseSnapshotEnsemble, BaseClassifier):
def __init__(self, voting_strategy="soft", **kwargs):
super().__init__(**kwargs)

self.voting_strategy = voting_strategy

@torchensemble_model_doc(
"""Implementation on the data forwarding in SnapshotEnsembleClassifier.""", # noqa: E501
"classifier_forward",
)
def forward(self, *x):
proba = self._forward(*x)

return F.softmax(proba, dim=1)
outputs = [
F.softmax(estimator(*x), dim=1) for estimator in self.estimators_
]

if self.voting_strategy == "soft":
proba = op.average(outputs)

elif self.voting_strategy == "hard":
proba = op.majority_vote(outputs)

return proba

@torchensemble_model_doc(
"""Set the attributes on optimizer for SnapshotEnsembleClassifier.""",
Expand Down
11 changes: 11 additions & 0 deletions torchensemble/tests/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,14 @@ def test_residual_regression_invalid_shape():
label.view(-1, 1), # 4 * 1
)
assert "should be the same as output" in str(excinfo.value)


def test_majority_voting():
outputs = [
torch.FloatTensor(np.array(([0.9, 0.1], [0.2, 0.8]))),
torch.FloatTensor(np.array(([0.7, 0.3], [0.1, 0.9]))),
torch.FloatTensor(np.array(([0.1, 0.9], [0.8, 0.2]))),
]
actual = op.majority_vote(outputs).numpy()
expected = np.array(([1, 0], [0, 1]))
assert_array_almost_equal(actual, expected)
22 changes: 22 additions & 0 deletions torchensemble/utils/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import torch
import torch.nn.functional as F
from typing import List


__all__ = [
Expand All @@ -11,6 +12,7 @@
"onehot_encoding",
"pseudo_residual_classification",
"pseudo_residual_regression",
"majority_vote",
]


Expand Down Expand Up @@ -51,3 +53,23 @@ def pseudo_residual_regression(target, output):
raise ValueError(msg.format(target.size(), output.size()))

return target - output


def majority_vote(outputs: List[torch.Tensor]) -> torch.Tensor:
"""Compute the majority vote for a list of model outputs.
outputs: list of length (n_models)
containing tensors with shape (n_samples, n_classes)
majority_one_hots: (n_samples, n_classes)
"""

if len(outputs[0].shape) != 2:
msg = """The shape of outputs should be a list tensors of
length (n_models) with sizes (n_samples, n_classes).
The first tensor had shape {} """
raise ValueError(msg.format(outputs[0].shape))

votes = torch.stack(outputs).argmax(dim=2).mode(dim=0)[0]
proba = torch.zeros_like(outputs[0])
majority_one_hots = proba.scatter_(1, votes.view(-1, 1), 1)

return majority_one_hots
40 changes: 35 additions & 5 deletions torchensemble/voting.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,16 +92,36 @@ def _parallel_fit_per_epoch(
"""Implementation on the VotingClassifier.""", "model"
)
class VotingClassifier(BaseClassifier):
def __init__(self, voting_strategy="soft", **kwargs):
super(VotingClassifier, self).__init__(**kwargs)

implemented_strategies = {"soft", "hard"}
if voting_strategy not in implemented_strategies:
msg = (
"Voting strategy {} is not implemented, "
"please choose from {}."
)
raise ValueError(
msg.format(voting_strategy, implemented_strategies)
)

self.voting_strategy = voting_strategy

@torchensemble_model_doc(
"""Implementation on the data forwarding in VotingClassifier.""",
"classifier_forward",
)
def forward(self, *x):
# Average over class distributions from all base estimators.

outputs = [
F.softmax(estimator(*x), dim=1) for estimator in self.estimators_
]
proba = op.average(outputs)

if self.voting_strategy == "soft":
proba = op.average(outputs)

elif self.voting_strategy == "hard":
proba = op.majority_vote(outputs)

return proba

Expand Down Expand Up @@ -167,12 +187,17 @@ def fit(
# Utils
best_acc = 0.0

# Internal helper function on pesudo forward
# Internal helper function on pseudo forward
def _forward(estimators, *x):
outputs = [
F.softmax(estimator(*x), dim=1) for estimator in estimators
]
proba = op.average(outputs)

if self.voting_strategy == "soft":
proba = op.average(outputs)

elif self.voting_strategy == "hard":
proba = op.majority_vote(outputs)

return proba

Expand Down Expand Up @@ -287,6 +312,11 @@ def predict(self, *x):
"""Implementation on the NeuralForestClassifier.""", "tree_ensmeble_model"
)
class NeuralForestClassifier(BaseTreeEnsemble, VotingClassifier):
def __init__(self, voting_strategy="soft", **kwargs):
super().__init__(**kwargs)

self.voting_strategy = voting_strategy

@torchensemble_model_doc(
"""Implementation on the data forwarding in NeuralForestClassifier.""",
"classifier_forward",
Expand Down Expand Up @@ -420,7 +450,7 @@ def fit(
# Utils
best_loss = float("inf")

# Internal helper function on pesudo forward
# Internal helper function on pseudo forward
def _forward(estimators, *x):
outputs = [estimator(*x) for estimator in estimators]
pred = op.average(outputs)
Expand Down

0 comments on commit 061bef7

Please sign in to comment.