Skip to content

Commit

Permalink
[composite compliance] cov, corrcoef (pytorch#82954)
Browse files Browse the repository at this point in the history
Ref: pytorch#69991
Pull Request resolved: pytorch#82954
Approved by: https://github.com/zou3519
  • Loading branch information
kshitij12345 authored and pytorchmergebot committed Aug 26, 2022
1 parent cddf96c commit 65ea3d0
Show file tree
Hide file tree
Showing 6 changed files with 21 additions and 15 deletions.
12 changes: 12 additions & 0 deletions aten/src/ATen/TensorSubclassLikeUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,4 +63,16 @@ inline bool areAnyOptionalTensorSubclassLike(
});
}

// Helper function to deal testing truthfulness of a scalar tensor
// in a Composite Compliant manner.
// NOTE: This function expects a scalar tensor of boolean dtype.
// Eg.
// Non-Composite Compliant Pattern : (t == 0).all().item<bool>()
// Composite Compliant Patter : is_salar_tensor_true((t == 0).all())
inline bool is_scalar_tensor_true(const Tensor& t) {
TORCH_INTERNAL_ASSERT(t.dim() == 0)
TORCH_INTERNAL_ASSERT(t.scalar_type() == kBool)
return at::equal(t, t.new_ones({}, t.options()));
}

} // namespace at
9 changes: 5 additions & 4 deletions aten/src/ATen/native/Correlation.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include <ATen/ATen.h>
#include <ATen/NativeFunctions.h>
#include <ATen/TensorSubclassLikeUtils.h>

namespace at {
namespace native {
Expand Down Expand Up @@ -47,7 +48,7 @@ Tensor cov(
" != ",
num_observations);
TORCH_CHECK(
num_observations == 0 || w.min().ge(0).item<bool>(),
num_observations == 0 || at::is_scalar_tensor_true(w.min().ge(0)),
"cov(): fweights cannot be negative");
}

Expand All @@ -70,7 +71,7 @@ Tensor cov(
" != ",
num_observations);
TORCH_CHECK(
num_observations == 0 || aw.min().ge(0).item<bool>(),
num_observations == 0 || at::is_scalar_tensor_true(aw.min().ge(0)),
"cov(): aweights cannot be negative");
w = w.defined() ? w * aw : aw;
}
Expand All @@ -81,7 +82,7 @@ Tensor cov(
: at::scalar_tensor(num_observations, in.options().dtype(kLong));

TORCH_CHECK(
!w.defined() || w_sum.ne(0).item<bool>(),
!w.defined() || at::is_scalar_tensor_true(w_sum.ne(0)),
"cov(): weights sum to zero, can't be normalized");

const auto avg = (w.defined() ? in * w : in).sum(OBSERVATIONS_DIM) / w_sum;
Expand All @@ -95,7 +96,7 @@ Tensor cov(
norm_factor = w_sum - correction;
}

if (norm_factor.le(0).item<bool>()) {
if (at::is_scalar_tensor_true(norm_factor.le(0))) {
TORCH_WARN("cov(): degrees of freedom is <= 0");
norm_factor.zero_();
}
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/Sorting.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ Tensor quantile_compute(
// synchronizing an accelerator with the CPU
if (self.device().is_cpu()) {
auto all_q_in_range = q.ge(0).logical_and_(q.le(1)).all();
TORCH_CHECK(at::equal(all_q_in_range, all_q_in_range.new_ones({})),
TORCH_CHECK(at::is_scalar_tensor_true(all_q_in_range),
"quantile() q values must be in the range [0, 1]");
}

Expand Down
3 changes: 2 additions & 1 deletion aten/src/ATen/native/SpectralOps.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include <ATen/ATen.h>
#include <ATen/Config.h>
#include <ATen/TensorSubclassLikeUtils.h>
#include <ATen/detail/CUDAHooksInterface.h>
#include <ATen/native/SpectralOpsUtils.h>
#include <ATen/native/TensorIterator.h>
Expand Down Expand Up @@ -1101,7 +1102,7 @@ Tensor istft(const Tensor& self, const int64_t n_fft, const optional<int64_t> ho
y = y.slice(2, start, end, 1);
window_envelop = window_envelop.slice(2, start, end, 1);
const auto window_envelop_lowest = window_envelop.abs().min().lt(1e-11);
if (at::equal(window_envelop_lowest, window_envelop_lowest.new_ones({}))) {
if (at::is_scalar_tensor_true(window_envelop_lowest)) {
std::ostringstream ss;
REPR(ss) << "window overlap add min: " << window_envelop_lowest;
AT_ERROR(ss.str());
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/autograd/FunctionsManual.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3489,7 +3489,7 @@ Tensor eig_backward(
}
// No support for complex eigenvalues for real inputs yet.
TORCH_CHECK(
at::equal(is_imag_eigvals_zero, is_imag_eigvals_zero.new_ones({})),
at::is_scalar_tensor_true(is_imag_eigvals_zero),
"eig_backward: Backward calculation does not support complex eigenvalues for real inputs at the moment.");
} else {
// torch.eig returns 2d tensors for eigenvalues,
Expand Down
8 changes: 0 additions & 8 deletions torch/testing/_internal/common_methods_invocations.py
Original file line number Diff line number Diff line change
Expand Up @@ -8694,8 +8694,6 @@ def reference_flatten(input, start_dim=0, end_dim=-1):
'TestSchemaCheckModeOpInfo',
'test_schema_correctness',
dtypes=(torch.complex64, torch.complex128)),
# Pre-existing condition (calls .item); needs to be fixed
DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_operator'),
),
supports_out=False),
UnaryUfuncInfo('cos',
Expand Down Expand Up @@ -8773,12 +8771,6 @@ def reference_flatten(input, start_dim=0, end_dim=-1):
'TestSchemaCheckModeOpInfo',
'test_schema_correctness',
dtypes=(torch.complex64, torch.complex128)),
# Pre-existing condition (calls .item); needs to be fixed
DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_operator'),
# Pre-existing condition (calls .item); needs to be fixed
DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_backward'),
# Pre-existing condition (calls .item); needs to be fixed
DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_forward_ad'),
# Float did not match double
DecorateInfo(unittest.expectedFailure, 'TestGradients', 'test_fn_grad'),
# Jacobian mismatch
Expand Down

0 comments on commit 65ea3d0

Please sign in to comment.