Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

initial attempt to implementing hard majority voting #126

Merged
merged 9 commits into from
Oct 6, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions torchensemble/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,8 +280,16 @@ def evaluate(self, test_loader, return_loss=False):

for _, elem in enumerate(test_loader):
data, target = split_data_target(elem, self.device)

gardberg marked this conversation as resolved.
Show resolved Hide resolved
# Get averaged probabilities over all base estimators
# for each sample in batch.
# Output shape: (batch_size, n_classes)
output = self.forward(*data)

# get val, idx for the class w max log-prob for each sample
# predicted shape: (batch_size)
_, predicted = torch.max(output.data, 1)

correct += (predicted == target).sum().item()
total += target.size(0)
loss += self._criterion(output, target)
Expand Down Expand Up @@ -349,8 +357,8 @@ def __init__(self, input_dim, output_dim, depth=5, lamda=1e-3, cuda=False):

self._validate_parameters()

self.internal_node_num_ = 2 ** self.depth - 1
self.leaf_node_num_ = 2 ** self.depth
self.internal_node_num_ = 2**self.depth - 1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Several lines here do not meet the coding style requirements. Please consider to execute the command black --skip-string-normalization --config pyproject.toml torchensemble/ in the root directory of the project, which will automatically format the code

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Weird, I ran the same black command and it automatically did the changes mentioned above. I reverted them manually now. Maybe it has something to do with my version of black.

self.leaf_node_num_ = 2**self.depth

# Different penalty coefficients for nodes in different layers
self.penalty_list = [
Expand Down Expand Up @@ -420,7 +428,7 @@ def _cal_penalty(self, layer_idx, _mu, _path_prob):
penalty = torch.tensor(0.0).to(self.device)

batch_size = _mu.size()[0]
_mu = _mu.view(batch_size, 2 ** layer_idx)
_mu = _mu.view(batch_size, 2**layer_idx)
_path_prob = _path_prob.view(batch_size, 2 ** (layer_idx + 1))

for node in range(0, 2 ** (layer_idx + 1)):
Expand Down
51 changes: 47 additions & 4 deletions torchensemble/voting.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,17 +92,52 @@ def _parallel_fit_per_epoch(
"""Implementation on the VotingClassifier.""", "model"
)
class VotingClassifier(BaseClassifier):
def __init__(self, voting_strategy="soft", **kwargs):
super(VotingClassifier, self).__init__(**kwargs)

self._implemented_voting_strategies = ["soft", "hard"]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This attribute is redundant, maybe we could hard-code the supported strategies, for example, change the attribute in line 100 into a set ("soft", "hard"), and replace the latter {} in line 103 with a fixed string.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Check if my latest changes is what you meant :)


if voting_strategy not in self._implemented_voting_strategies:
msg = (
"Voting strategy {} is not implemented, "
"please choose from {}."
)
raise ValueError(
msg.format(
voting_strategy, self._implemented_voting_strategies
)
)

self.voting_strategy = voting_strategy

@torchensemble_model_doc(
"""Implementation on the data forwarding in VotingClassifier.""",
"classifier_forward",
)
def forward(self, *x):
# Average over class distributions from all base estimators.

# output: (n_estimators, batch_size, n_classes)
gardberg marked this conversation as resolved.
Show resolved Hide resolved
outputs = [
F.softmax(estimator(*x), dim=1) for estimator in self.estimators_
]
proba = op.average(outputs)

if self.voting_strategy == "soft":
proba = op.average(outputs)

elif self.voting_strategy == "hard":
gardberg marked this conversation as resolved.
Show resolved Hide resolved
# Do hard majority voting
# votes: (batch_size)
votes = torch.stack(outputs).argmax(dim=2).mode(dim=0)[0]
# Set the probability of the most voted class to 1
proba = torch.zeros_like(outputs[0])
proba.scatter_(1, votes.view(-1, 1), 1)

# TODO: make sure this is compatible with
# predict and evaluate

# Returns averaged class probabilities for each sample
# proba shape: (batch_size, n_classes)
return proba

@torchensemble_model_doc(
Expand Down Expand Up @@ -167,12 +202,20 @@ def fit(
# Utils
best_acc = 0.0

# Internal helper function on pesudo forward
# Internal helper function on psuedo forward
def _forward(estimators, *x):
outputs = [
F.softmax(estimator(*x), dim=1) for estimator in estimators
]
proba = op.average(outputs)

if self.voting_strategy == "soft":
proba = op.average(outputs)

# Maybe the voting strategy can be packaged as a function?
elif self.voting_strategy == "hard":
votes = torch.stack(outputs).argmax(dim=2).mode(dim=0)[0]
proba = torch.zeros_like(outputs[0])
proba.scatter_(1, votes.view(-1, 1), 1)

return proba

Expand Down Expand Up @@ -409,7 +452,7 @@ def fit(
# Utils
best_loss = float("inf")

# Internal helper function on pesudo forward
# Internal helper function on psuedo forward
def _forward(estimators, *x):
outputs = [estimator(*x) for estimator in estimators]
pred = op.average(outputs)
Expand Down