Skip to content

Commit

Permalink
Validate Sparse Compressed tensor arguments
Browse files Browse the repository at this point in the history
Pull Request resolved: pytorch#75946

Approved by: https://github.com/cpuhrsch
  • Loading branch information
pearu authored and pytorchmergebot committed Apr 18, 2022
1 parent 5d0450d commit e9791cd
Show file tree
Hide file tree
Showing 7 changed files with 178 additions and 88 deletions.
8 changes: 8 additions & 0 deletions aten/src/ATen/SparseCsrTensorUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,5 +95,13 @@ inline int columnDimension(Layout layout, IntArrayRef size) {
return size.size() - (isCompressedColumn(layout) ? 2 : 1);
}

inline int compressedDimension(Layout layout, IntArrayRef size) {
return size.size() - (isCompressedRow(layout) ? 2 : 1);
}

inline int plainDimension(Layout layout, IntArrayRef size) {
return size.size() - (isCompressedRow(layout) ? 1 : 2);
}

} // namespace sparse_csr
} // namespace at
2 changes: 2 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5376,6 +5376,8 @@

- func: _validate_sparse_csr_tensor_args(Tensor crow_indices, Tensor col_indices, Tensor values, int[] size) -> ()

- func: _validate_sparse_compressed_tensor_args(Tensor crow_indices, Tensor col_indices, Tensor values, int[] size, Layout layout) -> ()

- func: _sparse_coo_tensor_with_dims(int sparse_dim, int dense_dim, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor
dispatch:
SparseCPU, SparseCUDA: new_with_dims_sparse
Expand Down
214 changes: 131 additions & 83 deletions aten/src/ATen/native/sparse/SparseCsrTensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#else
#include <ATen/ops/_nnz_native.h>
#include <ATen/ops/_sparse_csr_tensor_unsafe_native.h>
#include <ATen/ops/_validate_sparse_compressed_tensor_args_native.h>
#include <ATen/ops/_validate_sparse_csr_tensor_args_native.h>
#include <ATen/ops/clone_native.h>
#include <ATen/ops/col_indices_native.h>
Expand All @@ -41,122 +42,160 @@ namespace {

} // end anonymous namespace

void _validate_sparse_csr_tensor_args(const Tensor& crow_indices, const Tensor& col_indices, const Tensor& values, IntArrayRef size) {
void _validate_sparse_compressed_tensor_args_worker(const Tensor& compressed_indices, const Tensor& plain_indices, const Tensor& values, const IntArrayRef size, const Layout& layout) {

// Layout must be Sparse Compressed
AT_DISPATCH_ALL_SPARSE_COMPRESSED_LAYOUTS(layout, "validate_sparse_compressed_tensor_args", [&]{});

const std::string layout_name = layoutToString(layout, /*upper=*/ true);
const std::string compressed_indices_name = compressedIndicesName(layout);
const std::string plain_indices_name = plainIndicesName(layout);

// Layout Invariants
TORCH_CHECK(
col_indices.layout() == kStrided && col_indices.is_contiguous(),
"expected col_indices to be a strided and contiguous tensor");
plain_indices.layout() == kStrided && plain_indices.is_contiguous(),
"expected ", plain_indices_name, " to be a strided and contiguous tensor");

TORCH_CHECK(
crow_indices.layout() == kStrided && crow_indices.is_contiguous(),
"expected crow_indices to be a strided and contiguous tensor");
compressed_indices.layout() == kStrided && compressed_indices.is_contiguous(),
"expected ", compressed_indices_name ," to be a strided and contiguous tensor");

TORCH_CHECK(
values.layout() == kStrided && values.is_contiguous(),
"expected values to be a strided and contiguous tensor");

// Shape and Strides invariants
TORCH_CHECK(
size.size() >= 2,
"size of a batched CSR tensor must have length >= 2, but got: ",
size.size());
size.size() >= 2,
"size of a batched ", layout_name, " tensor must have length >= 2, but got: ",
size.size());
TORCH_CHECK(
crow_indices.dim() >= 1,
"crow_indices must have dim >= 1 but got crow_indices.dim() = ",
crow_indices.dim());
compressed_indices.dim() >= 1,
compressed_indices_name, " must have dim >= 1 but got ", compressed_indices_name, ".dim() = ",
compressed_indices.dim());
TORCH_CHECK(
col_indices.dim() >= 1,
"col_indices must have dim >= 1 but got col_indices.dim() = ",
col_indices.dim());
plain_indices.dim() >= 1,
plain_indices_name, " must have dim >= 1 but got ", plain_indices_name, ".dim() = ",
plain_indices.dim());
TORCH_CHECK(
values.dim() >= 1,
"values must have dim >= 1 but got values.dim() = ",
values.dim());
values.dim() >= 1,
"values must have dim >= 1 but got values.dim() = ",
values.dim());

