From 90b965c3a421aac6265f3e8e00c1356b0661b6d0 Mon Sep 17 00:00:00 2001 From: LukasGardberg Date: Tue, 20 Sep 2022 17:55:09 +0200 Subject: [PATCH 1/9] initial attempt to implementing hard majority voting --- torchensemble/_base.py | 8 +++++++ torchensemble/voting.py | 50 +++++++++++++++++++++++++++++++++++++---- 2 files changed, 54 insertions(+), 4 deletions(-) diff --git a/torchensemble/_base.py b/torchensemble/_base.py index 2581155..7030bea 100644 --- a/torchensemble/_base.py +++ b/torchensemble/_base.py @@ -280,8 +280,16 @@ def evaluate(self, test_loader, return_loss=False): for _, elem in enumerate(test_loader): data, target = split_data_target(elem, self.device) + + # Get averaged probabilities over all base estimators + # for each sample in batch. + # Output shape: (batch_size, n_classes) output = self.forward(*data) + + # get val, idx for the class w max log-prob for each sample + # predicted shape: (batch_size) _, predicted = torch.max(output.data, 1) + correct += (predicted == target).sum().item() total += target.size(0) loss += self._criterion(output, target) diff --git a/torchensemble/voting.py b/torchensemble/voting.py index c0dea3d..5725344 100644 --- a/torchensemble/voting.py +++ b/torchensemble/voting.py @@ -92,17 +92,52 @@ def _parallel_fit_per_epoch( """Implementation on the VotingClassifier.""", "model" ) class VotingClassifier(BaseClassifier): + + def __init__(self, voting_strategy="soft", **kwargs): + super(VotingClassifier, self).__init__(**kwargs) + + self._implemented_voting_strategies = ["soft", "hard"] + + if voting_strategy not in self._implemented_voting_strategies: + msg = ( + "Voting strategy {} is not implemented, " + "please choose from {}." + ) + raise ValueError( + msg.format( + voting_strategy, self._implemented_voting_strategies + ) + ) + + self.voting_strategy = voting_strategy + @torchensemble_model_doc( """Implementation on the data forwarding in VotingClassifier.""", "classifier_forward", ) def forward(self, *x): # Average over class distributions from all base estimators. + + # output: (n_estimators, batch_size, n_classes) outputs = [ F.softmax(estimator(*x), dim=1) for estimator in self.estimators_ ] - proba = op.average(outputs) + if self.voting_strategy == "soft": proba = op.average(outputs) + + elif self.voting_strategy == "hard": + # Do hard majority voting + # votes: (batch_size) + votes = torch.stack(outputs).argmax(dim=2).mode(dim=0)[0] + # Set the probability of the most voted class to 1 + proba = torch.zeros_like(outputs[0]) + proba.scatter_(1, votes.view(-1, 1), 1) + + # TODO: make sure this is compatible with + # predict and evaluate + + # Returns averaged class probabilities for each sample + # proba shape: (batch_size, n_classes) return proba @torchensemble_model_doc( @@ -167,12 +202,19 @@ def fit( # Utils best_acc = 0.0 - # Internal helper function on pesudo forward + # Internal helper function on psuedo forward def _forward(estimators, *x): outputs = [ F.softmax(estimator(*x), dim=1) for estimator in estimators ] - proba = op.average(outputs) + + if self.voting_strategy == "soft": proba = op.average(outputs) + + # Maybe the voting strategy can be packaged as a function? + elif self.voting_strategy == "hard": + votes = torch.stack(outputs).argmax(dim=2).mode(dim=0)[0] + proba = torch.zeros_like(outputs[0]) + proba.scatter_(1, votes.view(-1, 1), 1) return proba @@ -409,7 +451,7 @@ def fit( # Utils best_loss = float("inf") - # Internal helper function on pesudo forward + # Internal helper function on psuedo forward def _forward(estimators, *x): outputs = [estimator(*x) for estimator in estimators] pred = op.average(outputs) From 13715d4bf3b7eea04a2db2bd5cd84a26503cb2a0 Mon Sep 17 00:00:00 2001 From: LukasGardberg Date: Tue, 20 Sep 2022 19:50:22 +0200 Subject: [PATCH 2/9] fixed linting --- torchensemble/_base.py | 6 +++--- torchensemble/voting.py | 9 +++++---- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/torchensemble/_base.py b/torchensemble/_base.py index 7030bea..09f5a48 100644 --- a/torchensemble/_base.py +++ b/torchensemble/_base.py @@ -357,8 +357,8 @@ def __init__(self, input_dim, output_dim, depth=5, lamda=1e-3, cuda=False): self._validate_parameters() - self.internal_node_num_ = 2 ** self.depth - 1 - self.leaf_node_num_ = 2 ** self.depth + self.internal_node_num_ = 2**self.depth - 1 + self.leaf_node_num_ = 2**self.depth # Different penalty coefficients for nodes in different layers self.penalty_list = [ @@ -428,7 +428,7 @@ def _cal_penalty(self, layer_idx, _mu, _path_prob): penalty = torch.tensor(0.0).to(self.device) batch_size = _mu.size()[0] - _mu = _mu.view(batch_size, 2 ** layer_idx) + _mu = _mu.view(batch_size, 2**layer_idx) _path_prob = _path_prob.view(batch_size, 2 ** (layer_idx + 1)) for node in range(0, 2 ** (layer_idx + 1)): diff --git a/torchensemble/voting.py b/torchensemble/voting.py index 5725344..189080a 100644 --- a/torchensemble/voting.py +++ b/torchensemble/voting.py @@ -92,7 +92,6 @@ def _parallel_fit_per_epoch( """Implementation on the VotingClassifier.""", "model" ) class VotingClassifier(BaseClassifier): - def __init__(self, voting_strategy="soft", **kwargs): super(VotingClassifier, self).__init__(**kwargs) @@ -123,7 +122,8 @@ def forward(self, *x): F.softmax(estimator(*x), dim=1) for estimator in self.estimators_ ] - if self.voting_strategy == "soft": proba = op.average(outputs) + if self.voting_strategy == "soft": + proba = op.average(outputs) elif self.voting_strategy == "hard": # Do hard majority voting @@ -137,7 +137,7 @@ def forward(self, *x): # predict and evaluate # Returns averaged class probabilities for each sample - # proba shape: (batch_size, n_classes) + # proba shape: (batch_size, n_classes) return proba @torchensemble_model_doc( @@ -208,7 +208,8 @@ def _forward(estimators, *x): F.softmax(estimator(*x), dim=1) for estimator in estimators ] - if self.voting_strategy == "soft": proba = op.average(outputs) + if self.voting_strategy == "soft": + proba = op.average(outputs) # Maybe the voting strategy can be packaged as a function? elif self.voting_strategy == "hard": From dbd661464f16aed2769b7478c0ea40746949dbac Mon Sep 17 00:00:00 2001 From: LukasGardberg Date: Fri, 23 Sep 2022 15:55:31 +0200 Subject: [PATCH 3/9] removed comments etc --- torchensemble/_base.py | 11 +++-------- torchensemble/voting.py | 23 +++++------------------ 2 files changed, 8 insertions(+), 26 deletions(-) diff --git a/torchensemble/_base.py b/torchensemble/_base.py index 09f5a48..418dd8b 100644 --- a/torchensemble/_base.py +++ b/torchensemble/_base.py @@ -281,13 +281,8 @@ def evaluate(self, test_loader, return_loss=False): for _, elem in enumerate(test_loader): data, target = split_data_target(elem, self.device) - # Get averaged probabilities over all base estimators - # for each sample in batch. - # Output shape: (batch_size, n_classes) output = self.forward(*data) - # get val, idx for the class w max log-prob for each sample - # predicted shape: (batch_size) _, predicted = torch.max(output.data, 1) correct += (predicted == target).sum().item() @@ -357,8 +352,8 @@ def __init__(self, input_dim, output_dim, depth=5, lamda=1e-3, cuda=False): self._validate_parameters() - self.internal_node_num_ = 2**self.depth - 1 - self.leaf_node_num_ = 2**self.depth + self.internal_node_num_ = 2 ** self.depth - 1 + self.leaf_node_num_ = 2 ** self.depth # Different penalty coefficients for nodes in different layers self.penalty_list = [ @@ -428,7 +423,7 @@ def _cal_penalty(self, layer_idx, _mu, _path_prob): penalty = torch.tensor(0.0).to(self.device) batch_size = _mu.size()[0] - _mu = _mu.view(batch_size, 2**layer_idx) + _mu = _mu.view(batch_size, 2 ** layer_idx) _path_prob = _path_prob.view(batch_size, 2 ** (layer_idx + 1)) for node in range(0, 2 ** (layer_idx + 1)): diff --git a/torchensemble/voting.py b/torchensemble/voting.py index 189080a..883c675 100644 --- a/torchensemble/voting.py +++ b/torchensemble/voting.py @@ -95,17 +95,14 @@ class VotingClassifier(BaseClassifier): def __init__(self, voting_strategy="soft", **kwargs): super(VotingClassifier, self).__init__(**kwargs) - self._implemented_voting_strategies = ["soft", "hard"] - - if voting_strategy not in self._implemented_voting_strategies: + implemented_strategies = {"soft", "hard"} + if voting_strategy not in implemented_strategies: msg = ( "Voting strategy {} is not implemented, " "please choose from {}." ) raise ValueError( - msg.format( - voting_strategy, self._implemented_voting_strategies - ) + msg.format(voting_strategy, implemented_strategies) ) self.voting_strategy = voting_strategy @@ -117,7 +114,6 @@ def __init__(self, voting_strategy="soft", **kwargs): def forward(self, *x): # Average over class distributions from all base estimators. - # output: (n_estimators, batch_size, n_classes) outputs = [ F.softmax(estimator(*x), dim=1) for estimator in self.estimators_ ] @@ -126,18 +122,10 @@ def forward(self, *x): proba = op.average(outputs) elif self.voting_strategy == "hard": - # Do hard majority voting - # votes: (batch_size) votes = torch.stack(outputs).argmax(dim=2).mode(dim=0)[0] - # Set the probability of the most voted class to 1 proba = torch.zeros_like(outputs[0]) proba.scatter_(1, votes.view(-1, 1), 1) - # TODO: make sure this is compatible with - # predict and evaluate - - # Returns averaged class probabilities for each sample - # proba shape: (batch_size, n_classes) return proba @torchensemble_model_doc( @@ -202,7 +190,7 @@ def fit( # Utils best_acc = 0.0 - # Internal helper function on psuedo forward + # Internal helper function on pseudo forward def _forward(estimators, *x): outputs = [ F.softmax(estimator(*x), dim=1) for estimator in estimators @@ -211,7 +199,6 @@ def _forward(estimators, *x): if self.voting_strategy == "soft": proba = op.average(outputs) - # Maybe the voting strategy can be packaged as a function? elif self.voting_strategy == "hard": votes = torch.stack(outputs).argmax(dim=2).mode(dim=0)[0] proba = torch.zeros_like(outputs[0]) @@ -452,7 +439,7 @@ def fit( # Utils best_loss = float("inf") - # Internal helper function on psuedo forward + # Internal helper function on pseudo forward def _forward(estimators, *x): outputs = [estimator(*x) for estimator in estimators] pred = op.average(outputs) From 2d075c84858d252db975ac7d2ba30fc887c3e86d Mon Sep 17 00:00:00 2001 From: LukasGardberg Date: Wed, 28 Sep 2022 21:14:39 +0200 Subject: [PATCH 4/9] moved hard voting to utils, added test --- torchensemble/_base.py | 2 +- torchensemble/snapshot_ensemble.py | 19 +++++++++++++++++-- torchensemble/tests/test_operator.py | 11 +++++++++++ torchensemble/utils/operator.py | 17 +++++++++++++++++ torchensemble/voting.py | 16 +++++++++------- 5 files changed, 55 insertions(+), 10 deletions(-) diff --git a/torchensemble/_base.py b/torchensemble/_base.py index 418dd8b..c94b3cb 100644 --- a/torchensemble/_base.py +++ b/torchensemble/_base.py @@ -203,7 +203,7 @@ def predict(self, *x): class BaseTreeEnsemble(BaseModule): def __init__( self, - n_estimators, + n_estimators=10, depth=5, lamda=1e-3, cuda=False, diff --git a/torchensemble/snapshot_ensemble.py b/torchensemble/snapshot_ensemble.py index 34bb2f8..7b6a44d 100644 --- a/torchensemble/snapshot_ensemble.py +++ b/torchensemble/snapshot_ensemble.py @@ -209,14 +209,29 @@ def set_scheduler(self, scheduler_name, **kwargs): """Implementation on the SnapshotEnsembleClassifier.""", "seq_model" ) class SnapshotEnsembleClassifier(_BaseSnapshotEnsemble, BaseClassifier): + + def __init__(self, voting_strategy="soft", **kwargs): + super().__init__(**kwargs) + + self.voting_strategy = voting_strategy + @torchensemble_model_doc( """Implementation on the data forwarding in SnapshotEnsembleClassifier.""", # noqa: E501 "classifier_forward", ) def forward(self, *x): - proba = self._forward(*x) - return F.softmax(proba, dim=1) + outputs = [ + F.softmax(estimator(*x), dim=1) for estimator in self.estimators_ + ] + + if self.voting_strategy == "soft": + proba = op.average(outputs) + + elif self.voting_strategy == "hard": + proba = op.majority_vote(outputs) + + return proba @torchensemble_model_doc( """Set the attributes on optimizer for SnapshotEnsembleClassifier.""", diff --git a/torchensemble/tests/test_operator.py b/torchensemble/tests/test_operator.py index c2008f6..6fdaa2a 100644 --- a/torchensemble/tests/test_operator.py +++ b/torchensemble/tests/test_operator.py @@ -43,3 +43,14 @@ def test_residual_regression_invalid_shape(): label.view(-1, 1), # 4 * 1 ) assert "should be the same as output" in str(excinfo.value) + + +def test_majority_voting(): + outputs = [ + torch.FloatTensor(np.array(([0.9, 0.1], [0.2, 0.8]))), + torch.FloatTensor(np.array(([0.7, 0.3], [0.1, 0.9]))), + torch.FloatTensor(np.array(([0.1, 0.9], [0.8, 0.2]))), + ] + actual = op.majority_vote(outputs).numpy() + expected = np.array(([1, 0], [0, 1])) + assert_array_almost_equal(actual, expected) diff --git a/torchensemble/utils/operator.py b/torchensemble/utils/operator.py index c305960..1547f71 100644 --- a/torchensemble/utils/operator.py +++ b/torchensemble/utils/operator.py @@ -3,6 +3,7 @@ import torch import torch.nn.functional as F +from typing import List __all__ = [ @@ -11,6 +12,7 @@ "onehot_encoding", "pseudo_residual_classification", "pseudo_residual_regression", + "majority_vote", ] @@ -51,3 +53,18 @@ def pseudo_residual_regression(target, output): raise ValueError(msg.format(target.size(), output.size())) return target - output + + +def majority_vote(outputs: List[torch.Tensor]) -> torch.Tensor: + """Compute the majority vote for a list of model outputs. + outputs: list of length (n_models) of tensors with shape (n_samples, n_classes) + majority_one_hots: (n_samples, n_classes) + """ + + assert len(outputs[0].shape) == 2, "The shape of outputs should be (n_models, n_samples, n_classes)." + + votes = torch.stack(outputs).argmax(dim=2).mode(dim=0)[0] + proba = torch.zeros_like(outputs[0]) + majority_one_hots = proba.scatter_(1, votes.view(-1, 1), 1) + + return majority_one_hots diff --git a/torchensemble/voting.py b/torchensemble/voting.py index 883c675..aef9380 100644 --- a/torchensemble/voting.py +++ b/torchensemble/voting.py @@ -92,6 +92,7 @@ def _parallel_fit_per_epoch( """Implementation on the VotingClassifier.""", "model" ) class VotingClassifier(BaseClassifier): + def __init__(self, voting_strategy="soft", **kwargs): super(VotingClassifier, self).__init__(**kwargs) @@ -112,7 +113,6 @@ def __init__(self, voting_strategy="soft", **kwargs): "classifier_forward", ) def forward(self, *x): - # Average over class distributions from all base estimators. outputs = [ F.softmax(estimator(*x), dim=1) for estimator in self.estimators_ @@ -122,9 +122,7 @@ def forward(self, *x): proba = op.average(outputs) elif self.voting_strategy == "hard": - votes = torch.stack(outputs).argmax(dim=2).mode(dim=0)[0] - proba = torch.zeros_like(outputs[0]) - proba.scatter_(1, votes.view(-1, 1), 1) + proba = op.majority_vote(outputs) return proba @@ -200,9 +198,7 @@ def _forward(estimators, *x): proba = op.average(outputs) elif self.voting_strategy == "hard": - votes = torch.stack(outputs).argmax(dim=2).mode(dim=0)[0] - proba = torch.zeros_like(outputs[0]) - proba.scatter_(1, votes.view(-1, 1), 1) + proba = op.majority_vote(outputs) return proba @@ -306,6 +302,12 @@ def predict(self, *x): """Implementation on the NeuralForestClassifier.""", "tree_ensmeble_model" ) class NeuralForestClassifier(BaseTreeEnsemble, VotingClassifier): + + def __init__(self, voting_strategy="soft", **kwargs): + super().__init__(**kwargs) + + self.voting_strategy = voting_strategy + @torchensemble_model_doc( """Implementation on the data forwarding in NeuralForestClassifier.""", "classifier_forward", From b33076b9a086269e72778d88280f363d977fcb6a Mon Sep 17 00:00:00 2001 From: LukasGardberg Date: Wed, 28 Sep 2022 21:18:45 +0200 Subject: [PATCH 5/9] fixed linting --- torchensemble/snapshot_ensemble.py | 1 - torchensemble/tests/test_operator.py | 2 +- torchensemble/utils/operator.py | 4 +++- torchensemble/voting.py | 2 -- 4 files changed, 4 insertions(+), 5 deletions(-) diff --git a/torchensemble/snapshot_ensemble.py b/torchensemble/snapshot_ensemble.py index 7b6a44d..37f12d8 100644 --- a/torchensemble/snapshot_ensemble.py +++ b/torchensemble/snapshot_ensemble.py @@ -209,7 +209,6 @@ def set_scheduler(self, scheduler_name, **kwargs): """Implementation on the SnapshotEnsembleClassifier.""", "seq_model" ) class SnapshotEnsembleClassifier(_BaseSnapshotEnsemble, BaseClassifier): - def __init__(self, voting_strategy="soft", **kwargs): super().__init__(**kwargs) diff --git a/torchensemble/tests/test_operator.py b/torchensemble/tests/test_operator.py index 6fdaa2a..5656031 100644 --- a/torchensemble/tests/test_operator.py +++ b/torchensemble/tests/test_operator.py @@ -51,6 +51,6 @@ def test_majority_voting(): torch.FloatTensor(np.array(([0.7, 0.3], [0.1, 0.9]))), torch.FloatTensor(np.array(([0.1, 0.9], [0.8, 0.2]))), ] - actual = op.majority_vote(outputs).numpy() + actual = op.majority_vote(outputs).numpy() expected = np.array(([1, 0], [0, 1])) assert_array_almost_equal(actual, expected) diff --git a/torchensemble/utils/operator.py b/torchensemble/utils/operator.py index 1547f71..556cc1f 100644 --- a/torchensemble/utils/operator.py +++ b/torchensemble/utils/operator.py @@ -61,7 +61,9 @@ def majority_vote(outputs: List[torch.Tensor]) -> torch.Tensor: majority_one_hots: (n_samples, n_classes) """ - assert len(outputs[0].shape) == 2, "The shape of outputs should be (n_models, n_samples, n_classes)." + assert ( + len(outputs[0].shape) == 2 + ), "The shape of outputs should be (n_models, n_samples, n_classes)." votes = torch.stack(outputs).argmax(dim=2).mode(dim=0)[0] proba = torch.zeros_like(outputs[0]) diff --git a/torchensemble/voting.py b/torchensemble/voting.py index aef9380..33d5255 100644 --- a/torchensemble/voting.py +++ b/torchensemble/voting.py @@ -92,7 +92,6 @@ def _parallel_fit_per_epoch( """Implementation on the VotingClassifier.""", "model" ) class VotingClassifier(BaseClassifier): - def __init__(self, voting_strategy="soft", **kwargs): super(VotingClassifier, self).__init__(**kwargs) @@ -302,7 +301,6 @@ def predict(self, *x): """Implementation on the NeuralForestClassifier.""", "tree_ensmeble_model" ) class NeuralForestClassifier(BaseTreeEnsemble, VotingClassifier): - def __init__(self, voting_strategy="soft", **kwargs): super().__init__(**kwargs) From 3d095b3303c347692389cb8d22f82c973af1f5f6 Mon Sep 17 00:00:00 2001 From: LukasGardberg Date: Wed, 28 Sep 2022 21:23:09 +0200 Subject: [PATCH 6/9] fixed linting 2 --- torchensemble/utils/operator.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torchensemble/utils/operator.py b/torchensemble/utils/operator.py index 556cc1f..29851c3 100644 --- a/torchensemble/utils/operator.py +++ b/torchensemble/utils/operator.py @@ -57,7 +57,8 @@ def pseudo_residual_regression(target, output): def majority_vote(outputs: List[torch.Tensor]) -> torch.Tensor: """Compute the majority vote for a list of model outputs. - outputs: list of length (n_models) of tensors with shape (n_samples, n_classes) + outputs: list of length (n_models) + containing tensors with shape (n_samples, n_classes) majority_one_hots: (n_samples, n_classes) """ From b7757664663b4f487cbe34ff959d90add62fa161 Mon Sep 17 00:00:00 2001 From: xuyxu Date: Sun, 2 Oct 2022 19:28:09 +0800 Subject: [PATCH 7/9] add changelog --- CHANGELOG.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 34da1d0..c2bcc7d 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -18,6 +18,7 @@ Changelog Ver 0.1.* --------- +* |Feature| |API| Add ``voting_strategy`` parameter for :class:`VotingClassifer`, :class:`NeuralForestClassifier`, and :class:`SnapshotEnsembleClassifier` | `@LukasGardberg `__ * |Fix| Fix the sampling issue in :class:`BaggingClassifier` and :class:`BaggingRegressor` | `@SunHaozhe `__ * |Feature| |API| Add :class:`NeuralForestClassifier` and :class:`NeuralForestRegressor` | `@xuyxu `__ * |Fix| Relax check on input dataloader | `@xuyxu `__ From 7dad1d7b614e7dcbdf82ddd54a9ec14025ce9836 Mon Sep 17 00:00:00 2001 From: LukasGardberg Date: Wed, 5 Oct 2022 21:29:51 +0200 Subject: [PATCH 8/9] switched to value error --- torchensemble/utils/operator.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/torchensemble/utils/operator.py b/torchensemble/utils/operator.py index 29851c3..7e7b702 100644 --- a/torchensemble/utils/operator.py +++ b/torchensemble/utils/operator.py @@ -62,9 +62,11 @@ def majority_vote(outputs: List[torch.Tensor]) -> torch.Tensor: majority_one_hots: (n_samples, n_classes) """ - assert ( - len(outputs[0].shape) == 2 - ), "The shape of outputs should be (n_models, n_samples, n_classes)." + if len(outputs[0].shape) != 2: + msg = """The shape of outputs should be a list tensors of + length (n_models) with sizes (n_samples, n_classes). + The first tensor had shape {} """ + raise ValueError(msg.format(outputs[0].shape)) votes = torch.stack(outputs).argmax(dim=2).mode(dim=0)[0] proba = torch.zeros_like(outputs[0]) From 8b5bf439bec408ca0a5f3b1fdf967dd6a7c26fc0 Mon Sep 17 00:00:00 2001 From: xuyxu Date: Thu, 6 Oct 2022 12:05:53 +0800 Subject: [PATCH 9/9] flake8 formating --- torchensemble/utils/operator.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchensemble/utils/operator.py b/torchensemble/utils/operator.py index 7e7b702..e5c6738 100644 --- a/torchensemble/utils/operator.py +++ b/torchensemble/utils/operator.py @@ -63,8 +63,8 @@ def majority_vote(outputs: List[torch.Tensor]) -> torch.Tensor: """ if len(outputs[0].shape) != 2: - msg = """The shape of outputs should be a list tensors of - length (n_models) with sizes (n_samples, n_classes). + msg = """The shape of outputs should be a list tensors of + length (n_models) with sizes (n_samples, n_classes). The first tensor had shape {} """ raise ValueError(msg.format(outputs[0].shape))