Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handling data slicing at DataloaderIter.__next__ level #17

Open
Wala-Touati opened this issue Mar 29, 2024 · 3 comments
Open

Handling data slicing at DataloaderIter.__next__ level #17

Wala-Touati opened this issue Mar 29, 2024 · 3 comments
Labels
dataloaders/dataset help wanted Extra attention is needed question Further information is requested

Comments

@Wala-Touati
Copy link

The current DataloaderIter class assumes that the Dataset's __getitem__ method can handle slicing and return batches of samples. This assumption causes compatibility issues with datasets that return a single sample which is usually the case. To resolve this, i think it's better to handle slicing within DataloaderIter __next__ method and keep __getitem__'s concern limited to returning a single sample.
The line causing errors is

batch = self.dataset[self.current_index:end]

The proposed solution is to modify __next__ to call __getitem__ multiple times to create a batch, maybe something that looks like this:

    def __next__(self):
        if self.total_yielded < self.total_yield:
            batch = []
            for _ in range(self.batch_size):
                if self.total_yielded >= self.total_yield:
                    break
                batch.append(self.dataset[self.current_index])
                self.current_index += 1
                self.total_yielded += 1
            return batch
        raise StopIteration
@zeddo123
Copy link
Member

DataloaderIter class assumes that the Dataset's getitem method can handle slicing and return batches of samples. This assumption causes compatibility issues with datasets that return a single sample which is usually the case.

Yep, I was able to reproduce the issue on my side and it looks like the dataloader is expecting a __getitem__ method that enable slicing. That said, I think it's better the shift the responsibility of the slicing to the dataset class since:

  1. Most python objects that implement a getitem do have slicing capabilities
  2. It's not that difficult to implement the slicing on the dataset since it has access to the underlining data

Here's an example:

def __getitem__(self, val):
    if type(val) is slice:
       return [data[x] for x in range(*val.indices(len(data)))]
    return data[val]

Although this shifts the burden on users to define their dataset well, this looks more concise to me. What you do think?

@zeddo123 zeddo123 added help wanted Extra attention is needed question Further information is requested dataloaders/dataset labels Mar 29, 2024
@Wala-Touati
Copy link
Author

Thank you for looking into this issue and providing a solution. I've actually been relying on the same approach, and it works really great.

However, I initially thought this was just a temporary fix and not the way to go, as it didn't feel like the conventional approach used in deep learning frameworks such as PyTorch. Where users typically find the __getitem__ definition more intuitive since it returns a single sample, and then use other mechanisms such as the collate_fn function to handle slicing and batching.

That being said, I now see the benefits of implementing slicing within the __getitem__ method, such as making use of the slicing capabilities of most Python objects and keeping the dataset class more self-contained. I think this is a minor issue, and both approaches work just fine.

@zeddo123
Copy link
Member

Alright, I'll keep this issue open for now.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
dataloaders/dataset help wanted Extra attention is needed question Further information is requested
Projects
None yet
Development

No branches or pull requests

2 participants