TORCH_CHECK(
crow_indices.dim() == col_indices.dim(),
"Number of dimensions of crow_indices and col_indices must be the same.");
TORCH_CHECK(
crow_indices.dim() == values.dim(),
"Number of dimensions of indices and values must be the same.");
compressed_indices.dim() == plain_indices.dim(),
"number of dimensions of ", compressed_indices_name, " and ", plain_indices_name, " must be the same.");

AT_DISPATCH_PLAIN_SPARSE_COMPRESSED_LAYOUTS(
layout, "validate_sparse_compressed_tensor_args",
[&] {
TORCH_CHECK(
compressed_indices.dim() == values.dim(),
"number of dimensions of indices and values must be the same.");
},
[&] {
TORCH_CHECK(
compressed_indices.dim() + 2 == values.dim(),
"number of dimensions of indices must be two less than the number of dimensions of the values.");
});

TORCH_CHECK(
static_cast<size_t>(crow_indices.dim()) == size.size() - 1,
"Number of dimensions of indices must be one less than the number of dimensions of the provided size.");
static_cast<size_t>(compressed_indices.dim()) == size.size() - 1,
"number of dimensions of indices must be one less than the number of dimensions of the provided size.");

int block_ndim = AT_DISPATCH_PLAIN_SPARSE_COMPRESSED_LAYOUTS(layout, "validate_sparse_compressed_tensor_args", [&]{ return 0; }, [&]{ return 2; });
IntArrayRef block_size = values.sizes().slice(values.dim() - block_ndim, block_ndim);
int64_t numel_per_block = AT_DISPATCH_PLAIN_SPARSE_COMPRESSED_LAYOUTS(layout, "validate_sparse_compressed_tensor_args",
[&]() -> int64_t { return 1; }, [&]() -> int64_t { return block_size[0] * block_size[1]; });
int compressed_dim = compressedDimension(layout, size);
int plain_dim = plainDimension(layout, size);

// All batch sizes must be the same
auto batch_size = size.slice(0, size.size() - 2);
auto crow_indices_batch_size = crow_indices.sizes().slice(0, crow_indices.dim() - 1);
auto col_indices_batch_size = col_indices.sizes().slice(0, col_indices.dim() - 1);
auto values_batch_size = values.sizes().slice(0, values.dim() - 1);
auto compressed_indices_batch_size = compressed_indices.sizes().slice(0, compressed_indices.dim() - 1);
auto plain_indices_batch_size = plain_indices.sizes().slice(0, plain_indices.dim() - 1);
auto values_batch_size = values.sizes().slice(0, values.dim() - 1 - block_ndim);
TORCH_CHECK(
batch_size == crow_indices_batch_size &&
batch_size == col_indices_batch_size &&
batch_size == compressed_indices_batch_size &&
batch_size == plain_indices_batch_size &&
batch_size == values_batch_size,
"All batch dimensions of the provided size, indices, and values must be the same.");
"all batch dimensions of the provided size (", batch_size, "), indices (",
compressed_indices_batch_size,", ", plain_indices_batch_size, "), and values (",
values_batch_size,") must be the same.");

