Skip to content

Commit

Permalink
feat: support dataloader with multiple input (#76)
Browse files Browse the repository at this point in the history
* Update CHANGELOG.rst

* update temporary code

* fix error

* fix typo

* add unit tests

* Update _constants.py

* update doc

* update doc

* Update index.rst

* Update test_all_models_multi_input.py

* update unit tests

* Update test_all_models_multi_input.py

* update
  • Loading branch information
xuyxu authored Jun 8, 2021
1 parent 30e2f4e commit 50de90f
Show file tree
Hide file tree
Showing 16 changed files with 580 additions and 782 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ Changelog
Ver 0.1.*
---------

* |Feature| |API| Add :class:`SoftGradientBoostingClassifier` and :class:`SoftGradientBoostingRegressor` | `@xuyxu <https://github.com/xuyxu>`__
* |Feature| |API| Support using dataloader with multiple input | `@xuyxu <https://github.com/xuyxu>`__
* |Fix| Fix missing functionality of ``use_reduction_sum`` for :meth:`fit` of Gradient Boosting | `@xuyxu <https://github.com/xuyxu>`__
* |Enhancement| Relax :mod:`tensorboard` as a soft dependency | `@xuyxu <https://github.com/xuyxu>`__
* |Enhancement| |API| Simplify the training workflow of :class:`FastGeometricClassifier` and :class:`FastGeometricRegressor` | `@xuyxu <https://github.com/xuyxu>`__
Expand Down
4 changes: 0 additions & 4 deletions torchensemble/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
from .adversarial_training import AdversarialTrainingRegressor
from .fast_geometric import FastGeometricClassifier
from .fast_geometric import FastGeometricRegressor
from .soft_gradient_boosting import SoftGradientBoostingClassifier
from .soft_gradient_boosting import SoftGradientBoostingRegressor


__all__ = [
Expand All @@ -31,6 +29,4 @@
"AdversarialTrainingRegressor",
"FastGeometricClassifier",
"FastGeometricRegressor",
"SoftGradientBoostingClassifier",
"SoftGradientBoostingRegressor",
]
60 changes: 31 additions & 29 deletions torchensemble/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import torch.nn as nn

from . import _constants as const
from .utils.io import split_data_target
from .utils.logging import get_tb_logger


Expand Down Expand Up @@ -148,7 +149,7 @@ def set_scheduler(self, scheduler_name, **kwargs):
self.use_scheduler_ = True

@abc.abstractmethod
def forward(self, x):
def forward(self, *x):
"""
Implementation on the data forwarding in the ensemble. Notice
that the input ``x`` should be a data batch instead of a standalone
Expand All @@ -170,27 +171,26 @@ def fit(
"""

@torch.no_grad()
def predict(self, X, return_numpy=True):
"""Docstrings decorated by downstream models."""
def predict(self, *x):
"""Docstrings decorated by downstream ensembles."""
self.eval()
pred = None

if isinstance(X, torch.Tensor):
pred = self.forward(X.to(self.device))
elif isinstance(X, np.ndarray):
X = torch.Tensor(X).to(self.device)
pred = self.forward(X)
else:
msg = (
"The type of input X should be one of {{torch.Tensor,"
" np.ndarray}}."
)
raise ValueError(msg)
# Copy data
x_device = []
for data in x:
if isinstance(data, torch.Tensor):
x_device.append(data.to(self.device))
elif isinstance(data, np.ndarray):
x_device.append(torch.Tensor(data).to(self.device))
else:
msg = (
"The type of input X should be one of {{torch.Tensor,"
" np.ndarray}}."
)
raise ValueError(msg)

pred = self.forward(*x_device)
pred = pred.cpu()
if return_numpy:
return pred.numpy()

return pred


Expand All @@ -212,7 +212,8 @@ def _decide_n_outputs(self, train_loader):
# Infer `n_outputs` from the dataloader
else:
labels = []
for _, (_, target) in enumerate(train_loader):
for _, elem in enumerate(train_loader):
_, target = split_data_target(elem, self.device)
labels.append(target)
labels = torch.unique(torch.cat(labels))
n_outputs = labels.size(0)
Expand All @@ -228,9 +229,9 @@ def evaluate(self, test_loader, return_loss=False):
criterion = nn.CrossEntropyLoss()
loss = 0.0

for _, (data, target) in enumerate(test_loader):
data, target = data.to(self.device), target.to(self.device)
output = self.forward(data)
for _, elem in enumerate(test_loader):
data, target = split_data_target(elem, self.device)
output = self.forward(*data)
_, predicted = torch.max(output.data, 1)
correct += (predicted == target).sum().item()
total += target.size(0)
Expand Down Expand Up @@ -258,25 +259,26 @@ def _decide_n_outputs(self, train_loader):
The number of outputs equals the number of target variables for
regressors (e.g., `1` in univariate regression).
"""
for _, (_, target) in enumerate(train_loader):
for _, elem in enumerate(train_loader):
_, target = split_data_target(elem, self.device)
if len(target.size()) == 1:
n_outputs = 1
n_outputs = 1 # univariate regression
else:
n_outputs = target.size(1)
n_outputs = target.size(1) # multivariate regression
break

return n_outputs

@torch.no_grad()
def evaluate(self, test_loader):
"""Docstrings decorated by downstream models."""
"""Docstrings decorated by downstream ensembles."""
self.eval()
mse = 0.0
criterion = nn.MSELoss()

for _, (data, target) in enumerate(test_loader):
data, target = data.to(self.device), target.to(self.device)
output = self.forward(data)
for _, elem in enumerate(test_loader):
data, target = split_data_target(elem, self.device)
output = self.forward(*data)
mse += criterion(output, target)

return float(mse) / len(test_loader)
11 changes: 3 additions & 8 deletions torchensemble/_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,19 +124,14 @@
Parameters
----------
X : {Tensor, ndarray}
A data batch in the form of tensor or Numpy array.
return_numpy : bool, default=True
Whether to convert the predictions into a Numpy array.
X : {tensor, numpy array}
A data batch in the form of tensor or numpy array.
Returns
-------
pred : Array of shape (n_samples, n_outputs)
pred : tensor of shape (n_samples, n_outputs)
For classifiers, ``n_outputs`` is the number of distinct classes. For
regressors, ``n_output`` is the number of target variables.
- If ``return_numpy`` is ``False``, the result is a tensor.
- If ``return_numpy`` is ``True``, the result is a Numpy array.
"""


Expand Down
91 changes: 49 additions & 42 deletions torchensemble/adversarial_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,23 +107,24 @@ def _parallel_fit_per_epoch(
# Parallelization corrupts the binding between optimizer and scheduler
set_module.update_lr(optimizer, cur_lr)

for batch_idx, (data, target) in enumerate(train_loader):
for batch_idx, elem in enumerate(train_loader):

batch_size = data.size()[0]
data, target = data.to(device), target.to(device)
data.requires_grad = True
data, target = io.split_data_target(elem, device)
batch_size = data[0].size(0)
for tensor in data:
tensor.requires_grad = True

# Get adversarial samples
_output = estimator(data)
_output = estimator(*data)
_loss = criterion(_output, target)
_loss.backward()
data_grad = data.grad.data
data_grad = [tensor.grad.data for tensor in data]
adv_data = _get_fgsm_samples(data, epsilon, data_grad)

# Compute the training loss
optimizer.zero_grad()
org_output = estimator(data)
adv_output = estimator(adv_data)
org_output = estimator(*data)
adv_output = estimator(*adv_data)
loss = criterion(org_output, target) + criterion(adv_output, target)
loss.backward()
optimizer.step()
Expand Down Expand Up @@ -156,27 +157,31 @@ def _parallel_fit_per_epoch(
return estimator, optimizer


def _get_fgsm_samples(sample, epsilon, sample_grad):
def _get_fgsm_samples(sample_list, epsilon, sample_grad_list):
"""
Private functions used to generate adversarial samples with fast gradient
sign method (FGSM).
"""

# Check the input range of `sample`
min_value, max_value = torch.min(sample), torch.max(sample)
if not 0 <= min_value < max_value <= 1:
msg = (
"The input range of samples passed to adversarial training"
" should be in the range [0, 1], but got [{:.3f}, {:.3f}]"
" instead."
)
raise ValueError(msg.format(min_value, max_value))
perturbed_sample_list = []
for sample, sample_grad in zip(sample_list, sample_grad_list):
# Check the input range of `sample`
min_value, max_value = torch.min(sample), torch.max(sample)
if not 0 <= min_value < max_value <= 1:
msg = (
"The input range of samples passed to adversarial training"
" should be in the range [0, 1], but got [{:.3f}, {:.3f}]"
" instead."
)
raise ValueError(msg.format(min_value, max_value))

sign_sample_grad = sample_grad.sign()
perturbed_sample = sample + epsilon * sign_sample_grad
perturbed_sample = torch.clamp(perturbed_sample, 0, 1)
sign_sample_grad = sample_grad.sign()
perturbed_sample = sample + epsilon * sign_sample_grad
perturbed_sample = torch.clamp(perturbed_sample, 0, 1)

return perturbed_sample
perturbed_sample_list.append(perturbed_sample)

return perturbed_sample_list


class _BaseAdversarialTraining(BaseModule):
Expand Down Expand Up @@ -218,10 +223,10 @@ class AdversarialTrainingClassifier(_BaseAdversarialTraining, BaseClassifier):
"""Implementation on the data forwarding in AdversarialTrainingClassifier.""", # noqa: E501
"classifier_forward",
)
def forward(self, x):
def forward(self, *x):
# Take the average over class distributions from all base estimators.
outputs = [
F.softmax(estimator(x), dim=1) for estimator in self.estimators_
F.softmax(estimator(*x), dim=1) for estimator in self.estimators_
]
proba = op.average(outputs)

Expand Down Expand Up @@ -282,9 +287,9 @@ def fit(
best_acc = 0.0

# Internal helper function on pesudo forward
def _forward(estimators, data):
def _forward(estimators, *x):
outputs = [
F.softmax(estimator(data), dim=1) for estimator in estimators
F.softmax(estimator(*x), dim=1) for estimator in estimators
]
proba = op.average(outputs)

Expand Down Expand Up @@ -336,10 +341,11 @@ def _forward(estimators, data):
with torch.no_grad():
correct = 0
total = 0
for _, (data, target) in enumerate(test_loader):
data = data.to(self.device)
target = target.to(self.device)
output = _forward(estimators, data)
for _, elem in enumerate(test_loader):
data, target = io.split_data_target(
elem, self.device
)
output = _forward(estimators, *data)
_, predicted = torch.max(output.data, 1)
correct += (predicted == target).sum().item()
total += target.size(0)
Expand Down Expand Up @@ -384,8 +390,8 @@ def evaluate(self, test_loader, return_loss=False):
return super().evaluate(test_loader, return_loss)

@torchensemble_model_doc(item="predict")
def predict(self, X, return_numpy=True):
return super().predict(X, return_numpy)
def predict(self, *x):
return super().predict(*x)


@torchensemble_model_doc(
Expand All @@ -397,9 +403,9 @@ class AdversarialTrainingRegressor(_BaseAdversarialTraining, BaseRegressor):
"""Implementation on the data forwarding in AdversarialTrainingRegressor.""", # noqa: E501
"regressor_forward",
)
def forward(self, x):
def forward(self, *x):
# Take the average over predictions from all base estimators.
outputs = [estimator(x) for estimator in self.estimators_]
outputs = [estimator(*x) for estimator in self.estimators_]
pred = op.average(outputs)

return pred
Expand Down Expand Up @@ -459,8 +465,8 @@ def fit(
best_mse = float("inf")

# Internal helper function on pesudo forward
def _forward(estimators, data):
outputs = [estimator(data) for estimator in estimators]
def _forward(estimators, *x):
outputs = [estimator(*x) for estimator in estimators]
pred = op.average(outputs)

return pred
Expand Down Expand Up @@ -510,10 +516,11 @@ def _forward(estimators, data):
self.eval()
with torch.no_grad():
mse = 0.0
for _, (data, target) in enumerate(test_loader):
data = data.to(self.device)
target = target.to(self.device)
output = _forward(estimators, data)
for _, elem in enumerate(test_loader):
data, target = io.split_data_target(
elem, self.device
)
output = _forward(estimators, *data)
mse += criterion(output, target)
mse /= len(test_loader)

Expand Down Expand Up @@ -553,5 +560,5 @@ def evaluate(self, test_loader):
return super().evaluate(test_loader)

@torchensemble_model_doc(item="predict")
def predict(self, X, return_numpy=True):
return super().predict(X, return_numpy)
def predict(self, *x):
return super().predict(*x)
Loading

0 comments on commit 50de90f

Please sign in to comment.