From 3705c78f66a4928ef2c2b71ff2bdf2abae83ef8e Mon Sep 17 00:00:00 2001 From: BirkhoffG <26811230+BirkhoffG@users.noreply.github.com> Date: Fri, 9 Aug 2024 14:38:23 -0400 Subject: [PATCH 1/2] Fix DataLoaderJAX to handle incomplete batches without dropping them --- jax_dataloader/loaders/jax.py | 3 ++- jax_dataloader/tests.py | 6 +++--- nbs/loader.jax.ipynb | 16 +++------------- nbs/tests.ipynb | 6 +++--- 4 files changed, 11 insertions(+), 20 deletions(-) diff --git a/jax_dataloader/loaders/jax.py b/jax_dataloader/loaders/jax.py index 7f15942..f1a277f 100644 --- a/jax_dataloader/loaders/jax.py +++ b/jax_dataloader/loaders/jax.py @@ -69,4 +69,5 @@ def next_key(self): return subkey def __len__(self): - return len(self.indices) // self.batch_size + int(not self.drop_last) + complete_batches, remainder = divmod(len(self.indices), self.batch_size) + return complete_batches if self.drop_last else complete_batches + bool(remainder) diff --git a/jax_dataloader/tests.py b/jax_dataloader/tests.py index 6687350..8f540aa 100644 --- a/jax_dataloader/tests.py +++ b/jax_dataloader/tests.py @@ -17,8 +17,8 @@ def get_batch(batch): # %% ../nbs/tests.ipynb 4 def test_no_shuffle(cls, ds, batch_size: int, feats, labels): - dl = cls(ds, batch_size=batch_size, shuffle=False) - assert len(dl) == len(feats) // batch_size + 1 + dl = cls(ds, batch_size=batch_size, shuffle=False, drop_last=False) + assert len(dl) == len(feats) // batch_size + bool(len(feats) % batch_size) for _ in range(2): X_list, Y_list = [], [] for batch in dl: @@ -48,7 +48,7 @@ def test_no_shuffle_drop_last(cls, ds, batch_size: int, feats, labels): def test_shuffle(cls, ds, batch_size: int, feats, labels): dl = cls(ds, batch_size=batch_size, shuffle=True, drop_last=False) last_X, last_Y = jnp.array([]), jnp.array([]) - assert len(dl) == len(feats) // batch_size + 1 + assert len(dl) == len(feats) // batch_size + bool(len(feats) % batch_size) for _ in range(2): X_list, Y_list = [], [] for batch in dl: diff --git a/nbs/loader.jax.ipynb b/nbs/loader.jax.ipynb index 3bc2673..877c5d3 100644 --- a/nbs/loader.jax.ipynb +++ b/nbs/loader.jax.ipynb @@ -34,18 +34,7 @@ "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "2024-02-01 22:18:26.142014: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n", - "2024-02-01 22:18:26.142138: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n", - "2024-02-01 22:18:26.151662: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n", - "2024-02-01 22:18:26.979728: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n" - ] - } - ], + "outputs": [], "source": [ "#| export\n", "from __future__ import print_function, division, annotations\n", @@ -134,7 +123,8 @@ " return subkey\n", " \n", " def __len__(self):\n", - " return len(self.indices) // self.batch_size + int(not self.drop_last)" + " complete_batches, remainder = divmod(len(self.indices), self.batch_size)\n", + " return complete_batches if self.drop_last else complete_batches + bool(remainder)" ] }, { diff --git a/nbs/tests.ipynb b/nbs/tests.ipynb index fd16efe..164ad07 100644 --- a/nbs/tests.ipynb +++ b/nbs/tests.ipynb @@ -57,8 +57,8 @@ "source": [ "#| exporti\n", "def test_no_shuffle(cls, ds, batch_size: int, feats, labels):\n", - " dl = cls(ds, batch_size=batch_size, shuffle=False)\n", - " assert len(dl) == len(feats) // batch_size + 1\n", + " dl = cls(ds, batch_size=batch_size, shuffle=False, drop_last=False)\n", + " assert len(dl) == len(feats) // batch_size + bool(len(feats) % batch_size)\n", " for _ in range(2):\n", " X_list, Y_list = [], []\n", " for batch in dl:\n", @@ -102,7 +102,7 @@ "def test_shuffle(cls, ds, batch_size: int, feats, labels):\n", " dl = cls(ds, batch_size=batch_size, shuffle=True, drop_last=False)\n", " last_X, last_Y = jnp.array([]), jnp.array([])\n", - " assert len(dl) == len(feats) // batch_size + 1\n", + " assert len(dl) == len(feats) // batch_size + bool(len(feats) % batch_size)\n", " for _ in range(2):\n", " X_list, Y_list = [], []\n", " for batch in dl:\n", From 390b3e3b17f7c92f22413ecca2fa07d1e1756c59 Mon Sep 17 00:00:00 2001 From: BirkhoffG <26811230+BirkhoffG@users.noreply.github.com> Date: Fri, 9 Aug 2024 14:42:21 -0400 Subject: [PATCH 2/2] Add more test cases --- nbs/loader.jax.ipynb | 18 ++++++++++++++++++ nbs/loader.torch.ipynb | 4 +++- 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/nbs/loader.jax.ipynb b/nbs/loader.jax.ipynb index 877c5d3..155c777 100644 --- a/nbs/loader.jax.ipynb +++ b/nbs/loader.jax.ipynb @@ -143,6 +143,23 @@ "assert len(dl.indices) == 1280" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "samples = 128\n", + "batch_size = 128\n", + "feats = np.arange(samples).repeat(10).reshape(samples, 10)\n", + "labels = np.arange(samples).reshape(samples, 1)\n", + "ds = ArrayDataset(feats, labels)\n", + "dl = DataLoaderJAX(ds, batch_size=batch_size, shuffle=True, drop_last=True)\n", + "assert len(dl) == 1\n", + "dl = DataLoaderJAX(ds, batch_size=batch_size, shuffle=True, drop_last=False)\n", + "assert len(dl) == 1" + ] + }, { "cell_type": "code", "execution_count": null, @@ -150,6 +167,7 @@ "outputs": [], "source": [ "#| hide\n", + "test_dataloader(DataLoaderJAX, samples=10, batch_size=10)\n", "test_dataloader(DataLoaderJAX, samples=20, batch_size=12)\n", "test_dataloader(DataLoaderJAX, samples=20, batch_size=10)\n", "test_dataloader(DataLoaderJAX, samples=11, batch_size=10)\n", diff --git a/nbs/loader.torch.ipynb b/nbs/loader.torch.ipynb index 370d1e6..7f4d911 100644 --- a/nbs/loader.torch.ipynb +++ b/nbs/loader.torch.ipynb @@ -198,9 +198,11 @@ "source": [ "#| hide\n", "#| torch\n", + "test_dataloader(DataLoaderPytorch, samples=10, batch_size=10)\n", "test_dataloader(DataLoaderPytorch, samples=20, batch_size=12)\n", "test_dataloader(DataLoaderPytorch, samples=20, batch_size=10)\n", - "test_dataloader(DataLoaderPytorch, samples=11, batch_size=10)" + "test_dataloader(DataLoaderPytorch, samples=11, batch_size=10)\n", + "test_dataloader(DataLoaderPytorch, samples=40, batch_size=12)" ] }, {