// Note, this check also enforces `crow_indices.size(-1) >= 1`
TORCH_CHECK(
crow_indices.size(-1) == (size[size.size() - 2] + 1),
"crow_indices.size(-1) must be equal to size[-2] + 1 (that is ", size[size.size() - 2] + 1, "), but got: ",
crow_indices.size(-1));
// Note, this check also enforces `compressed_indices.size(-1) >= 1`
TORCH_CHECK(
col_indices.numel() == values.numel(),
"col_indices and values must have the same number of elements, but got col_indices.numel(): ",
col_indices.numel(),
", values.numel(): ",
values.numel());
compressed_indices.size(-1) == (size[compressed_dim] + 1),
compressed_indices_name, ".size(-1) must be equal to size[-", (size.size() - compressed_dim), "] + 1 (that is ",
size[compressed_dim] + 1, "), but got: ", compressed_indices.size(-1));

// Indices invariants
AT_DISPATCH_INDEX_TYPES(crow_indices.scalar_type(), "csr_construct_check", [&] {
Tensor crow_indices_cpu = crow_indices.to(kCPU);
auto crow_indices_data_ptr = crow_indices_cpu.data_ptr<index_t>();
auto batch_stride = crow_indices_cpu.dim() >= 2 ? crow_indices_cpu.stride(-2) : 0;
for (const auto batch_id : c10::irange(batchCount(crow_indices_cpu))) {
TORCH_CHECK(
crow_indices_data_ptr[batch_id*batch_stride] == 0,
"(Batch element ", batch_id, ") ",
": 0th value of crow_indices must be 0, but it is ", crow_indices_data_ptr[batch_id*batch_stride]);
TORCH_CHECK(
crow_indices_data_ptr[batch_id*batch_stride + crow_indices.size(-1) - 1] == col_indices.size(-1),
"(Batch element ", batch_id, ") ",
"last value of crow_indices should be equal to the length of col_indices.");

for (int i = 1; i <= size[size.size() - 2]; i++) {
AT_DISPATCH_PLAIN_SPARSE_COMPRESSED_LAYOUTS(layout, "validate_sparse_compressed_tensor_args",
[&] {
TORCH_CHECK(
plain_indices.numel() == values.numel(),
plain_indices_name, " and values must have the same number of elements, but got ", plain_indices_name, ".numel(): ",
plain_indices.numel(), ", values.numel(): ", values.numel());
},
[&] {
TORCH_CHECK(
crow_indices_data_ptr[batch_id*batch_stride + i - 1] <= crow_indices_data_ptr[batch_id*batch_stride + i],
"(Batch element ", batch_id, ") ",
"at position i = ", i, ", the condition crow_indices[i - 1] <= crow_indices[i] fails");
}
}
if (col_indices.numel() > 0) {
TORCH_CHECK(0 <= col_indices.min().item<index_t>(), "col_indices.min() should be greater or equal to zero");
TORCH_CHECK(size[size.size() - 1] > col_indices.max().item<index_t>(), "size[-1] should be greater than col_indices.max()");
}
});

// CSR Type Invariants
auto crow_indices_type = crow_indices.scalar_type();
auto col_indices_type = col_indices.scalar_type();
plain_indices.numel() * numel_per_block == values.numel(),
"number of ", plain_indices_name, " elements must be the same as the number of blocks in values, but got ",
plain_indices_name, ".numel() * numel_per_block: ", plain_indices.numel() * numel_per_block,
", values.numel(): ", values.numel(),", numel_per_block: ", numel_per_block);
});

// Indices invariants
AT_DISPATCH_INDEX_TYPES(compressed_indices.scalar_type(), "validate_sparse_compressed_tensor_args",
[&] {
Tensor compressed_indices_cpu = compressed_indices.to(kCPU);
auto compressed_indices_data_ptr = compressed_indices_cpu.data_ptr<index_t>();
auto batch_stride = compressed_indices_cpu.dim() >= 2 ? compressed_indices_cpu.stride(-2) : 0;

for (const auto batch_id : c10::irange(batchCount(compressed_indices_cpu))) {
TORCH_CHECK(
compressed_indices_data_ptr[batch_id*batch_stride] == 0,
"(Batch element ", batch_id, ") ",
": 0th value of ", compressed_indices_name, " must be 0, but it is ", compressed_indices_data_ptr[batch_id*batch_stride]);
TORCH_CHECK(
compressed_indices_data_ptr[batch_id*batch_stride + compressed_indices.size(-1) - 1] == plain_indices.size(-1),
"(Batch element ", batch_id, ") ",
"last value of ", compressed_indices_name, " should be equal to the length of ", plain_indices_name, ".");
for (int i = 1; i <= size[size.size() - 2]; i++) {
TORCH_CHECK(
compressed_indices_data_ptr[batch_id*batch_stride + i - 1] <= compressed_indices_data_ptr[batch_id*batch_stride + i],
"(Batch element ", batch_id, ") ",
"at position i = ", i, ", the condition ", compressed_indices_name, "[i - 1] <= ", compressed_indices_name, "[i] fails");
}
}
if (plain_indices.numel() > 0) {
TORCH_CHECK(0 <= plain_indices.min().item<index_t>(), plain_indices_name, ".min() should be greater or equal to zero");
TORCH_CHECK(size[plain_dim] > plain_indices.max().item<index_t>(), "size[-", (size.size() - plain_dim),"] should be greater than ", plain_indices_name, ".max()");
}
});

