Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

initial attempt to implementing hard majority voting #126

Merged
merged 9 commits into from
Oct 6, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions torchensemble/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,8 +280,11 @@ def evaluate(self, test_loader, return_loss=False):

for _, elem in enumerate(test_loader):
data, target = split_data_target(elem, self.device)

gardberg marked this conversation as resolved.
Show resolved Hide resolved
output = self.forward(*data)

_, predicted = torch.max(output.data, 1)

correct += (predicted == target).sum().item()
total += target.size(0)
loss += self._criterion(output, target)
Expand Down
38 changes: 34 additions & 4 deletions torchensemble/voting.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,16 +92,39 @@ def _parallel_fit_per_epoch(
"""Implementation on the VotingClassifier.""", "model"
)
class VotingClassifier(BaseClassifier):
def __init__(self, voting_strategy="soft", **kwargs):
super(VotingClassifier, self).__init__(**kwargs)

implemented_strategies = {"soft", "hard"}
if voting_strategy not in implemented_strategies:
msg = (
"Voting strategy {} is not implemented, "
"please choose from {}."
)
raise ValueError(
msg.format(voting_strategy, implemented_strategies)
)

self.voting_strategy = voting_strategy

@torchensemble_model_doc(
"""Implementation on the data forwarding in VotingClassifier.""",
"classifier_forward",
)
def forward(self, *x):
# Average over class distributions from all base estimators.

outputs = [
F.softmax(estimator(*x), dim=1) for estimator in self.estimators_
]
proba = op.average(outputs)

if self.voting_strategy == "soft":
proba = op.average(outputs)

elif self.voting_strategy == "hard":
gardberg marked this conversation as resolved.
Show resolved Hide resolved
votes = torch.stack(outputs).argmax(dim=2).mode(dim=0)[0]
proba = torch.zeros_like(outputs[0])
proba.scatter_(1, votes.view(-1, 1), 1)

return proba

Expand Down Expand Up @@ -167,12 +190,19 @@ def fit(
# Utils
best_acc = 0.0

# Internal helper function on pesudo forward
# Internal helper function on pseudo forward
def _forward(estimators, *x):
outputs = [
F.softmax(estimator(*x), dim=1) for estimator in estimators
]
proba = op.average(outputs)

if self.voting_strategy == "soft":
proba = op.average(outputs)

elif self.voting_strategy == "hard":
votes = torch.stack(outputs).argmax(dim=2).mode(dim=0)[0]
proba = torch.zeros_like(outputs[0])
proba.scatter_(1, votes.view(-1, 1), 1)

return proba

Expand Down Expand Up @@ -409,7 +439,7 @@ def fit(
# Utils
best_loss = float("inf")

# Internal helper function on pesudo forward
# Internal helper function on pseudo forward
def _forward(estimators, *x):
outputs = [estimator(*x) for estimator in estimators]
pred = op.average(outputs)
Expand Down