From 061bef71bc2b16fbc8fc56ba51424158c24d1fa7 Mon Sep 17 00:00:00 2001 From: LukasGardberg <52111220+LukasGardberg@users.noreply.github.com> Date: Thu, 6 Oct 2022 06:12:19 +0200 Subject: [PATCH] feat: add hard majority voting strategy for parallel ensembles (#126) * initial attempt to implementing hard majority voting * fixed linting * removed comments etc * moved hard voting to utils, added test * fixed linting * fixed linting 2 * add changelog * switched to value error * flake8 formating Co-authored-by: xuyxu --- CHANGELOG.rst | 1 + torchensemble/_base.py | 5 +++- torchensemble/snapshot_ensemble.py | 18 +++++++++++-- torchensemble/tests/test_operator.py | 11 ++++++++ torchensemble/utils/operator.py | 22 +++++++++++++++ torchensemble/voting.py | 40 ++++++++++++++++++++++++---- 6 files changed, 89 insertions(+), 8 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 34da1d0..c2bcc7d 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -18,6 +18,7 @@ Changelog Ver 0.1.* --------- +* |Feature| |API| Add ``voting_strategy`` parameter for :class:`VotingClassifer`, :class:`NeuralForestClassifier`, and :class:`SnapshotEnsembleClassifier` | `@LukasGardberg `__ * |Fix| Fix the sampling issue in :class:`BaggingClassifier` and :class:`BaggingRegressor` | `@SunHaozhe `__ * |Feature| |API| Add :class:`NeuralForestClassifier` and :class:`NeuralForestRegressor` | `@xuyxu `__ * |Fix| Relax check on input dataloader | `@xuyxu `__ diff --git a/torchensemble/_base.py b/torchensemble/_base.py index 2581155..c94b3cb 100644 --- a/torchensemble/_base.py +++ b/torchensemble/_base.py @@ -203,7 +203,7 @@ def predict(self, *x): class BaseTreeEnsemble(BaseModule): def __init__( self, - n_estimators, + n_estimators=10, depth=5, lamda=1e-3, cuda=False, @@ -280,8 +280,11 @@ def evaluate(self, test_loader, return_loss=False): for _, elem in enumerate(test_loader): data, target = split_data_target(elem, self.device) + output = self.forward(*data) + _, predicted = torch.max(output.data, 1) + correct += (predicted == target).sum().item() total += target.size(0) loss += self._criterion(output, target) diff --git a/torchensemble/snapshot_ensemble.py b/torchensemble/snapshot_ensemble.py index 34bb2f8..37f12d8 100644 --- a/torchensemble/snapshot_ensemble.py +++ b/torchensemble/snapshot_ensemble.py @@ -209,14 +209,28 @@ def set_scheduler(self, scheduler_name, **kwargs): """Implementation on the SnapshotEnsembleClassifier.""", "seq_model" ) class SnapshotEnsembleClassifier(_BaseSnapshotEnsemble, BaseClassifier): + def __init__(self, voting_strategy="soft", **kwargs): + super().__init__(**kwargs) + + self.voting_strategy = voting_strategy + @torchensemble_model_doc( """Implementation on the data forwarding in SnapshotEnsembleClassifier.""", # noqa: E501 "classifier_forward", ) def forward(self, *x): - proba = self._forward(*x) - return F.softmax(proba, dim=1) + outputs = [ + F.softmax(estimator(*x), dim=1) for estimator in self.estimators_ + ] + + if self.voting_strategy == "soft": + proba = op.average(outputs) + + elif self.voting_strategy == "hard": + proba = op.majority_vote(outputs) + + return proba @torchensemble_model_doc( """Set the attributes on optimizer for SnapshotEnsembleClassifier.""", diff --git a/torchensemble/tests/test_operator.py b/torchensemble/tests/test_operator.py index c2008f6..5656031 100644 --- a/torchensemble/tests/test_operator.py +++ b/torchensemble/tests/test_operator.py @@ -43,3 +43,14 @@ def test_residual_regression_invalid_shape(): label.view(-1, 1), # 4 * 1 ) assert "should be the same as output" in str(excinfo.value) + + +def test_majority_voting(): + outputs = [ + torch.FloatTensor(np.array(([0.9, 0.1], [0.2, 0.8]))), + torch.FloatTensor(np.array(([0.7, 0.3], [0.1, 0.9]))), + torch.FloatTensor(np.array(([0.1, 0.9], [0.8, 0.2]))), + ] + actual = op.majority_vote(outputs).numpy() + expected = np.array(([1, 0], [0, 1])) + assert_array_almost_equal(actual, expected) diff --git a/torchensemble/utils/operator.py b/torchensemble/utils/operator.py index c305960..e5c6738 100644 --- a/torchensemble/utils/operator.py +++ b/torchensemble/utils/operator.py @@ -3,6 +3,7 @@ import torch import torch.nn.functional as F +from typing import List __all__ = [ @@ -11,6 +12,7 @@ "onehot_encoding", "pseudo_residual_classification", "pseudo_residual_regression", + "majority_vote", ] @@ -51,3 +53,23 @@ def pseudo_residual_regression(target, output): raise ValueError(msg.format(target.size(), output.size())) return target - output + + +def majority_vote(outputs: List[torch.Tensor]) -> torch.Tensor: + """Compute the majority vote for a list of model outputs. + outputs: list of length (n_models) + containing tensors with shape (n_samples, n_classes) + majority_one_hots: (n_samples, n_classes) + """ + + if len(outputs[0].shape) != 2: + msg = """The shape of outputs should be a list tensors of + length (n_models) with sizes (n_samples, n_classes). + The first tensor had shape {} """ + raise ValueError(msg.format(outputs[0].shape)) + + votes = torch.stack(outputs).argmax(dim=2).mode(dim=0)[0] + proba = torch.zeros_like(outputs[0]) + majority_one_hots = proba.scatter_(1, votes.view(-1, 1), 1) + + return majority_one_hots diff --git a/torchensemble/voting.py b/torchensemble/voting.py index b3dffb9..54ed023 100644 --- a/torchensemble/voting.py +++ b/torchensemble/voting.py @@ -92,16 +92,36 @@ def _parallel_fit_per_epoch( """Implementation on the VotingClassifier.""", "model" ) class VotingClassifier(BaseClassifier): + def __init__(self, voting_strategy="soft", **kwargs): + super(VotingClassifier, self).__init__(**kwargs) + + implemented_strategies = {"soft", "hard"} + if voting_strategy not in implemented_strategies: + msg = ( + "Voting strategy {} is not implemented, " + "please choose from {}." + ) + raise ValueError( + msg.format(voting_strategy, implemented_strategies) + ) + + self.voting_strategy = voting_strategy + @torchensemble_model_doc( """Implementation on the data forwarding in VotingClassifier.""", "classifier_forward", ) def forward(self, *x): - # Average over class distributions from all base estimators. + outputs = [ F.softmax(estimator(*x), dim=1) for estimator in self.estimators_ ] - proba = op.average(outputs) + + if self.voting_strategy == "soft": + proba = op.average(outputs) + + elif self.voting_strategy == "hard": + proba = op.majority_vote(outputs) return proba @@ -167,12 +187,17 @@ def fit( # Utils best_acc = 0.0 - # Internal helper function on pesudo forward + # Internal helper function on pseudo forward def _forward(estimators, *x): outputs = [ F.softmax(estimator(*x), dim=1) for estimator in estimators ] - proba = op.average(outputs) + + if self.voting_strategy == "soft": + proba = op.average(outputs) + + elif self.voting_strategy == "hard": + proba = op.majority_vote(outputs) return proba @@ -287,6 +312,11 @@ def predict(self, *x): """Implementation on the NeuralForestClassifier.""", "tree_ensmeble_model" ) class NeuralForestClassifier(BaseTreeEnsemble, VotingClassifier): + def __init__(self, voting_strategy="soft", **kwargs): + super().__init__(**kwargs) + + self.voting_strategy = voting_strategy + @torchensemble_model_doc( """Implementation on the data forwarding in NeuralForestClassifier.""", "classifier_forward", @@ -420,7 +450,7 @@ def fit( # Utils best_loss = float("inf") - # Internal helper function on pesudo forward + # Internal helper function on pseudo forward def _forward(estimators, *x): outputs = [estimator(*x) for estimator in estimators] pred = op.average(outputs)