// Type Invariants
auto compressed_indices_type = compressed_indices.scalar_type();
auto plain_indices_type = plain_indices.scalar_type();
TORCH_CHECK(
crow_indices_type == col_indices_type,
"both crow_indices and col_indices should have the same type.");
compressed_indices_type == plain_indices_type,
"both ", compressed_indices_name, " and ", plain_indices_name, " should have the same type.");
TORCH_CHECK(
crow_indices_type == kInt || crow_indices_type == kLong,
"crow_indices and col_indices must be an int32 or int64 type, but got: ",
crow_indices_type);
compressed_indices_type == kInt || compressed_indices_type == kLong,
compressed_indices_name, " and ", plain_indices_name, " must be an int32 or int64 type, but got: ",
compressed_indices_type);

// CSR Device Invariants
// Device Invariants
TORCH_CHECK(
col_indices.get_device() == crow_indices.get_device(),
"crow_indices and col_indices devices (",
crow_indices.get_device(),
plain_indices.get_device() == compressed_indices.get_device(),
compressed_indices_name, " and ", plain_indices_name, " devices (",
compressed_indices.get_device(),
", ",
col_indices.get_device(),
plain_indices.get_device(),
") must match");
TORCH_CHECK(
crow_indices.get_device() == values.get_device(),
"device of crow_indices (",
crow_indices.get_device(),
compressed_indices.get_device() == values.get_device(),
"device of ", compressed_indices_name, " (",
compressed_indices.get_device(),
") must match device of values (",
values.get_device(),
")");
Expand All @@ -165,6 +204,15 @@ void _validate_sparse_csr_tensor_args(const Tensor& crow_indices, const Tensor&
"device type of values (",
values.device().type(),
") must be CPU or CUDA");

}

void _validate_sparse_compressed_tensor_args(const Tensor& crow_indices, const Tensor& col_indices, const Tensor& values, IntArrayRef size, Layout layout) {
_validate_sparse_compressed_tensor_args_worker(crow_indices, col_indices, values, size, layout);
}

void _validate_sparse_csr_tensor_args(const Tensor& crow_indices, const Tensor& col_indices, const Tensor& values, IntArrayRef size) {
_validate_sparse_compressed_tensor_args_worker(crow_indices, col_indices, values, size, kSparseCsr);
}

// Construction of CSR tensors.
Expand Down
11 changes: 6 additions & 5 deletions test/test_sparse_csr.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,7 @@ def test_factory_type_invariants_check(self, device):
torch.tensor([1, 2, 3, 4]),
device=device)

with self.assertRaisesRegex(RuntimeError, r"\"csr_construct_check\" not implemented for 'Short'"):
with self.assertRaisesRegex(RuntimeError, r"\"validate_sparse_compressed_tensor_args\" not implemented for 'Short'"):
torch.sparse_csr_tensor(torch.tensor([0, 2, 4], dtype=torch.int16),
torch.tensor([0, 1, 0, 1], dtype=torch.int16),
torch.tensor([1, 2, 3, 4]),
Expand Down Expand Up @@ -445,22 +445,23 @@ def test_factory_shape_invariants_check(self, device):


with self.assertRaisesRegex(RuntimeError,
r"Number of dimensions of crow_indices and col_indices must be the same"):
r"number of dimensions of crow_indices and col_indices must be the same"):
torch.sparse_csr_tensor(crow_indices, col_indices.repeat(2, 1), values, size,
device=device)

