Skip to content

Commit

Permalink
Set correct device id on efficientzerotensors (pytorch#71611)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#71611

Fixes pytorch#71160 pytorch#69925

Test Plan: Imported from OSS

Reviewed By: george-qi

Differential Revision: D33834916

Pulled By: anjali411

fbshipit-source-id: 11cec343e95e2ee188ab7576f26f64aa19317891
(cherry picked from commit f6e86f8)
  • Loading branch information
anjali411 authored and pytorchmergebot committed Jan 30, 2022
1 parent 784bd92 commit a18cfb7
Show file tree
Hide file tree
Showing 6 changed files with 45 additions and 49 deletions.
29 changes: 4 additions & 25 deletions aten/src/ATen/native/TensorFactories.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -427,27 +427,6 @@ Tensor& eye_out_cpu(int64_t n, int64_t m, Tensor& result) {

namespace {

// The ZeroTensor allocator ignores whatever allocation is requested and always
// gives you nullptr
struct ZeroTensorAllocator final : public at::Allocator {
ZeroTensorAllocator(at::Device device) : device_(device) {};
~ZeroTensorAllocator() override = default;
static void deleter(void* const pointer) {
TORCH_INTERNAL_ASSERT(!pointer);
}
DataPtr allocate(const size_t nbytes) const override {
return {nullptr, nullptr, &deleter, device_};
}
DeleterFnPtr raw_deleter() const override {
return deleter;
}
at::Device device_;
};

at::Allocator* GetZeroTensorAllocator(ZeroTensorAllocator& zt) {
return &zt;
}

// Performs dtype inference for full
TensorOptions infer_full_options(
const Scalar& fill_value,
Expand Down Expand Up @@ -1076,11 +1055,11 @@ Tensor _efficientzerotensor(IntArrayRef size,
c10::optional<Device> device,
c10::optional<bool> pin_memory) {
auto device_ = device_or_default(device);
auto allocator = ZeroTensorAllocator(device_);
auto allocator = at::native::ZeroTensorAllocator(device_);
auto dtype_ = dtype_or_default(dtype);
constexpr auto zero_ks = at::DispatchKeySet(at::DispatchKey::ZeroTensor);
return at::detail::empty_generic(
size, &allocator, zero_ks, dtype_, c10::nullopt);
auto zero_ks = at::DispatchKeySet(c10::DispatchKey::CPU) | at::DispatchKeySet(c10::DispatchKey::ZeroTensor);
auto out = at::detail::empty_generic(size, &allocator, zero_ks, dtype_, c10::nullopt);
return out;
}

Tensor& zeros_out(IntArrayRef size, Tensor& result) {
Expand Down
17 changes: 17 additions & 0 deletions aten/src/ATen/native/TensorFactories.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,23 @@ inline void check_supported_max_int_with_precision(int64_t n, const Tensor& tens
}
}

// The ZeroTensor allocator ignores whatever allocation is requested and always
// gives you nullptr
struct ZeroTensorAllocator final : public at::Allocator {
ZeroTensorAllocator(at::Device device) : device_(device) {};
~ZeroTensorAllocator() override = default;
static void deleter(void* const pointer) {
TORCH_INTERNAL_ASSERT(!pointer);
}
DataPtr allocate(const size_t nbytes) const override {
return {nullptr, nullptr, &deleter, device_};
}
DeleterFnPtr raw_deleter() const override {
return deleter;
}
at::Device device_;
};

using binary_fn = void (*)(TensorIterator&);

DECLARE_DISPATCH(binary_fn, complex_stub);
Expand Down
17 changes: 17 additions & 0 deletions aten/src/ATen/native/cuda/TensorFactories.cu
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,23 @@ Tensor empty_cuda(IntArrayRef size, c10::optional<ScalarType> dtype_opt, c10::op
return at::detail::empty_cuda(size, dtype_opt, layout_opt, device_opt, pin_memory_opt, memory_format_opt);
}

Tensor _efficientzerotensor_cuda(IntArrayRef size,
c10::optional<ScalarType> dtype,
c10::optional<Layout> layout,
c10::optional<Device> device,
c10::optional<bool> pin_memory) {
auto device_ = device_or_default(device);
if (!device_.has_index()) {
device_.set_index(at::cuda::current_device());
}
auto allocator = at::native::ZeroTensorAllocator(device_);
auto dtype_ = dtype_or_default(dtype);
auto zero_ks = at::DispatchKeySet(c10::DispatchKey::CUDA) | at::DispatchKeySet(c10::DispatchKey::ZeroTensor);
auto out = at::detail::empty_generic(size, &allocator, zero_ks, dtype_, c10::nullopt);
return out;
}


Tensor empty_strided_cuda(IntArrayRef size, IntArrayRef stride, c10::optional<ScalarType> dtype_opt, c10::optional<Layout> layout_opt, c10::optional<Device> device_opt, c10::optional<bool> pin_memory_opt) {
return at::detail::empty_strided_cuda(size, stride, dtype_opt, layout_opt, device_opt, pin_memory_opt);
}
Expand Down
3 changes: 2 additions & 1 deletion aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4807,7 +4807,8 @@

- func: _efficientzerotensor(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
dispatch:
CompositeExplicitAutograd: _efficientzerotensor
CPU: _efficientzerotensor
CUDA: _efficientzerotensor_cuda

- func: zeros(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor

Expand Down
4 changes: 4 additions & 0 deletions test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -5477,6 +5477,10 @@ def test_clamp(self, device, dtype):
actual = x[..., :1].clamp(lb, ub)
self.assertEqual(expect, actual)

def test_cuda_device_idx(self, device):
x = torch.zeros(3, device=device)
y = torch._efficientzerotensor(3, device=device)
self.assertEqual(x.device, y.device)

# we implemented custom deallocation for subclasses, so it behooves
# us to make sure all of these bits work. We'll use __del__ to
Expand Down
24 changes: 1 addition & 23 deletions torch/testing/_internal/common_methods_invocations.py
Original file line number Diff line number Diff line change
Expand Up @@ -9069,9 +9069,6 @@ def ref_pairwise_distance(input1, input2):
assert_autodiffed=True,
rhs_make_tensor_kwargs=dict(exclude_zero=True),
skips=(
# 69913: RuntimeError: CUDA error: an illegal memory access was encountered
DecorateInfo(unittest.skip("Skipped!"), 'TestGradients', 'test_fn_fwgrad_bwgrad',
device_type='cuda', dtypes=[torch.double, torch.cdouble]),
DecorateInfo(unittest.skip("Skipped!"), 'TestGradients', 'test_forward_mode_AD',
device_type='cuda', dtypes=[torch.double, torch.cdouble]),
DecorateInfo(unittest.skip("Skipped!"), 'TestGradients', 'test_inplace_forward_mode_AD',
Expand All @@ -9088,9 +9085,6 @@ def ref_pairwise_distance(input1, input2):
assert_autodiffed=True,
rhs_make_tensor_kwargs=dict(exclude_zero=True),
skips=(
# 69913: RuntimeError: CUDA error: an illegal memory access was encountered
DecorateInfo(unittest.skip("Skipped!"), 'TestGradients', 'test_fn_fwgrad_bwgrad',
device_type='cuda', dtypes=[torch.double, torch.cdouble]),
DecorateInfo(unittest.skip("Skipped!"), 'TestGradients', 'test_forward_mode_AD',
device_type='cuda', dtypes=[torch.double, torch.cdouble]),
DecorateInfo(unittest.skip("Skipped!"), 'TestGradients', 'test_inplace_forward_mode_AD',
Expand All @@ -9107,9 +9101,6 @@ def ref_pairwise_distance(input1, input2):
assert_autodiffed=True,
rhs_make_tensor_kwargs=dict(exclude_zero=True),
skips=(
# 69913: RuntimeError: CUDA error: an illegal memory access was encountered
DecorateInfo(unittest.skip("Skipped!"), 'TestGradients', 'test_fn_fwgrad_bwgrad',
device_type='cuda', dtypes=[torch.double, torch.cdouble]),
DecorateInfo(unittest.skip("Skipped!"), 'TestGradients', 'test_forward_mode_AD',
device_type='cuda', dtypes=[torch.double, torch.cdouble]),
DecorateInfo(unittest.skip("Skipped!"), 'TestGradients', 'test_inplace_forward_mode_AD',
Expand Down Expand Up @@ -9690,11 +9681,6 @@ def ref_pairwise_distance(input1, input2):
# RuntimeError:
# Arguments for call are not valid.
DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32, torch.complex64)), # noqa: B950
# 69925: RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!
DecorateInfo(unittest.expectedFailure, 'TestGradients', 'test_fn_fwgrad_bwgrad', device_type='cuda'),
# (ROCm) Memory exception on virtual address 0x7f6f3deb7000, node id 4: Page not present
DecorateInfo(unittest.skip("Skipped! ROCm memory exception"), 'TestGradients', 'test_fn_fwgrad_bwgrad',
device_type='cuda', dtypes=[torch.float64, torch.complex128], active_if=TEST_WITH_ROCM),
),
supports_inplace_autograd=False,
sample_inputs_func=sample_inputs_gradient),
Expand Down Expand Up @@ -14032,15 +14018,7 @@ def ref_pairwise_distance(input1, input2):
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
supports_out=False,
sample_inputs_func=sample_cumulative_trapezoid,
skips=(
# Two failures:
# 1. (CUDA) RuntimeError: Expected all tensors to be on the same device, but found at
# least two devices, cuda:0 and cpu!
# 2. (ROCm) Memory exception on virtual address 0x7f6a2216f000, node id 4: Page not present
DecorateInfo(unittest.skip("Skipped! ROCm memory exception"), 'TestGradients',
'test_fn_fwgrad_bwgrad', device_type='cuda'),
)),
sample_inputs_func=sample_cumulative_trapezoid,),
OpInfo('unsqueeze',
dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
supports_out=False,
Expand Down

0 comments on commit a18cfb7

Please sign in to comment.