Skip to content

Commit

Permalink
update readme
Browse files Browse the repository at this point in the history
  • Loading branch information
BirkhoffG committed Feb 16, 2024
1 parent d605fe7 commit 66437b4
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 21 deletions.
18 changes: 11 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,15 @@ License](https://img.shields.io/github/license/BirkhoffG/jax-dataloader.svg)
supports

- **4 datasets to download and pre-process data**:

- [jax dataset](https://birkhoffg.github.io/jax-dataloader/dataset/)
- [huggingface datasets](https://github.com/huggingface/datasets)
- [pytorch
Dataset](https://pytorch.org/docs/stable/data.html#torch.utils.data.Dataset)
- [tensorflow dataset](www.tensorflow.org/datasets)

- **3 backends to iteratively load batches**:

- [jax
dataloader](https://birkhoffg.github.io/jax-dataloader/core.html#jax-dataloader)
- [pytorch
Expand All @@ -36,6 +39,9 @@ import jax_dataloader as jdl
dataloader = jdl.DataLoader(
dataset, # Can be a jdl.Dataset or pytorch or huggingface dataset
backend='jax', # Use 'jax' for loading data (also supports `pytorch`)
batch_size=32, # Batch size
shuffle=True, # Shuffle the dataloader every iteration
drop_last=False, # Drop the last batch or not
)

batch = next(iter(dataloader)) # iterate next batch
Expand Down Expand Up @@ -113,6 +119,8 @@ This `arr_ds` can be loaded by *every* backends.
dataloader = jdl.DataLoader(arr_ds, 'jax', batch_size=5, shuffle=True)
# Or we can use the pytorch backend
dataloader = jdl.DataLoader(arr_ds, 'pytorch', batch_size=5, shuffle=True)
# Or we can use the tensorflow backend
dataloader = jdl.DataLoader(arr_ds, 'tensorflow', batch_size=5, shuffle=True)
```

### Using Huggingface Datasets
Expand All @@ -138,6 +146,8 @@ Then, we can use `jax_dataloader` to load batches of `hf_ds`.
dataloader = jdl.DataLoader(hf_ds['train'], 'jax', batch_size=5, shuffle=True)
# Or we can use the pytorch backend
dataloader = jdl.DataLoader(hf_ds['train'], 'pytorch', batch_size=5, shuffle=True)
# Or we can use the tensorflow backend
dataloader = jdl.DataLoader(hf_ds['train'], 'tensorflow', batch_size=5, shuffle=True)
```

### Using Pytorch Datasets
Expand Down Expand Up @@ -169,13 +179,7 @@ We load the MNIST dataset from `torchvision`. The `ToNumpy` object
transforms images to `numpy.array`.

``` python
class ToNumpy(object):
def __call__(self, pic):
return np.array(pic, dtype=float)
```

``` python
pt_ds = MNIST('/tmp/mnist/', download=True, transform=ToNumpy(), train=False)
pt_ds = MNIST('/tmp/mnist/', download=True, transform=lambda x: np.array(x, dtype=float), train=False)
```

This `pt_ds` can **only** be loaded via `"pytorch"` dataloaders.
Expand Down
26 changes: 12 additions & 14 deletions nbs/index.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,14 @@
"It supports\n",
"\n",
"* **4 datasets to download and pre-process data**: \n",
"\n",
" * [jax dataset](https://birkhoffg.github.io/jax-dataloader/dataset/)\n",
" * [huggingface datasets](https://github.com/huggingface/datasets) \n",
" * [pytorch Dataset](https://pytorch.org/docs/stable/data.html#torch.utils.data.Dataset)\n",
" * [tensorflow dataset](www.tensorflow.org/datasets)\n",
"\n",
"* **3 backends to iteratively load batches**: \n",
"\n",
" * [jax dataloader](https://birkhoffg.github.io/jax-dataloader/core.html#jax-dataloader)\n",
" * [pytorch dataloader](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader) \n",
" * [tensorflow dataset](www.tensorflow.org/datasets)\n",
Expand All @@ -51,6 +53,9 @@
"dataloader = jdl.DataLoader(\n",
" dataset, # Can be a jdl.Dataset or pytorch or huggingface dataset\n",
" backend='jax', # Use 'jax' for loading data (also supports `pytorch`)\n",
" batch_size=32, # Batch size \n",
" shuffle=True, # Shuffle the dataloader every iteration\n",
" drop_last=False, # Drop the last batch or not\n",
")\n",
"\n",
"batch = next(iter(dataloader)) # iterate next batch\n",
Expand Down Expand Up @@ -226,7 +231,9 @@
"# Create a `DataLoader` from the `ArrayDataset` via jax backend\n",
"dataloader = jdl.DataLoader(arr_ds, 'jax', batch_size=5, shuffle=True)\n",
"# Or we can use the pytorch backend\n",
"dataloader = jdl.DataLoader(arr_ds, 'pytorch', batch_size=5, shuffle=True)"
"dataloader = jdl.DataLoader(arr_ds, 'pytorch', batch_size=5, shuffle=True)\n",
"# Or we can use the tensorflow backend\n",
"dataloader = jdl.DataLoader(arr_ds, 'tensorflow', batch_size=5, shuffle=True)"
]
},
{
Expand Down Expand Up @@ -285,7 +292,9 @@
"# Create a `DataLoader` from the `datasets.Dataset` via jax backend\n",
"dataloader = jdl.DataLoader(hf_ds['train'], 'jax', batch_size=5, shuffle=True)\n",
"# Or we can use the pytorch backend\n",
"dataloader = jdl.DataLoader(hf_ds['train'], 'pytorch', batch_size=5, shuffle=True)"
"dataloader = jdl.DataLoader(hf_ds['train'], 'pytorch', batch_size=5, shuffle=True)\n",
"# Or we can use the tensorflow backend\n",
"dataloader = jdl.DataLoader(hf_ds['train'], 'tensorflow', batch_size=5, shuffle=True)"
]
},
{
Expand Down Expand Up @@ -331,25 +340,14 @@
"The `ToNumpy` object transforms images to `numpy.array`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class ToNumpy(object):\n",
" def __call__(self, pic):\n",
" return np.array(pic, dtype=float)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#| torch\n",
"pt_ds = MNIST('/tmp/mnist/', download=True, transform=ToNumpy(), train=False)"
"pt_ds = MNIST('/tmp/mnist/', download=True, transform=lambda x: np.array(x, dtype=float), train=False)"
]
},
{
Expand Down

0 comments on commit 66437b4

Please sign in to comment.