-
Notifications
You must be signed in to change notification settings - Fork 95
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Create dataloder.py * add unit tests * update unit test
- Loading branch information
Showing
2 changed files
with
53 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
import torch | ||
import pytest | ||
import numpy as np | ||
from torch.utils.data import TensorDataset, DataLoader | ||
|
||
from torchensemble.utils.dataloder import FixedDataLoader | ||
|
||
|
||
# Data | ||
X = torch.Tensor(np.array(([0.1, 0.1], [0.2, 0.2], [0.3, 0.3], [0.4, 0.4]))) | ||
y = torch.LongTensor(np.array(([0, 0, 1, 1]))) | ||
|
||
data = TensorDataset(X, y) | ||
dataloder = DataLoader(data, batch_size=2, shuffle=False) | ||
|
||
|
||
def test_fixed_dataloder(): | ||
fixed_dataloader = FixedDataLoader(dataloder) | ||
for _, (fixed_elem, elem) in enumerate(zip(fixed_dataloader, dataloder)): | ||
# Check same elements | ||
for elem_1, elem_2 in zip(fixed_elem, elem): | ||
assert torch.equal(elem_1, elem_2) | ||
|
||
# Check dataloder length | ||
assert len(fixed_dataloader) == 2 | ||
|
||
|
||
def test_fixed_dataloader_invalid_type(): | ||
with pytest.raises(ValueError) as excinfo: | ||
FixedDataLoader((X, y)) | ||
assert "input used to instantiate FixedDataLoader" in str(excinfo.value) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
from torch.utils.data import DataLoader | ||
|
||
|
||
class FixedDataLoader(object): | ||
def __init__(self, dataloader): | ||
# Check input | ||
if not isinstance(dataloader, DataLoader): | ||
msg = ( | ||
"The input used to instantiate FixedDataLoader should be a" | ||
" DataLoader from `torch.utils.data`." | ||
) | ||
raise ValueError(msg) | ||
|
||
self.elem_list = [] | ||
for _, elem in enumerate(dataloader): | ||
self.elem_list.append(elem) | ||
|
||
def __getitem__(self, index): | ||
return self.elem_list[index] | ||
|
||
def __len__(self): | ||
return len(self.elem_list) |