From ae9cb6e4db6f81fd18148c2cc67d72b903d81a46 Mon Sep 17 00:00:00 2001 From: Marius Arvinte <5852612+mariusarvinte@users.noreply.github.com> Date: Tue, 8 Oct 2024 07:01:26 -0700 Subject: [PATCH] Handle negative values for `dim` input in `pad_across_processes` (#3114) * Handle negative values for dim * Add tests for negative dimension --- src/accelerate/utils/operations.py | 5 ++++- tests/test_utils.py | 9 +++++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/src/accelerate/utils/operations.py b/src/accelerate/utils/operations.py index 162009e76b6..5f737344b56 100644 --- a/src/accelerate/utils/operations.py +++ b/src/accelerate/utils/operations.py @@ -652,8 +652,11 @@ def _pad_across_processes(tensor, dim=0, pad_index=0, pad_first=False): CannotPadNestedTensorWarning, ) return tensor - if dim >= len(tensor.shape): + if dim >= len(tensor.shape) or dim < -len(tensor.shape): return tensor + # Convert negative dimensions to non-negative + if dim < 0: + dim += len(tensor.shape) # Gather all sizes size = torch.tensor(tensor.shape, device=tensor.device)[None] diff --git a/tests/test_utils.py b/tests/test_utils.py index ed4481ed92c..cabdb55a1a6 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -304,6 +304,15 @@ def test_pad_across_processes(self): nt2 = pad_across_processes(nt) assert nt is nt2 + # Basic functionality + tensor = torch.randn(4, 3, 100) + padded_tensor = pad_across_processes(tensor, dim=-1) + assert padded_tensor.shape[-1] == 100 + + # dim = -4 is out of bounds + padded_tensor = pad_across_processes(tensor, dim=-4) + assert padded_tensor is tensor + def test_slice_and_concatenate(self): # First base case: 2 processes, batch size of 1 num_processes = 2