Skip to content

Commit

Permalink
Merge pull request #34 from BirkhoffG/fix-batch-len
Browse files Browse the repository at this point in the history
Fix batch len
  • Loading branch information
BirkhoffG authored Aug 10, 2024
2 parents 1c0394b + 390b3e3 commit 0bb83ae
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 21 deletions.
3 changes: 2 additions & 1 deletion jax_dataloader/loaders/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,4 +69,5 @@ def next_key(self):
return subkey

def __len__(self):
return len(self.indices) // self.batch_size + int(not self.drop_last)
complete_batches, remainder = divmod(len(self.indices), self.batch_size)
return complete_batches if self.drop_last else complete_batches + bool(remainder)
6 changes: 3 additions & 3 deletions jax_dataloader/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ def get_batch(batch):

# %% ../nbs/tests.ipynb 4
def test_no_shuffle(cls, ds, batch_size: int, feats, labels):
dl = cls(ds, batch_size=batch_size, shuffle=False)
assert len(dl) == len(feats) // batch_size + 1
dl = cls(ds, batch_size=batch_size, shuffle=False, drop_last=False)
assert len(dl) == len(feats) // batch_size + bool(len(feats) % batch_size)
for _ in range(2):
X_list, Y_list = [], []
for batch in dl:
Expand Down Expand Up @@ -48,7 +48,7 @@ def test_no_shuffle_drop_last(cls, ds, batch_size: int, feats, labels):
def test_shuffle(cls, ds, batch_size: int, feats, labels):
dl = cls(ds, batch_size=batch_size, shuffle=True, drop_last=False)
last_X, last_Y = jnp.array([]), jnp.array([])
assert len(dl) == len(feats) // batch_size + 1
assert len(dl) == len(feats) // batch_size + bool(len(feats) % batch_size)
for _ in range(2):
X_list, Y_list = [], []
for batch in dl:
Expand Down
34 changes: 21 additions & 13 deletions nbs/loader.jax.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -34,18 +34,7 @@
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2024-02-01 22:18:26.142014: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n",
"2024-02-01 22:18:26.142138: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n",
"2024-02-01 22:18:26.151662: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n",
"2024-02-01 22:18:26.979728: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n"
]
}
],
"outputs": [],
"source": [
"#| export\n",
"from __future__ import print_function, division, annotations\n",
Expand Down Expand Up @@ -134,7 +123,8 @@
" return subkey\n",
" \n",
" def __len__(self):\n",
" return len(self.indices) // self.batch_size + int(not self.drop_last)"
" complete_batches, remainder = divmod(len(self.indices), self.batch_size)\n",
" return complete_batches if self.drop_last else complete_batches + bool(remainder)"
]
},
{
Expand All @@ -153,13 +143,31 @@
"assert len(dl.indices) == 1280"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"samples = 128\n",
"batch_size = 128\n",
"feats = np.arange(samples).repeat(10).reshape(samples, 10)\n",
"labels = np.arange(samples).reshape(samples, 1)\n",
"ds = ArrayDataset(feats, labels)\n",
"dl = DataLoaderJAX(ds, batch_size=batch_size, shuffle=True, drop_last=True)\n",
"assert len(dl) == 1\n",
"dl = DataLoaderJAX(ds, batch_size=batch_size, shuffle=True, drop_last=False)\n",
"assert len(dl) == 1"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#| hide\n",
"test_dataloader(DataLoaderJAX, samples=10, batch_size=10)\n",
"test_dataloader(DataLoaderJAX, samples=20, batch_size=12)\n",
"test_dataloader(DataLoaderJAX, samples=20, batch_size=10)\n",
"test_dataloader(DataLoaderJAX, samples=11, batch_size=10)\n",
Expand Down
4 changes: 3 additions & 1 deletion nbs/loader.torch.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -198,9 +198,11 @@
"source": [
"#| hide\n",
"#| torch\n",
"test_dataloader(DataLoaderPytorch, samples=10, batch_size=10)\n",
"test_dataloader(DataLoaderPytorch, samples=20, batch_size=12)\n",
"test_dataloader(DataLoaderPytorch, samples=20, batch_size=10)\n",
"test_dataloader(DataLoaderPytorch, samples=11, batch_size=10)"
"test_dataloader(DataLoaderPytorch, samples=11, batch_size=10)\n",
"test_dataloader(DataLoaderPytorch, samples=40, batch_size=12)"
]
},
{
Expand Down
6 changes: 3 additions & 3 deletions nbs/tests.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@
"source": [
"#| exporti\n",
"def test_no_shuffle(cls, ds, batch_size: int, feats, labels):\n",
" dl = cls(ds, batch_size=batch_size, shuffle=False)\n",
" assert len(dl) == len(feats) // batch_size + 1\n",
" dl = cls(ds, batch_size=batch_size, shuffle=False, drop_last=False)\n",
" assert len(dl) == len(feats) // batch_size + bool(len(feats) % batch_size)\n",
" for _ in range(2):\n",
" X_list, Y_list = [], []\n",
" for batch in dl:\n",
Expand Down Expand Up @@ -102,7 +102,7 @@
"def test_shuffle(cls, ds, batch_size: int, feats, labels):\n",
" dl = cls(ds, batch_size=batch_size, shuffle=True, drop_last=False)\n",
" last_X, last_Y = jnp.array([]), jnp.array([])\n",
" assert len(dl) == len(feats) // batch_size + 1\n",
" assert len(dl) == len(feats) // batch_size + bool(len(feats) % batch_size)\n",
" for _ in range(2):\n",
" X_list, Y_list = [], []\n",
" for batch in dl:\n",
Expand Down

0 comments on commit 0bb83ae

Please sign in to comment.