diff --git a/CHANGELOG.rst b/CHANGELOG.rst index c2bcc7d..51f055b 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -18,6 +18,7 @@ Changelog Ver 0.1.* --------- +* |Feature| Add internal :meth:`unsqueeze` operation in :meth:`forward` of all classifiers | `@xuyxu `__ * |Feature| |API| Add ``voting_strategy`` parameter for :class:`VotingClassifer`, :class:`NeuralForestClassifier`, and :class:`SnapshotEnsembleClassifier` | `@LukasGardberg `__ * |Fix| Fix the sampling issue in :class:`BaggingClassifier` and :class:`BaggingRegressor` | `@SunHaozhe `__ * |Feature| |API| Add :class:`NeuralForestClassifier` and :class:`NeuralForestRegressor` | `@xuyxu `__ diff --git a/torchensemble/_base.py b/torchensemble/_base.py index c94b3cb..b101311 100644 --- a/torchensemble/_base.py +++ b/torchensemble/_base.py @@ -29,7 +29,7 @@ def get_doc(item): __doc = { "model": const.__model_doc, "seq_model": const.__seq_model_doc, - "tree_ensmeble_model": const.__tree_ensemble_doc, + "tree_ensemble_model": const.__tree_ensemble_doc, "fit": const.__fit_doc, "predict": const.__predict_doc, "set_optimizer": const.__set_optimizer_doc, diff --git a/torchensemble/adversarial_training.py b/torchensemble/adversarial_training.py index 8f526b6..9e0d318 100644 --- a/torchensemble/adversarial_training.py +++ b/torchensemble/adversarial_training.py @@ -226,7 +226,8 @@ class AdversarialTrainingClassifier(_BaseAdversarialTraining, BaseClassifier): def forward(self, *x): # Take the average over class distributions from all base estimators. outputs = [ - F.softmax(estimator(*x), dim=1) for estimator in self.estimators_ + F.softmax(op.unsqueeze_tensor(estimator(*x)), dim=1) + for estimator in self.estimators_ ] proba = op.average(outputs) diff --git a/torchensemble/bagging.py b/torchensemble/bagging.py index a7ad91d..623b7c3 100644 --- a/torchensemble/bagging.py +++ b/torchensemble/bagging.py @@ -94,7 +94,8 @@ class BaggingClassifier(BaseClassifier): def forward(self, *x): # Average over class distributions from all base estimators. outputs = [ - F.softmax(estimator(*x), dim=1) for estimator in self.estimators_ + F.softmax(op.unsqueeze_tensor(estimator(*x)), dim=1) + for estimator in self.estimators_ ] proba = op.average(outputs) diff --git a/torchensemble/fast_geometric.py b/torchensemble/fast_geometric.py index fa6e8d5..bea4588 100644 --- a/torchensemble/fast_geometric.py +++ b/torchensemble/fast_geometric.py @@ -176,7 +176,7 @@ class FastGeometricClassifier(_BaseFastGeometric, BaseClassifier): "classifier_forward", ) def forward(self, *x): - proba = self._forward(*x) + proba = op.unsqueeze_tensor(self._forward(*x)) return F.softmax(proba, dim=1) diff --git a/torchensemble/fusion.py b/torchensemble/fusion.py index 08ff04b..f3c5126 100644 --- a/torchensemble/fusion.py +++ b/torchensemble/fusion.py @@ -39,7 +39,7 @@ def _forward(self, *x): "classifier_forward", ) def forward(self, *x): - output = self._forward(*x) + output = op.unsqueeze_tensor(self._forward(*x)) proba = F.softmax(output, dim=1) return proba diff --git a/torchensemble/gradient_boosting.py b/torchensemble/gradient_boosting.py index 7b967b9..55fb95a 100644 --- a/torchensemble/gradient_boosting.py +++ b/torchensemble/gradient_boosting.py @@ -420,7 +420,10 @@ def fit( "classifier_forward", ) def forward(self, *x): - output = [estimator(*x) for estimator in self.estimators_] + output = [ + op.unsqueeze_tensor(estimator(*x)) + for estimator in self.estimators_ + ] output = op.sum_with_multiplicative(output, self.shrinkage_rate) proba = F.softmax(output, dim=1) diff --git a/torchensemble/snapshot_ensemble.py b/torchensemble/snapshot_ensemble.py index 37f12d8..bbcc4a2 100644 --- a/torchensemble/snapshot_ensemble.py +++ b/torchensemble/snapshot_ensemble.py @@ -212,6 +212,16 @@ class SnapshotEnsembleClassifier(_BaseSnapshotEnsemble, BaseClassifier): def __init__(self, voting_strategy="soft", **kwargs): super().__init__(**kwargs) + implemented_strategies = {"soft", "hard"} + if voting_strategy not in implemented_strategies: + msg = ( + "Voting strategy {} is not implemented, " + "please choose from {}." + ) + raise ValueError( + msg.format(voting_strategy, implemented_strategies) + ) + self.voting_strategy = voting_strategy @torchensemble_model_doc( @@ -221,13 +231,13 @@ def __init__(self, voting_strategy="soft", **kwargs): def forward(self, *x): outputs = [ - F.softmax(estimator(*x), dim=1) for estimator in self.estimators_ + F.softmax(op.unsqueeze_tensor(estimator(*x)), dim=1) + for estimator in self.estimators_ ] if self.voting_strategy == "soft": proba = op.average(outputs) - - elif self.voting_strategy == "hard": + else: proba = op.majority_vote(outputs) return proba diff --git a/torchensemble/soft_gradient_boosting.py b/torchensemble/soft_gradient_boosting.py index 3f89a9e..8ffe5d9 100644 --- a/torchensemble/soft_gradient_boosting.py +++ b/torchensemble/soft_gradient_boosting.py @@ -406,7 +406,10 @@ def fit( "classifier_forward", ) def forward(self, *x): - output = [estimator(*x) for estimator in self.estimators_] + output = [ + op.unsqueeze_tensor(estimator(*x)) + for estimator in self.estimators_ + ] output = op.sum_with_multiplicative(output, self.shrinkage_rate) proba = F.softmax(output, dim=1) diff --git a/torchensemble/tests/test_all_models.py b/torchensemble/tests/test_all_models.py index 00d1b78..86e9c0d 100644 --- a/torchensemble/tests/test_all_models.py +++ b/torchensemble/tests/test_all_models.py @@ -49,7 +49,7 @@ def __init__(self): self.linear2 = nn.Linear(2, 2) def forward(self, X): - X = X.view(X.size()[0], -1) + X = X.view(X.size(0), -1) output = self.linear1(X) output = self.linear2(output) return output diff --git a/torchensemble/utils/operator.py b/torchensemble/utils/operator.py index e5c6738..80a6512 100644 --- a/torchensemble/utils/operator.py +++ b/torchensemble/utils/operator.py @@ -13,6 +13,7 @@ "pseudo_residual_classification", "pseudo_residual_regression", "majority_vote", + "unsqueeze_tensor", ] @@ -73,3 +74,11 @@ def majority_vote(outputs: List[torch.Tensor]) -> torch.Tensor: majority_one_hots = proba.scatter_(1, votes.view(-1, 1), 1) return majority_one_hots + + +def unsqueeze_tensor(tensor: torch.Tensor, dim=1) -> torch.Tensor: + """Reshape 1-D tensor to 2-D for downstream operations.""" + if tensor.ndim == 1: + tensor = torch.unsqueeze(tensor, dim) + + return tensor diff --git a/torchensemble/voting.py b/torchensemble/voting.py index 54ed023..33da787 100644 --- a/torchensemble/voting.py +++ b/torchensemble/voting.py @@ -114,13 +114,13 @@ def __init__(self, voting_strategy="soft", **kwargs): def forward(self, *x): outputs = [ - F.softmax(estimator(*x), dim=1) for estimator in self.estimators_ + F.softmax(op.unsqueeze_tensor(estimator(*x)), dim=1) + for estimator in self.estimators_ ] if self.voting_strategy == "soft": proba = op.average(outputs) - - elif self.voting_strategy == "hard": + else: proba = op.majority_vote(outputs) return proba @@ -309,7 +309,7 @@ def predict(self, *x): @torchensemble_model_doc( - """Implementation on the NeuralForestClassifier.""", "tree_ensmeble_model" + """Implementation on the NeuralForestClassifier.""", "tree_ensemble_model" ) class NeuralForestClassifier(BaseTreeEnsemble, VotingClassifier): def __init__(self, voting_strategy="soft", **kwargs): @@ -324,7 +324,8 @@ def __init__(self, voting_strategy="soft", **kwargs): def forward(self, *x): # Average over class distributions from all base estimators. outputs = [ - F.softmax(estimator(*x), dim=1) for estimator in self.estimators_ + F.softmax(op.unsqueeze_tensor(estimator(*x)), dim=1) + for estimator in self.estimators_ ] proba = op.average(outputs) @@ -561,7 +562,7 @@ def predict(self, *x): @torchensemble_model_doc( - """Implementation on the NeuralForestRegressor.""", "tree_ensmeble_model" + """Implementation on the NeuralForestRegressor.""", "tree_ensemble_model" ) class NeuralForestRegressor(BaseTreeEnsemble, VotingRegressor): @torchensemble_model_doc(