with self.assertRaisesRegex(RuntimeError,
r"Number of dimensions of indices and values must be the same"):
r"number of dimensions of indices and values must be the same"):
torch.sparse_csr_tensor(crow_indices, col_indices, values.repeat(2, 1), size,
device=device)

with self.assertRaisesRegex(RuntimeError,
r"Number of dimensions of indices must be one less"):
r"number of dimensions of indices must be one less"):
torch.sparse_csr_tensor(crow_indices.repeat(2, 1), col_indices.repeat(2, 1), values.repeat(2, 1), size,
device=device)

with self.assertRaisesRegex(RuntimeError,
r"All batch dimensions of the provided size, indices, and values must be the same"):
r"all batch dimensions of the provided size \(\[2\]\), indices \(\[2\], \[3\]\),"
r" and values \(\[4\]\) must be the same"):
torch.sparse_csr_tensor(crow_indices.repeat(2, 1), col_indices.repeat(3, 1), values.repeat(4, 1), (2, 2, 10),
device=device)

Expand Down
2 changes: 2 additions & 0 deletions tools/codegen/api/python.py
Original file line number Diff line number Diff line change
Expand Up @@ -1044,6 +1044,8 @@ def arg_parser_unpack_method(t: Type, has_default: bool) -> str:
return 'toDouble'
elif t.name == BaseTy.str:
return 'stringView'
elif t.name == BaseTy.Layout:
return 'layout'

elif isinstance(t, OptionalType):
if str(t.elem) == 'Tensor':
Expand Down
28 changes: 28 additions & 0 deletions torch/csrc/utils/tensor_new.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -809,6 +809,34 @@ void _validate_sparse_coo_tensor_args(c10::DispatchKey dispatch_key, at::ScalarT
at::native::_validate_sparse_coo_tensor_args(indices, values, r.intlist(2));
}

void _validate_sparse_compressed_tensor_args(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs) {
auto options = dispatchKeyToTensorOptions(dispatch_key);
enum {
ARG_CROW_INDICES = 0,
ARG_COL_INDICES,
ARG_VALUES,
ARG_SIZE,
ARG_LAYOUT,
ARGS_COUNT
};

const std::string signature = "_validate_sparse_compressed_tensor(PyObject* compressed_indices, PyObject* plain_indices, PyObject* values, IntArrayRef size, Layout layout)";
static PythonArgParser parser({signature});

ParsedArgs<ARGS_COUNT> parsed_args;
auto r = parser.parse(args, kwargs, parsed_args);
Tensor values = internal_new_from_data(
options, scalar_type, c10::nullopt, r.pyobject(ARG_VALUES),
/*copy_variables=*/false, /*copy_numpy=*/true, /*type_inference=*/true);
// See Note [Ensuring sparse values and indices match devices]
Tensor crow_indices = internal_new_from_data(
values.options(), kInt, c10::nullopt, r.pyobject(ARG_CROW_INDICES),
/*copy_variables=*/false, /*copy_numpy=*/true, /*type_inference=*/true);
Tensor col_indices = internal_new_from_data(
values.options(), kInt, c10::nullopt, r.pyobject(ARG_COL_INDICES),
/*copy_variables=*/false, /*copy_numpy=*/true, /*type_inference=*/true);
at::native::_validate_sparse_compressed_tensor_args(crow_indices, col_indices, values, r.intlist(ARG_SIZE), r.layout(ARG_LAYOUT));
}

void _validate_sparse_csr_tensor_args(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs) {
auto options = dispatchKeyToTensorOptions(dispatch_key);
Expand Down
1 change: 1 addition & 0 deletions torch/csrc/utils/tensor_new.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ at::Tensor _sparse_csr_tensor_unsafe_ctor(
at::ScalarType scalar_type,
PythonArgs& r);
void _validate_sparse_csr_tensor_args(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs);
void _validate_sparse_compressed_tensor_args(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs);
at::Tensor tensor_ctor(
c10::DispatchKey dispatch_key,
at::ScalarType scalar_type,
Expand Down

0 comments on commit e9791cd

Please sign in to comment.