Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

initial attempt to implementing hard majority voting #126

Merged
merged 9 commits into from
Oct 6, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ Changelog
Ver 0.1.*
---------

* |Feature| |API| Add ``voting_strategy`` parameter for :class:`VotingClassifer`, :class:`NeuralForestClassifier`, and :class:`SnapshotEnsembleClassifier` | `@LukasGardberg <https://github.com/LukasGardberg>`__
* |Fix| Fix the sampling issue in :class:`BaggingClassifier` and :class:`BaggingRegressor` | `@SunHaozhe <https://github.com/SunHaozhe>`__
* |Feature| |API| Add :class:`NeuralForestClassifier` and :class:`NeuralForestRegressor` | `@xuyxu <https://github.com/xuyxu>`__
* |Fix| Relax check on input dataloader | `@xuyxu <https://github.com/xuyxu>`__
Expand Down
5 changes: 4 additions & 1 deletion torchensemble/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def predict(self, *x):
class BaseTreeEnsemble(BaseModule):
def __init__(
self,
n_estimators,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any reason on adding the default value for n_estimators

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey, no, I accidentally left it in there when trying some other stuff, hope its ok!

n_estimators=10,
depth=5,
lamda=1e-3,
cuda=False,
Expand Down Expand Up @@ -280,8 +280,11 @@ def evaluate(self, test_loader, return_loss=False):

for _, elem in enumerate(test_loader):
data, target = split_data_target(elem, self.device)

gardberg marked this conversation as resolved.
Show resolved Hide resolved
output = self.forward(*data)

_, predicted = torch.max(output.data, 1)

correct += (predicted == target).sum().item()
total += target.size(0)
loss += self._criterion(output, target)
Expand Down
18 changes: 16 additions & 2 deletions torchensemble/snapshot_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,14 +209,28 @@ def set_scheduler(self, scheduler_name, **kwargs):
"""Implementation on the SnapshotEnsembleClassifier.""", "seq_model"
)
class SnapshotEnsembleClassifier(_BaseSnapshotEnsemble, BaseClassifier):
def __init__(self, voting_strategy="soft", **kwargs):
super().__init__(**kwargs)

self.voting_strategy = voting_strategy

@torchensemble_model_doc(
"""Implementation on the data forwarding in SnapshotEnsembleClassifier.""", # noqa: E501
"classifier_forward",
)
def forward(self, *x):
proba = self._forward(*x)

return F.softmax(proba, dim=1)
outputs = [
F.softmax(estimator(*x), dim=1) for estimator in self.estimators_
]

if self.voting_strategy == "soft":
proba = op.average(outputs)

elif self.voting_strategy == "hard":
proba = op.majority_vote(outputs)

return proba

@torchensemble_model_doc(
"""Set the attributes on optimizer for SnapshotEnsembleClassifier.""",
Expand Down
11 changes: 11 additions & 0 deletions torchensemble/tests/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,14 @@ def test_residual_regression_invalid_shape():
label.view(-1, 1), # 4 * 1
)
assert "should be the same as output" in str(excinfo.value)


def test_majority_voting():
outputs = [
torch.FloatTensor(np.array(([0.9, 0.1], [0.2, 0.8]))),
torch.FloatTensor(np.array(([0.7, 0.3], [0.1, 0.9]))),
torch.FloatTensor(np.array(([0.1, 0.9], [0.8, 0.2]))),
]
actual = op.majority_vote(outputs).numpy()
expected = np.array(([1, 0], [0, 1]))
assert_array_almost_equal(actual, expected)
22 changes: 22 additions & 0 deletions torchensemble/utils/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import torch
import torch.nn.functional as F
from typing import List


__all__ = [
Expand All @@ -11,6 +12,7 @@
"onehot_encoding",
"pseudo_residual_classification",
"pseudo_residual_regression",
"majority_vote",
]


Expand Down Expand Up @@ -51,3 +53,23 @@ def pseudo_residual_regression(target, output):
raise ValueError(msg.format(target.size(), output.size()))

return target - output


def majority_vote(outputs: List[torch.Tensor]) -> torch.Tensor:
"""Compute the majority vote for a list of model outputs.
outputs: list of length (n_models)
containing tensors with shape (n_samples, n_classes)
majority_one_hots: (n_samples, n_classes)
"""

if len(outputs[0].shape) != 2:
msg = """The shape of outputs should be a list tensors of
length (n_models) with sizes (n_samples, n_classes).
The first tensor had shape {} """
raise ValueError(msg.format(outputs[0].shape))

votes = torch.stack(outputs).argmax(dim=2).mode(dim=0)[0]
proba = torch.zeros_like(outputs[0])
majority_one_hots = proba.scatter_(1, votes.view(-1, 1), 1)

return majority_one_hots
40 changes: 35 additions & 5 deletions torchensemble/voting.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,16 +92,36 @@ def _parallel_fit_per_epoch(
"""Implementation on the VotingClassifier.""", "model"
)
class VotingClassifier(BaseClassifier):
def __init__(self, voting_strategy="soft", **kwargs):
super(VotingClassifier, self).__init__(**kwargs)

implemented_strategies = {"soft", "hard"}
if voting_strategy not in implemented_strategies:
msg = (
"Voting strategy {} is not implemented, "
"please choose from {}."
)
raise ValueError(
msg.format(voting_strategy, implemented_strategies)
)

self.voting_strategy = voting_strategy

@torchensemble_model_doc(
"""Implementation on the data forwarding in VotingClassifier.""",
"classifier_forward",
)
def forward(self, *x):
# Average over class distributions from all base estimators.

outputs = [
F.softmax(estimator(*x), dim=1) for estimator in self.estimators_
]
proba = op.average(outputs)

if self.voting_strategy == "soft":
proba = op.average(outputs)

elif self.voting_strategy == "hard":
gardberg marked this conversation as resolved.
Show resolved Hide resolved
proba = op.majority_vote(outputs)

return proba

Expand Down Expand Up @@ -167,12 +187,17 @@ def fit(
# Utils
best_acc = 0.0

# Internal helper function on pesudo forward
# Internal helper function on pseudo forward
def _forward(estimators, *x):
outputs = [
F.softmax(estimator(*x), dim=1) for estimator in estimators
]
proba = op.average(outputs)

if self.voting_strategy == "soft":
proba = op.average(outputs)

elif self.voting_strategy == "hard":
proba = op.majority_vote(outputs)

return proba

Expand Down Expand Up @@ -276,6 +301,11 @@ def predict(self, *x):
"""Implementation on the NeuralForestClassifier.""", "tree_ensmeble_model"
)
class NeuralForestClassifier(BaseTreeEnsemble, VotingClassifier):
def __init__(self, voting_strategy="soft", **kwargs):
super().__init__(**kwargs)

self.voting_strategy = voting_strategy

@torchensemble_model_doc(
"""Implementation on the data forwarding in NeuralForestClassifier.""",
"classifier_forward",
Expand Down Expand Up @@ -409,7 +439,7 @@ def fit(
# Utils
best_loss = float("inf")

# Internal helper function on pesudo forward
# Internal helper function on pseudo forward
def _forward(estimators, *x):
outputs = [estimator(*x) for estimator in estimators]
pred = op.average(outputs)
Expand Down