Skip to content

Commit

Permalink
Add rudimentary NestedTensor.sum(dim) (pytorch#82387)
Browse files Browse the repository at this point in the history
A first step towards adding dimension-wise reductions to NestedTensor,
- Assumes tensors in the nested tensor as well as the buffer of the nested tensor are contiguous
- Always enforces `keepdim=True`
- Only supports reduction across the last dimension
- No support for acctype (`dtype` argument)
- No autograd support
- CPU only

Next steps would be to add support for the above. For now this basic support is for prototyping to make sure `NestedTensor` can be used as an API for segment reductions.

Pull Request resolved: pytorch#82387
Approved by: https://github.com/jbschlosser
  • Loading branch information
mikaylagawarecki authored and pytorchmergebot committed Jul 28, 2022
1 parent 2bfae07 commit 89c0123
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 0 deletions.
2 changes: 2 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4742,6 +4742,8 @@
structured_delegate: sum.IntList_out
device_check: NoCheck # TensorIterator
variants: function, method
dispatch:
NestedTensorCPU: NestedTensor_sum_dim_CPU

- func: sum.dim_DimnameList(Tensor self, Dimname[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor
device_check: NoCheck # TensorIterator
Expand Down
72 changes: 72 additions & 0 deletions aten/src/ATen/native/nested/NestedTensorMath.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -665,6 +665,78 @@ Tensor& NestedTensor_mul__Scalar(Tensor& self, const Scalar& other) {
return NestedTensor_mul__Tensor(self, wrapped_scalar_tensor(other));
}

// Very rudimentary sum_dim for prototyping with torch_scatter.segment_reduce.
Tensor NestedTensor_sum_dim_CPU(
const Tensor& self,
OptionalIntArrayRef opt_dims,
bool keepdim,
c10::optional<ScalarType> dtype) {
// Only allow reductions across the last dim
auto dims = opt_dims.value_or(IntArrayRef{});
TORCH_CHECK(
dims.size() == 1,
"NestedTensor only allows reduction of a single dimension for now."
);
auto dim = maybe_wrap_dim(dims[0], self.dim());
TORCH_CHECK(
dim == self.dim() - 1,
"NestedTensor can only be reduced across the last dimension for now ",
"got dimension ",
dim,
" instead.");
// Always keep reduced dim for now
// This is to avoid the case where the nested tensors are 1D and keepdim=False
// making the nested tensors -> elements (e.g. sum(nt([1, 2 ,3], [4, 5]), -1) -> nt(6, 9))
TORCH_CHECK(keepdim, "NestedTensor always requires keepdim=True for now.");
// acc_dtype is not supported for now
TORCH_CHECK(!dtype, "NestedTensor does not support dtype argument for now.");

auto nt_input = get_nested_tensor_impl(self);
TORCH_CHECK(
nested_tensor_impl_is_contiguous(nt_input),
"NestedTensor does not support reductions when the input is noncontiguous for now.");
int64_t ntensors = nt_input->size(0);
if (ntensors == 0) {
return self;
}
const Tensor& buffer = nt_input->get_buffer();

auto sizemat = nt_input->get_nested_size_tensor();
// create output size tensor for keepdim=True
auto output_sizemat = sizemat.clone();
output_sizemat.select(1, -1).fill_(1);

auto num_segments = at::prod(output_sizemat, -1);
auto segment_lengths = sizemat.select(1, -1);
const int64_t new_numel = at::sum(num_segments).item<int64_t>();
auto output_buffer = buffer.new_empty(IntArrayRef(new_numel));

// This logic assumes for now that
// (1) all the nested tensors are contiguous
// (2) the nested tensors are stored contiguously in the buffer
AT_DISPATCH_ALL_TYPES_AND2(
ScalarType::Half, ScalarType::BFloat16, buffer.scalar_type(), "nested_sum_dim_cpu", [&]() {
auto* output_data = output_buffer.data_ptr<scalar_t>();
const auto* input_data = buffer.data_ptr<scalar_t>();
int64_t out_idx = 0, in_idx = 0;
for (const auto i : c10::irange(ntensors)) {
int64_t segments = num_segments[i].item<int64_t>();
int64_t segment_length = segment_lengths[i].item<int64_t>();
for (auto j = 0; j < segments; j++) {
scalar_t res = 0;
for (auto k = 0; k < segment_length; k++) {
res += input_data[in_idx];
in_idx += 1;
}
output_data[out_idx] = res;
out_idx += 1;
}
}
});

return wrap_buffer(output_buffer, output_sizemat);
}

Tensor select_nested(const Tensor& self, int64_t dim, int64_t index) {
auto self_ptr = get_nested_tensor_impl(self);
int64_t positive_dim = at::maybe_wrap_dim(dim, self_ptr->dim());
Expand Down
30 changes: 30 additions & 0 deletions test/test_nestedtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
dtypesIfCUDA,
instantiate_device_type_tests,
skipMeta,
onlyCPU
)
from torch.testing._internal.common_utils import TestCase, IS_FBCODE, run_tests, freeze_rng_state
from torch import nested_tensor
Expand Down Expand Up @@ -523,6 +524,35 @@ def test_nested_tensor_mul_in_place(self, device, dtype):
lambda: vector.mul_(nt1)
)

@onlyCPU
@skipMeta
@dtypes(torch.float)
def test_nested_tensor_sum_dim(self, device, dtype):
params = ((2, (1, 1)), ((4), (4, 4)), (10, (3, 5, 7)))

def test_sum(nt, dim, keepdim=True):
nt2 = nt.clone()
nt = nt.sum(dim=dim, keepdim=keepdim)
ub2 = nt2.unbind()
ub2 = [t.sum(-1, keepdim=keepdim) for t in ub2]
nt2 = torch.nested_tensor(ub2)
self.nt_equal(nt, nt2)
return

for ntensors, max_sizes in params:
test_sum(self.random_nt(device, dtype, ntensors, max_sizes), len(max_sizes))

# Test error inputs
with self.assertRaisesRegex(RuntimeError, "NestedTensor can only be reduced across the last"):
torch.nested_tensor([torch.tensor([3, 4, 5]), torch.tensor([1, 2])]).sum(0, keepdim=True)

with self.assertRaisesRegex(RuntimeError, "NestedTensor only allows reduction of a single"):
torch.nested_tensor([torch.tensor([[3, 4, 5]]), torch.tensor([[1, 2]])]).sum([0, 1], keepdim=True)

with self.assertRaisesRegex(RuntimeError, "NestedTensor always requires keepdim=True for now."):
torch.nested_tensor([torch.tensor([3, 4, 5]), torch.tensor([1, 2])]).sum(-1)


@dtypes(torch.float, torch.float16)
@skipMeta
@torch.inference_mode()
Expand Down

0 comments on commit 89c0123

Please sign in to comment.