From ad44670fa1ce2dad7e2cdc3f90d27668e88e9548 Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Mon, 29 Aug 2022 06:08:43 -0700 Subject: [PATCH] Back out "Revert D38984222: Don't introduce new overload for SymInt (#83628)" (#84173) Also Back out "Revert D39075159: [acc_tensor] Use SymIntArrayRef for overloaded empty.memory_format's signature" Original commit changeset: dab4a9dba4fa Original commit changeset: dcaf16c037a9 Original Phabricator Diff: D38984222 Original Phabricator Diff: D39075159 Also update Metal registrations for C++ registration changes. Also update NNPI registration to account for tightened schema checking Differential Revision: [D39084762](https://our.internmc.facebook.com/intern/diff/D39084762/) **NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D39084762/)! Pull Request resolved: https://github.com/pytorch/pytorch/pull/84173 Approved by: https://github.com/Krovatkin --- .github/ci_commit_pins/xla.txt | 2 +- aten/src/ATen/BatchingRegistrations.cpp | 16 +-- aten/src/ATen/FunctionalInverses.cpp | 19 +-- aten/src/ATen/core/NamedRegistrations.cpp | 1 - .../impl/make_boxed_from_unboxed_functor.h | 9 +- aten/src/ATen/core/custom_class.cpp | 1 + aten/src/ATen/core/dispatch/OperatorEntry.cpp | 6 +- aten/src/ATen/core/dynamic_type.cpp | 4 - aten/src/ATen/core/dynamic_type.h | 3 +- aten/src/ATen/core/function_schema.cpp | 16 +++ aten/src/ATen/core/function_schema.h | 6 +- aten/src/ATen/core/jit_type.h | 134 ++++++++++++------ .../core/op_registration/infer_schema.cpp | 2 +- .../ATen/core/op_registration/infer_schema.h | 8 +- .../src/ATen/native/MathBitFallThroughLists.h | 1 - aten/src/ATen/native/MetaTensor.cpp | 12 -- aten/src/ATen/native/SummaryOps.cpp | 6 +- aten/src/ATen/native/TensorFactories.cpp | 44 +++--- aten/src/ATen/native/TensorShape.cpp | 47 +++--- aten/src/ATen/native/cuda/SummaryOps.cu | 12 +- aten/src/ATen/native/cuda/TensorFactories.cu | 4 - aten/src/ATen/native/cudnn/ConvShared.cpp | 4 +- aten/src/ATen/native/metal/MetalAten.mm | 3 +- .../src/ATen/native/metal/ops/MetalReshape.mm | 5 +- aten/src/ATen/native/miopen/Conv_miopen.cpp | 6 +- .../ATen/native/mkldnn/TensorFactories.cpp | 4 - aten/src/ATen/native/mps/TensorFactory.cpp | 11 -- aten/src/ATen/native/native_functions.yaml | 104 +++----------- .../ATen/native/quantized/TensorFactories.cpp | 10 -- .../ATen/native/sparse/SparseCsrTensor.cpp | 10 -- aten/src/ATen/native/sparse/SparseTensor.cpp | 11 -- .../ATen/native/sparse/SparseTensorMath.cpp | 2 +- aten/src/ATen/native/ts_native_functions.yaml | 2 - aten/src/ATen/native/vulkan/ops/Factory.cpp | 10 +- aten/src/ATen/native/vulkan/ops/Shape.cpp | 3 +- .../templates/CompositeViewCopyKernels.cpp | 2 - aten/src/ATen/test/ExclusivelyOwned_test.cpp | 2 +- aten/src/ATen/test/MaybeOwned_test.cpp | 2 +- aten/src/ATen/test/extension_backend_test.cpp | 4 +- aten/src/ATen/test/math_kernel_test.cpp | 2 +- .../functorch/csrc/BatchRulesFactory.cpp | 39 +++-- functorch/functorch/csrc/BatchRulesViews.cpp | 18 ++- test/cpp/jit/test_custom_class.cpp | 14 ++ .../jit/test_custom_class_registrations.cpp | 10 ++ .../cpp/jit/test_custom_class_registrations.h | 5 + test/cpp/lazy/test_lazy_ops.cpp | 29 ---- .../open_registration_extension.cpp | 3 +- test/cpp_extensions/ort_extension.cpp | 10 +- .../check_forward_backward_compatibility.py | 12 ++ test/test_decomp.py | 3 +- test/test_dynamic_shapes.py | 8 +- test/test_meta.py | 1 - test/test_nn.py | 1 + test/test_profiler_tree.py | 22 ++- test/test_proxy_tensor.py | 2 +- tools/autograd/derivatives.yaml | 18 +-- tools/autograd/gen_inplace_or_view_type.py | 5 +- tools/autograd/gen_python_functions.py | 68 ++++++--- tools/autograd/gen_trace_type.py | 6 +- tools/autograd/gen_variable_factories.py | 58 ++++---- tools/autograd/gen_variable_type.py | 13 +- tools/autograd/load_derivatives.py | 6 +- tools/test/test_codegen.py | 1 + torch/_subclasses/fake_tensor.py | 2 +- torch/csrc/jit/codegen/cuda/interface.cpp | 3 +- .../csrc/jit/frontend/schema_type_parser.cpp | 3 +- torch/csrc/jit/python/pybind_utils.cpp | 44 ++++-- torch/csrc/jit/python/pybind_utils.h | 2 +- torch/csrc/jit/runtime/static/ops.cpp | 9 +- torch/csrc/lazy/core/ir_builder.h | 17 +++ .../lazy/ts_backend/ts_native_functions.cpp | 36 ++--- torchgen/api/autograd.py | 4 +- torchgen/api/cpp.py | 82 ++++++++--- torchgen/api/dispatcher.py | 3 +- torchgen/api/lazy.py | 38 ++++- torchgen/api/native.py | 32 +++-- torchgen/api/python.py | 79 +++++++---- torchgen/api/structured.py | 7 +- torchgen/api/translate.py | 10 +- torchgen/api/types.py | 73 +++++++--- torchgen/api/ufunc.py | 7 +- torchgen/api/unboxing.py | 5 +- torchgen/dest/lazy_ir.py | 22 ++- torchgen/dest/register_dispatch_key.py | 43 +++--- torchgen/gen.py | 61 ++++---- torchgen/gen_backend_stubs.py | 1 + torchgen/gen_functionalization_type.py | 44 +----- torchgen/model.py | 77 ++-------- torchgen/static_runtime/generator.py | 7 +- 89 files changed, 864 insertions(+), 749 deletions(-) diff --git a/.github/ci_commit_pins/xla.txt b/.github/ci_commit_pins/xla.txt index 170afa2afb3c5b..e54b44698a6c86 100644 --- a/.github/ci_commit_pins/xla.txt +++ b/.github/ci_commit_pins/xla.txt @@ -1 +1 @@ -9b2f7929c2dae841888a836449c25b04c8cf4045 +95eedc33fb48c2ba72f5efa45daa4941cb069864 diff --git a/aten/src/ATen/BatchingRegistrations.cpp b/aten/src/ATen/BatchingRegistrations.cpp index 28de70636700cf..fab2c9e607625c 100644 --- a/aten/src/ATen/BatchingRegistrations.cpp +++ b/aten/src/ATen/BatchingRegistrations.cpp @@ -186,7 +186,8 @@ Tensor expand_batching_rule(const Tensor& self, IntArrayRef size, bool implicit) } Tensor expand_symint_batching_rule(const Tensor& self, SymIntArrayRef psize, bool implicit) { - return self.expand(asIntArrayRefSlow(psize), implicit); + // TODO: properly support this + return expand_batching_rule(self, asIntArrayRefSlow(psize), implicit); } std::vector chunk_batching_rule(const Tensor& self, int64_t chunks, int64_t dim) { @@ -469,7 +470,8 @@ Tensor view_batching_rule(const Tensor& self, IntArrayRef size) { } Tensor view_symint_batching_rule(const Tensor& self, c10::SymIntArrayRef size) { - return self.view(asIntArrayRefSlow(size)); + // TODO: properly support this + return view_batching_rule(self, asIntArrayRefSlow(size)); } Tensor view_as_complex_batching_rule(const Tensor& self) { @@ -1009,6 +1011,7 @@ Tensor new_empty_symint_batching_rule( c10::optional layout, c10::optional device, c10::optional pin_memory) { + // TODO: properly support this return new_empty_batching_rule(self, asIntArrayRefSlow(size), dtype, layout, device, pin_memory); } @@ -1109,8 +1112,7 @@ TORCH_LIBRARY_IMPL(aten, Batched, m) { m.impl("tensor_split.sections", tensor_split_sections_batching_rule); m.impl("tensor_split.indices", tensor_split_indices_batching_rule); m.impl("diagonal", diagonal_batching_rule); - m.impl("expand", expand_batching_rule); - m.impl("expand.SymInt", expand_symint_batching_rule); + m.impl("expand", expand_symint_batching_rule); m.impl("expand_as", native::expand_as); // composite wrt autograd m.impl("movedim.intlist", movedim_batching_rule); m.impl("movedim.int", static_cast(native::movedim)); // composite wrt autograd @@ -1138,8 +1140,7 @@ TORCH_LIBRARY_IMPL(aten, Batched, m) { m.impl("unbind.int", unbind_batching_rule); m.impl("unfold", unfold_batching_rule); m.impl("unsqueeze", unsqueeze_batching_rule); - m.impl("view", view_batching_rule); - m.impl("view.SymInt", view_symint_batching_rule); + m.impl("view", view_symint_batching_rule); m.impl("view_as", native::view_as); // composite wrt autograd // clamp operations @@ -1277,8 +1278,7 @@ TORCH_LIBRARY_IMPL(aten, Batched, m) { m.impl("diagonal_backward", diagonal_backward_batching_rule); // Tensor.new_* operators - m.impl("new_empty", new_empty_batching_rule); - m.impl("new_empty.SymInt", new_empty_symint_batching_rule); + m.impl("new_empty", new_empty_symint_batching_rule); m.impl("new_empty_strided", new_empty_strided_batching_rule); m.impl("new_zeros", new_zeros_batching_rule); diff --git a/aten/src/ATen/FunctionalInverses.cpp b/aten/src/ATen/FunctionalInverses.cpp index 471c74a73c9524..41c4e22a33deb6 100644 --- a/aten/src/ATen/FunctionalInverses.cpp +++ b/aten/src/ATen/FunctionalInverses.cpp @@ -137,12 +137,8 @@ Tensor FunctionalInverses::diagonal_copy_inverse(const Tensor& base, const Tenso return base.diagonal_scatter(mutated_view, offset, dim1, dim2); } -Tensor FunctionalInverses::expand_copy_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views, at::IntArrayRef size, bool implicit) { - return at::sum_to(mutated_view, base.sizes(),/*always_return_non_view=*/!reapply_views); -} - -Tensor FunctionalInverses::expand_copy_SymInt_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views, c10::SymIntArrayRef size, bool implicit) { - return at::sum_to(mutated_view, c10::asIntArrayRefSlow(base.sym_sizes()),/*always_return_non_view=*/!reapply_views); +Tensor FunctionalInverses::expand_copy_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views, at::SymIntArrayRef size, bool implicit) { + return at::sum_to(mutated_view, base.sym_sizes(),/*always_return_non_view=*/!reapply_views); } Tensor FunctionalInverses::permute_copy_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views, at::IntArrayRef dims) { @@ -291,15 +287,7 @@ Tensor FunctionalInverses::unbind_copy_int_inverse(const Tensor& base, const Ten return base.select_scatter(mutated_view, dim, mutated_view_idx); } -Tensor FunctionalInverses::view_copy_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views, at::IntArrayRef size) { - if (reapply_views) { - return mutated_view.view(base.sizes()); - } else { - return at::view_copy(mutated_view, base.sizes()); - } -} - -Tensor FunctionalInverses::view_copy_SymInt_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views, c10::SymIntArrayRef size) { +Tensor FunctionalInverses::view_copy_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views, at::SymIntArrayRef size) { if (reapply_views) { return mutated_view.view_symint(base.sym_sizes()); } else { @@ -307,6 +295,7 @@ Tensor FunctionalInverses::view_copy_SymInt_inverse(const Tensor& base, const Te } } + Tensor FunctionalInverses::view_copy_dtype_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views, at::ScalarType dtype) { if (reapply_views) { return mutated_view.view(base.scalar_type()); diff --git a/aten/src/ATen/core/NamedRegistrations.cpp b/aten/src/ATen/core/NamedRegistrations.cpp index bb675939b27c67..b78a563b673b0c 100644 --- a/aten/src/ATen/core/NamedRegistrations.cpp +++ b/aten/src/ATen/core/NamedRegistrations.cpp @@ -179,7 +179,6 @@ TORCH_LIBRARY_IMPL(aten, Named, m) { m.impl("exp.out", CppFunction::makeFallthrough()); m.impl("exp_", CppFunction::makeFallthrough()); m.impl("expand", CppFunction::makeFallthrough()); - m.impl("expand.SymInt", CppFunction::makeFallthrough()); m.impl("expm1", CppFunction::makeFallthrough()); m.impl("expm1.out", CppFunction::makeFallthrough()); m.impl("expm1_", CppFunction::makeFallthrough()); diff --git a/aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h b/aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h index 0a28330a0bfb5c..8c5ced3d462bd5 100644 --- a/aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h +++ b/aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h @@ -353,7 +353,14 @@ namespace impl { template struct ivalue_to_arg final { static std::vector call(IValue& v) { - return ivalue_to_arg, AllowDeprecatedTypes>::call(v); + if (v.isIntList()) { + std::vector r; + auto src = v.toIntList(); + std::transform(src.begin(), src.end(), std::back_inserter(r), [](int64_t i) { return c10::SymInt(i); }); + return r; + } else { + return ivalue_to_arg, AllowDeprecatedTypes>::call(v); + } } }; template diff --git a/aten/src/ATen/core/custom_class.cpp b/aten/src/ATen/core/custom_class.cpp index 2bba7e6df62fc7..d719dde6ea0cec 100644 --- a/aten/src/ATen/core/custom_class.cpp +++ b/aten/src/ATen/core/custom_class.cpp @@ -143,6 +143,7 @@ c10::FunctionSchema class_base::withNewArguments( new_args.emplace_back( default_arg.name_, old_arg.type(), + old_arg.real_type(), old_arg.N(), default_arg.value_); } diff --git a/aten/src/ATen/core/dispatch/OperatorEntry.cpp b/aten/src/ATen/core/dispatch/OperatorEntry.cpp index 5c1c42bb62260d..e10de4d9f85b4d 100644 --- a/aten/src/ATen/core/dispatch/OperatorEntry.cpp +++ b/aten/src/ATen/core/dispatch/OperatorEntry.cpp @@ -35,7 +35,11 @@ OperatorEntry::OperatorEntry(OperatorName&& operator_name) namespace { void checkSchema(const OperatorName& name, const FunctionSchema& from_def, const std::string& from_def_debug, const FunctionSchema& inferred, const std::string& inferred_debug) { - c10::optional schema_difference = findSchemaDifferences(from_def, inferred); + // TODO: figure out if we can just directly save real schema at def time + c10::optional schema_difference = findSchemaDifferences( + from_def.cloneWithRealTypes(), + inferred.cloneWithRealTypes() + ); if (schema_difference.has_value()) { TORCH_CHECK(false, "Inferred operator schema for a C++ kernel function doesn't match the expected function schema.\n" diff --git a/aten/src/ATen/core/dynamic_type.cpp b/aten/src/ATen/core/dynamic_type.cpp index 5920d7c05f1fbf..49dd593e38d3f1 100644 --- a/aten/src/ATen/core/dynamic_type.cpp +++ b/aten/src/ATen/core/dynamic_type.cpp @@ -231,8 +231,6 @@ TypePtr DynamicType::fallback() const { return BoolType::get(); case Tag::Int: return IntType::get(); - case Tag::SymInt: - return SymIntType::get(); case Tag::Float: return FloatType::get(); case Tag::Complex: @@ -326,8 +324,6 @@ DynamicType::Ptr IValue::TagType::get(const c10::IValue& v) { return DynamicTypeTrait::getBaseType(); case Tag::Int: return DynamicTypeTrait::getBaseType(); - case Tag::SymInt: - return DynamicTypeTrait::getBaseType(); case Tag::Bool: return DynamicTypeTrait::getBaseType(); case Tag::String: diff --git a/aten/src/ATen/core/dynamic_type.h b/aten/src/ATen/core/dynamic_type.h index a84644ddde0478..1f649c8217cbe2 100644 --- a/aten/src/ATen/core/dynamic_type.h +++ b/aten/src/ATen/core/dynamic_type.h @@ -16,7 +16,6 @@ constexpr DynamicTypeBits kDynamicAnyTypeBit = DYNAMIC_TYPE_BIT(30); constexpr DynamicTypeBits kDynamicNoneTypeBit = DYNAMIC_TYPE_BIT(1); constexpr DynamicTypeBits kDynamicIntTypeBit = DYNAMIC_TYPE_BIT(3); -constexpr DynamicTypeBits kDynamicSymIntTypeBit = DYNAMIC_TYPE_BIT(23); constexpr DynamicTypeBits kDynamicFloatTypeBit = DYNAMIC_TYPE_BIT(4); constexpr DynamicTypeBits kDynamicComplexTypeBit = DYNAMIC_TYPE_BIT(5); constexpr DynamicTypeBits kDynamicListTypeBit = DYNAMIC_TYPE_BIT(7); @@ -29,7 +28,6 @@ constexpr DynamicTypeBits kDynamicClassTypeBit = DYNAMIC_TYPE_BIT(10); _(Bool, DYNAMIC_TYPE_BIT(2), 1) \ _(Int, kDynamicIntTypeBit, 1) \ _(Float, kDynamicFloatTypeBit, 1) \ - _(SymInt, kDynamicSymIntTypeBit, 1) \ _(Complex, kDynamicComplexTypeBit, 1) \ _(Number, \ (kDynamicIntTypeBit | kDynamicFloatTypeBit | kDynamicComplexTypeBit), \ @@ -63,6 +61,7 @@ constexpr DynamicTypeBits kDynamicClassTypeBit = DYNAMIC_TYPE_BIT(10); #define FORALL_DYNAMIC_TYPES_FAKE(_) \ _(ScalarType, kDynamicIntTypeBit, 1) \ _(Layout, kDynamicIntTypeBit, 1) \ + _(SymInt, kDynamicIntTypeBit, 1) \ _(MemoryFormat, kDynamicIntTypeBit, 1) #define FORWARD_DECL_TYPE(NAME, _, __) struct NAME ## Type; diff --git a/aten/src/ATen/core/function_schema.cpp b/aten/src/ATen/core/function_schema.cpp index a3a10862178c8d..00a31224a48393 100644 --- a/aten/src/ATen/core/function_schema.cpp +++ b/aten/src/ATen/core/function_schema.cpp @@ -17,6 +17,22 @@ const std::vector& FunctionSchema::getCorrectList(SchemaArgType type) } } +FunctionSchema FunctionSchema::cloneWithRealTypes() const { + auto cloneWithRealTypes = [](const Argument& a) { + return a.cloneWithType(a.real_type()); + }; + std::vector new_arguments, new_returns; + std::transform(arguments().begin(), arguments().end(), std::back_inserter(new_arguments), cloneWithRealTypes); + std::transform(returns().begin(), returns().end(), std::back_inserter(new_returns), cloneWithRealTypes); + return FunctionSchema( + name(), + overload_name(), + std::move(new_arguments), + std::move(new_returns), + is_vararg(), + is_varret()); +} + bool FunctionSchema::canAliasTypeSetsAlias(const c10::optional &lhs, const c10::optional &rhs) const { if (!lhs || !rhs) { return false; diff --git a/aten/src/ATen/core/function_schema.h b/aten/src/ATen/core/function_schema.h index 16083820a1d818..bafc0d81032036 100644 --- a/aten/src/ATen/core/function_schema.h +++ b/aten/src/ATen/core/function_schema.h @@ -44,7 +44,7 @@ struct Argument { c10::optional alias_info = c10::nullopt) : name_(std::move(name)), type_(fake_type ? std::move(fake_type) : TensorType::get()), - real_type_(real_type ? std::move(real_type) : TensorType::get()), + real_type_(real_type ? std::move(real_type) : type_), N_(std::move(N)), default_value_(std::move(default_value)), alias_info_(alias_info ? std::make_unique(std::move(*alias_info)) : nullptr), @@ -88,6 +88,8 @@ struct Argument { const TypePtr& type() const { return type_; } + // if type() is non-null, this is guaranteed to be non-null (if no real + // type was provided, this takes on type()'s value) const TypePtr& real_type() const { return real_type_; } @@ -472,6 +474,8 @@ struct TORCH_API FunctionSchema { FunctionSchema cloneWithRemappedTypes( const std::function type_map) const; + FunctionSchema cloneWithRealTypes() const; + // Check that inputs have the correct types and appends any missing default // values. template diff --git a/aten/src/ATen/core/jit_type.h b/aten/src/ATen/core/jit_type.h index 50b27a0e8fd8b7..ce698761dad7eb 100644 --- a/aten/src/ATen/core/jit_type.h +++ b/aten/src/ATen/core/jit_type.h @@ -1738,6 +1738,13 @@ struct getTypePtr_ final { } }; +template +struct getMaybeFakeTypePtr_ final { + static decltype(auto) call() { + return getTypePtr_::call(); + } +}; + template <> struct getTypePtr_ final { static decltype(auto) call() { @@ -1783,13 +1790,13 @@ struct getTypePtr_ final { }; template <> -struct getTypePtr_ final { +struct getMaybeFakeTypePtr_ final { static decltype(auto) call() { return SymIntType::get(); } }; template <> -struct getTypePtr_ final { +struct getMaybeFakeTypePtr_ final { static decltype(auto) call() { return IntType::get(); } @@ -1801,18 +1808,6 @@ struct getTypePtr_ final { } }; template <> -struct getTypePtr_ final { - static decltype(auto) call() { - return IntType::get(); - } -}; -template <> -struct getTypePtr_ final { - static decltype(auto) call() { - return IntType::get(); - } -}; -template <> struct getTypePtr_ final { static decltype(auto) call() { return BoolType::get(); @@ -1855,47 +1850,47 @@ struct getTypePtr_ final { return StringType::get(); } }; -template -struct getTypePtr_> final { +template +struct getMaybeFakeTypePtr_, fake> final { static const auto& call() { - static auto inner_type = getTypePtr_::call(); + static auto inner_type = getMaybeFakeTypePtr_::call(); // The "per vector" static singleton needs to live in a .cpp file, // otherwise we'll end up with one singleton instance per shared library. static auto type = ListType::get("vector", inner_type); return type; } }; -template -struct getTypePtr_> final { +template +struct getMaybeFakeTypePtr_, fake> final { static const auto& call() { - static auto inner_type = getTypePtr_::call(); + static auto inner_type = getMaybeFakeTypePtr_::call(); // The "per ArrayRef" static singleton needs to live in a .cpp file, // otherwise we'll end up with one singleton instance per shared library. static auto type = ListType::get("ArrayRef", inner_type); return type; } }; -template <> -struct getTypePtr_ final { +template +struct getMaybeFakeTypePtr_ final { static const auto& call() { - static auto type = ListType::create(getTypePtr_::call()); + static auto type = ListType::create(getMaybeFakeTypePtr_::call()); return type; } }; -template -struct getTypePtr_> final { +template +struct getMaybeFakeTypePtr_, fake> final { static const auto& call() { - static auto inner_type = getTypePtr_::call(); + static auto inner_type = getMaybeFakeTypePtr_::call(); // The "per List" static singleton needs to live in a .cpp file, // otherwise we'll end up with one singleton instance per shared library. static auto type = ListType::get("List", inner_type); return type; } }; -template -struct getTypePtr_> final { +template +struct getMaybeFakeTypePtr_, fake> final { static const auto& call() { - static auto inner_type = getTypePtr_::call(); + static auto inner_type = getMaybeFakeTypePtr_::call(); // The "per array" static singleton needs to live in a .cpp file, // otherwise we'll end up with one singleton instance per shared library. // (Concatenating the length onto the end of the string because we want a unique @@ -1904,22 +1899,22 @@ struct getTypePtr_> final { return type; } }; -template -struct getTypePtr_> final { +template +struct getMaybeFakeTypePtr_, fake> final { static const auto& call() { - static auto inner_key_type = getTypePtr_::call(); - static auto inner_val_type = getTypePtr_::call(); + static auto inner_key_type = getMaybeFakeTypePtr_::call(); + static auto inner_val_type = getMaybeFakeTypePtr_::call(); // The "per unordered_map" static singleton needs to live in a .cpp file, // otherwise we'll end up with one singleton instance per shared library. static auto type = DictType::get("unordered_map", inner_key_type, inner_val_type); return type; } }; -template -struct getTypePtr_> final { +template +struct getMaybeFakeTypePtr_, fake> final { static const auto& call() { - static auto inner_key_type = getTypePtr_::call(); - static auto inner_val_type = getTypePtr_::call(); + static auto inner_key_type = getMaybeFakeTypePtr_::call(); + static auto inner_val_type = getMaybeFakeTypePtr_::call(); // The "per Dict" static singleton needs to live in a .cpp file, // otherwise we'll end up with one singleton instance per shared library. static auto type = DictType::get("Dict", inner_key_type, inner_val_type); @@ -1927,10 +1922,10 @@ struct getTypePtr_> final { } }; -template -struct getTypePtr_> final { +template +struct getMaybeFakeTypePtr_, fake> final { static const auto& call() { - static auto inner_type = getTypePtr_::call(); + static auto inner_type = getMaybeFakeTypePtr_::call(); // The "per optional" static singleton needs to live in a .cpp file, // otherwise we'll end up with one singleton instance per shared library. static auto type = OptionalType::get(inner_type); @@ -1942,17 +1937,17 @@ struct getTypePtr_> final { template<> struct getTypePtr_ final { static const auto& call() { - static auto type = OptionalType::create(getTypePtr_::call()); + static auto type = OptionalType::create(getMaybeFakeTypePtr_::call()); return type; } }; -template -struct getTypePtr_> final { +template +struct getMaybeFakeTypePtr_, fake> final { static const auto& call() { static auto type = ([]() { std::vector contained_types = { - (getTypePtr_::call())... + (getMaybeFakeTypePtr_::call())... }; return TupleType::create(std::move(contained_types)); })(); @@ -1970,7 +1965,7 @@ template inline decltype(auto) getTypePtr() { // TODO: static_assert that a templated function exists, and throw a friendly // error message if not - return detail::getTypePtr_::call(); + return detail::getMaybeFakeTypePtr_::call(); } template @@ -1980,6 +1975,16 @@ inline TypePtr getTypePtrCopy() { return getTypePtr(); } +template +inline decltype(auto) getFakeTypePtr() { + return detail::getMaybeFakeTypePtr_::call(); +} + +template +inline TypePtr getFakeTypePtrCopy() { + return getFakeTypePtr(); +} + using TypeEnv = std::unordered_map; struct MatchTypeReturn { MatchTypeReturn(std::string reason) : reason_(std::move(reason)) {} @@ -2133,6 +2138,45 @@ static LayoutTypePtr get(); LayoutType() : EnumerationType() {} }; +namespace detail { +template <> +struct getMaybeFakeTypePtr_ final { + static decltype(auto) call() { + return ScalarTypeType::get(); + } +}; +template <> +struct getMaybeFakeTypePtr_ final { + static decltype(auto) call() { + return LayoutType::get(); + } +}; +template <> +struct getMaybeFakeTypePtr_ final { + static decltype(auto) call() { + return MemoryFormatType::get(); + } +}; +template <> +struct getMaybeFakeTypePtr_ final { + static decltype(auto) call() { + return IntType::get(); + } +}; +template <> +struct getMaybeFakeTypePtr_ final { + static decltype(auto) call() { + return IntType::get(); + } +}; +template <> +struct getMaybeFakeTypePtr_ final { + static decltype(auto) call() { + return IntType::get(); + } +}; +} // namespace detail + // the common supertype of all lists, // List[T] <: AnyList for all T struct AnyListType; diff --git a/aten/src/ATen/core/op_registration/infer_schema.cpp b/aten/src/ATen/core/op_registration/infer_schema.cpp index df1925aba5ed1a..e9e93a2556e0b9 100644 --- a/aten/src/ATen/core/op_registration/infer_schema.cpp +++ b/aten/src/ATen/core/op_registration/infer_schema.cpp @@ -23,7 +23,7 @@ std::vector createArgumentVector(c10::ArrayRef args) { result.reserve(args.size()); for (const auto i : c10::irange(args.size())) { // Arguments are named "_" - result.emplace_back(fastToString(i), (*args[i].getTypeFn)()); + result.emplace_back(fastToString(i), (*args[i].getFakeTypeFn)(), (*args[i].getTypeFn)()); } return result; } diff --git a/aten/src/ATen/core/op_registration/infer_schema.h b/aten/src/ATen/core/op_registration/infer_schema.h index 7539cd59cac9b2..2938e2a8d564ed 100644 --- a/aten/src/ATen/core/op_registration/infer_schema.h +++ b/aten/src/ATen/core/op_registration/infer_schema.h @@ -22,8 +22,9 @@ namespace infer_schema { struct ArgumentDef final { using GetTypeFn = TypePtr(); GetTypeFn* getTypeFn; - constexpr ArgumentDef(): getTypeFn(nullptr) {} - explicit constexpr ArgumentDef(GetTypeFn *getTypeFn): getTypeFn(getTypeFn) {} + GetTypeFn* getFakeTypeFn; + constexpr ArgumentDef(): getTypeFn(nullptr), getFakeTypeFn(nullptr) {} + explicit constexpr ArgumentDef(GetTypeFn *getTypeFn, GetTypeFn *getFakeTypeFn): getTypeFn(getTypeFn), getFakeTypeFn(getFakeTypeFn) {} }; template @@ -52,7 +53,8 @@ constexpr std::array createArgumentVectorFromTypes(s checkStaticTypes(), // Create the return value - std::array{ArgumentDef(&getTypePtrCopy>)...} + std::array{ + ArgumentDef(&getTypePtrCopy>, &getFakeTypePtrCopy>)...} ); } diff --git a/aten/src/ATen/native/MathBitFallThroughLists.h b/aten/src/ATen/native/MathBitFallThroughLists.h index 025c25bcbe7b98..97b0854d82d0a2 100644 --- a/aten/src/ATen/native/MathBitFallThroughLists.h +++ b/aten/src/ATen/native/MathBitFallThroughLists.h @@ -54,7 +54,6 @@ namespace at { #define TENSOR_UTILITIES_AND_CONSTRUCTORS(m) \ m.impl("empty_like", torch::CppFunction::makeFallthrough()); \ m.impl("empty.memory_format", torch::CppFunction::makeFallthrough()); \ - m.impl("empty.SymInt", torch::CppFunction::makeFallthrough()); \ m.impl("empty.out", torch::CppFunction::makeFallthrough()); \ m.impl("empty_strided", torch::CppFunction::makeFallthrough()); \ m.impl("full_like", torch::CppFunction::makeFallthrough()); \ diff --git a/aten/src/ATen/native/MetaTensor.cpp b/aten/src/ATen/native/MetaTensor.cpp index 0b3bb3e04c7b94..a58b18c786e8c4 100644 --- a/aten/src/ATen/native/MetaTensor.cpp +++ b/aten/src/ATen/native/MetaTensor.cpp @@ -13,18 +13,6 @@ namespace at { namespace native { Tensor empty_meta( - IntArrayRef size, - c10::optional dtype_opt, - c10::optional layout_opt, - c10::optional device_opt, - c10::optional pin_memory_opt, - c10::optional memory_format_opt -) { - return at::detail::empty_meta( - size, dtype_opt, layout_opt, device_opt, pin_memory_opt, memory_format_opt); -} - -Tensor empty_symint_meta( SymIntArrayRef size, c10::optional dtype_opt, c10::optional layout_opt, diff --git a/aten/src/ATen/native/SummaryOps.cpp b/aten/src/ATen/native/SummaryOps.cpp index cf86225460ea0a..fdbd2a9fc82291 100644 --- a/aten/src/ATen/native/SummaryOps.cpp +++ b/aten/src/ATen/native/SummaryOps.cpp @@ -20,7 +20,7 @@ Tensor _bincount_cpu_template( AT_ERROR("minlength should be >= 0"); } if (self.dim() == 1 && self.numel() == 0) { - return native::zeros({minlength}, kLong); + return at::zeros({minlength}, kLong); } if (self.dim() != 1 || *self.min().data_ptr() < 0) { AT_ERROR("bincount only supports 1-d non-negative integral inputs."); @@ -38,7 +38,7 @@ Tensor _bincount_cpu_template( const input_t* self_p = self.data_ptr(); if (has_weights) { - output = native::zeros( + output = at::zeros( {nbins}, optTypeMetaToScalarType(weights.options().dtype_opt()), weights.options().layout_opt(), @@ -50,7 +50,7 @@ Tensor _bincount_cpu_template( output_p[self_p[i]] += weights_p[i]; } } else { - output = native::zeros({nbins}, kLong); + output = at::zeros({nbins}, kLong); int64_t* output_p = output.data_ptr(); for (const auto i : c10::irange(self_size)) { output_p[self_p[i]] += 1L; diff --git a/aten/src/ATen/native/TensorFactories.cpp b/aten/src/ATen/native/TensorFactories.cpp index c9cc522e06b837..6ccbbbac03a750 100644 --- a/aten/src/ATen/native/TensorFactories.cpp +++ b/aten/src/ATen/native/TensorFactories.cpp @@ -186,12 +186,7 @@ Tensor empty_cpu(IntArrayRef size, c10::optional dtype_opt, c10::opt return at::detail::empty_cpu(size, dtype_opt, layout_opt, device_opt, pin_memory_opt, memory_format_opt); } -Tensor empty_symint_cpu(c10::SymIntArrayRef size, c10::optional dtype_opt, c10::optional layout_opt, - c10::optional device_opt, c10::optional pin_memory_opt, c10::optional memory_format_opt) { - return at::native::empty_cpu(c10::asIntArrayRefSlow(size), dtype_opt, layout_opt, device_opt, pin_memory_opt, memory_format_opt); -} - -Tensor empty( +Tensor empty_names( IntArrayRef size, c10::optional names, c10::optional dtype, @@ -219,9 +214,12 @@ Tensor empty_strided_cpu(IntArrayRef size, IntArrayRef stride, c10::optional optional_memory_format, Tensor& result) { + // TODO: support empty_out properly (I was forced to change this immediately + // with empty so that empty/empty.out had the same type signature) + auto size = c10::asIntArrayRefSlow(sym_size); // Preferably, this argument would not be accepted by _out, but the code // generator requires the out and non-out overloads to match exactly TORCH_CHECK( @@ -389,17 +387,6 @@ Tensor empty_like_quantized( } Tensor new_empty( - const Tensor& self, - IntArrayRef size, - c10::optional dtype_opt, - c10::optional layout_opt, - c10::optional device_opt, - c10::optional pin_memory_opt - ) { - return self.new_empty_symint(c10::SymIntArrayRef::fromIntArrayRef(size), dtype_opt, layout_opt, device_opt, pin_memory_opt); -} - -Tensor new_empty_symint( const Tensor& self, SymIntArrayRef size, c10::optional dtype_opt, @@ -1090,15 +1077,7 @@ Tensor triu_indices_cpu( // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ zeros ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -Tensor zeros(IntArrayRef size, - c10::optional dtype, - c10::optional layout, - c10::optional device, - c10::optional pin_memory) { - return at::zeros_symint(c10::SymIntArrayRef::fromIntArrayRef(size), dtype, layout, device, pin_memory); -} - -Tensor zeros_symint(SymIntArrayRef size, +Tensor zeros(SymIntArrayRef size, c10::optional dtype, c10::optional layout, c10::optional device, @@ -1123,8 +1102,17 @@ Tensor _efficientzerotensor(IntArrayRef size, return out; } -Tensor& zeros_out(IntArrayRef size, Tensor& result) { +Tensor& zeros_sparse_out(IntArrayRef size, Tensor& result) { + result.sparse_resize_and_clear_(size, size.size(), 0.); + return result; +} + +Tensor& zeros_out(SymIntArrayRef sym_size, Tensor& result) { + auto size = c10::asIntArrayRefSlow(sym_size); if (result.is_sparse()) { + // TODO: I think this branch should be dead, but we don't have an easy + // way to cover all sparse kernels with zeros_sparse_out, so retain this + // for now result.sparse_resize_and_clear_(size, size.size(), 0.); return result; } else { diff --git a/aten/src/ATen/native/TensorShape.cpp b/aten/src/ATen/native/TensorShape.cpp index 85fb9fb627efbe..6d332c3cc158d7 100644 --- a/aten/src/ATen/native/TensorShape.cpp +++ b/aten/src/ATen/native/TensorShape.cpp @@ -300,7 +300,7 @@ Tensor sparse_broadcast_to(const Tensor& self, IntArrayRef size) { new_values_size[0] = new_indices_size[1]; Tensor new_values = values.expand(broadcast_dense_sizes).repeat_interleave(nnz_factor, 0); - Tensor new_indices = at::native::new_empty(indices, new_indices_size); + Tensor new_indices = indices.new_empty(new_indices_size); if (broadcast_sizes.size()>0) { // ones(broadcast_sizes).nonzero() is equivalent to // product(map(arange, broadcast_sizes)) but avoids creating @@ -542,14 +542,14 @@ static Tensor cat_sparse_impl(TensorList tensors, int64_t dim) { zeros_sizes[0] = t._values().size(0); zeros_sizes[values_dim] = cumulative_size; cumulative_size += t._values().size(values_dim); - auto z1 = native::zeros( + auto z1 = at::zeros( zeros_sizes, optTypeMetaToScalarType(t._values().options().dtype_opt()), t._values().options().layout_opt(), t._values().options().device_opt(), t._values().options().pinned_memory_opt()); zeros_sizes[values_dim] = total_size - cumulative_size; - auto z2 = native::zeros( + auto z2 = at::zeros( zeros_sizes, optTypeMetaToScalarType(t._values().options().dtype_opt()), t._values().options().layout_opt(), @@ -843,12 +843,9 @@ Tensor diag_embed(const Tensor& self, int64_t offset, int64_t dim1_, int64_t dim return result; } -Tensor expand_symint(const Tensor& self, c10::SymIntArrayRef packed_size, bool implicit) { - auto size = asIntArrayRefSlow(packed_size); - return self.expand(size, implicit); -} - -Tensor expand(const Tensor& self, IntArrayRef size, bool /*unused*/) { +Tensor expand(const Tensor& self, c10::SymIntArrayRef sym_size, bool /*unused*/) { + // TODO: properly support SymInt expand + auto size = asIntArrayRefSlow(sym_size); TORCH_CHECK(size.size() >= (size_t)self.dim(), "expand(", self.toString(), "{", self.sizes(), "}, size=", size, "): the number of sizes provided (", size.size(), ") ", @@ -927,12 +924,9 @@ const Tensor &as_strided_(const Tensor& self, IntArrayRef size, IntArrayRef stri return self; } -Tensor narrow_copy_symint(const Tensor& self, int64_t dim, int64_t start, SymInt sym_length) { - return self.narrow_copy(dim, start, sym_length.expect_int()); -} - -Tensor narrow_copy_dense(const Tensor& self, int64_t dim, int64_t start, int64_t length) { - return self.narrow(dim, start, length).clone(at::MemoryFormat::Contiguous); +Tensor narrow_copy_dense(const Tensor& self, int64_t dim, SymInt start, SymInt length) { + // TODO: properly support SymInt narrow_copy + return self.narrow(dim, start.expect_int(), length.expect_int()).clone(at::MemoryFormat::Contiguous); } Tensor narrow_copy_dense_cpu(const Tensor& self, int64_t dim, int64_t start, int64_t length){ @@ -2782,7 +2776,7 @@ Tensor unsqueeze_sparse(Tensor const &self, int64_t dim) { if (dim <= sparse_dim) { auto new_indices = at::cat( {indices.narrow(0, 0, dim), - native::zeros( + at::zeros( {1, indices.size(1)}, kLong, indices.options().layout_opt(), @@ -3118,14 +3112,15 @@ Tensor adjoint(const Tensor &self) { return _adjoint(self, /*transpose=*/false, "adjoint()"); } -Tensor view(const Tensor& self, - IntArrayRef size) { - return view_impl(self, size); +Tensor view_meta(const Tensor& self, + at::SymIntArrayRef size) { + // TODO: Properly support SymInt view + return view_impl(self, c10::asIntArrayRefSlow(size)); } -Tensor view_symint(const Tensor& self, - c10::SymIntArrayRef size) { - return self.view(c10::asIntArrayRefSlow(size)); +Tensor view(const Tensor& self, + at::IntArrayRef size) { + return view_impl(self, size); } Tensor alias(const Tensor& self) { @@ -3505,8 +3500,8 @@ at::Tensor& expand_copy_SymInt_out(const at::Tensor & self, c10::SymIntArrayRef } -at::Tensor& expand_copy_out(const at::Tensor & self, at::IntArrayRef size, bool implicit, at::Tensor & out) { - auto tmp = self.expand(size, implicit); +at::Tensor& expand_copy_out(const at::Tensor & self, at::SymIntArrayRef size, bool implicit, at::Tensor & out) { + auto tmp = self.expand_symint(size, implicit); out.copy_(tmp); return out; } @@ -3661,8 +3656,8 @@ void unbind_copy_int_out(const at::Tensor & self, int64_t dim, at::TensorList o } -at::Tensor& view_copy_out(const at::Tensor & self, at::IntArrayRef size, at::Tensor & out) { - auto tmp = self.view(size); +at::Tensor& view_copy_out(const at::Tensor & self, at::SymIntArrayRef size, at::Tensor & out) { + auto tmp = self.view_symint(size); out.copy_(tmp); return out; } diff --git a/aten/src/ATen/native/cuda/SummaryOps.cu b/aten/src/ATen/native/cuda/SummaryOps.cu index 5476682d7c4d97..eb47556a6fe7a5 100644 --- a/aten/src/ATen/native/cuda/SummaryOps.cu +++ b/aten/src/ATen/native/cuda/SummaryOps.cu @@ -15,7 +15,7 @@ #include #include #include -#include +#include #endif namespace at { @@ -271,7 +271,7 @@ bool CUDA_tensor_histogram( detail::TensorInfo pInfo(nullptr, 0, {}, {}); Tensor partial_output; if (memType == CUDAHistogramMemoryType::MULTI_BLOCK) { - partial_output = native::zeros( + partial_output = at::zeros( {grid.x, nbins}, optTypeMetaToScalarType(a.options().dtype_opt()), a.options().layout_opt(), @@ -313,7 +313,7 @@ Tensor _bincount_cuda_template( AT_ERROR("minlength should be >= 0"); } if (self.dim() == 1 && self.numel() == 0) { - return native::zeros( + return at::zeros( {minlength}, kLong, c10::nullopt /* layout */, @@ -342,7 +342,7 @@ Tensor _bincount_cuda_template( // alloc output counter on GPU Tensor output; if (has_weights) { - output = native::zeros( + output = at::zeros( {nbins}, optTypeMetaToScalarType(weights.options().dtype_opt()), weights.options().layout_opt(), @@ -351,7 +351,7 @@ Tensor _bincount_cuda_template( cuda::CUDA_tensor_histogram( output, self, weights, nbins, minvalue, maxvalue); } else { - output = native::zeros( + output = at::zeros( {nbins}, kLong, c10::nullopt /* layout */, @@ -373,7 +373,7 @@ Tensor _histc_cuda_template( if (nbins <= 0) { AT_ERROR("bins must be > 0"); } - Tensor output = native::zeros( + Tensor output = at::zeros( {nbins}, self.scalar_type(), c10::nullopt /* layout */, diff --git a/aten/src/ATen/native/cuda/TensorFactories.cu b/aten/src/ATen/native/cuda/TensorFactories.cu index 03711b194a983d..e880b21d650dea 100644 --- a/aten/src/ATen/native/cuda/TensorFactories.cu +++ b/aten/src/ATen/native/cuda/TensorFactories.cu @@ -55,10 +55,6 @@ Tensor empty_cuda(IntArrayRef size, c10::optional dtype_opt, c10::op return at::detail::empty_cuda(size, dtype_opt, layout_opt, device_opt, pin_memory_opt, memory_format_opt); } -Tensor empty_symint_cuda(c10::SymIntArrayRef size, c10::optional dtype_opt, c10::optional layout_opt, c10::optional device_opt, c10::optional pin_memory_opt, c10::optional memory_format_opt) { - return at::native::empty_cuda(asIntArrayRefSlow(size), dtype_opt, layout_opt, device_opt, pin_memory_opt, memory_format_opt); -} - Tensor _efficientzerotensor_cuda(IntArrayRef size, c10::optional dtype, c10::optional layout, diff --git a/aten/src/ATen/native/cudnn/ConvShared.cpp b/aten/src/ATen/native/cudnn/ConvShared.cpp index 9f921faf0320d3..afca16ea4af8a8 100644 --- a/aten/src/ATen/native/cudnn/ConvShared.cpp +++ b/aten/src/ATen/native/cudnn/ConvShared.cpp @@ -436,7 +436,7 @@ Tensor cudnn_convolution_relu( bool allow_tf32 = ctx.allowTF32CuDNN(); auto _bias = bias_t.has_value() ? bias_t.value() - : at::native::zeros( + : at::zeros( {output_t.size(1)}, optTypeMetaToScalarType(output_t.options().dtype_opt()), output_t.options().layout_opt(), @@ -514,7 +514,7 @@ Tensor cudnn_convolution_add_relu( auto _alpha = alpha.has_value() ? alpha.value().to() : 1.0; auto _bias = bias_t.has_value() ? bias_t.value() - : at::native::zeros( + : at::zeros( {output_t.size(1)}, optTypeMetaToScalarType(output_t.options().dtype_opt()), output_t.options().layout_opt(), diff --git a/aten/src/ATen/native/metal/MetalAten.mm b/aten/src/ATen/native/metal/MetalAten.mm index c1c34217c3740e..f100f473f05506 100644 --- a/aten/src/ATen/native/metal/MetalAten.mm +++ b/aten/src/ATen/native/metal/MetalAten.mm @@ -70,12 +70,13 @@ #pragma mark - ATen Ops Tensor empty( - IntArrayRef size, + c10::SymIntArrayRef sym_size, optional dtype, optional layout, optional device, optional pin_memory, c10::optional memory_format) { + auto size = c10::asIntArrayRefSlow(sym_size); TORCH_CHECK( !pin_memory.has_value(), "'pin_memory' argument is incompatible with Metal tensor"); diff --git a/aten/src/ATen/native/metal/ops/MetalReshape.mm b/aten/src/ATen/native/metal/ops/MetalReshape.mm index 37842ee3be59b9..1001b6690ad80c 100644 --- a/aten/src/ATen/native/metal/ops/MetalReshape.mm +++ b/aten/src/ATen/native/metal/ops/MetalReshape.mm @@ -16,7 +16,8 @@ namespace metal { API_AVAILABLE(ios(11.0), macos(10.13)) -Tensor view(const Tensor& input, IntArrayRef size) { +Tensor view(const Tensor& input, c10::SymIntArrayRef sym_size) { + auto size = c10::asIntArrayRefSlow(sym_size); TORCH_CHECK(input.is_metal()); auto inferred_size = at::infer_size(size, input.numel()); auto stride = @@ -63,7 +64,7 @@ Tensor view(const Tensor& input, IntArrayRef size) { Tensor reshape(const Tensor& input, IntArrayRef shape) { TORCH_CHECK(input.is_metal()); - return view(input, shape); + return view(input, c10::SymIntArrayRef::fromIntArrayRef(shape)); } Tensor flatten_using_ints( diff --git a/aten/src/ATen/native/miopen/Conv_miopen.cpp b/aten/src/ATen/native/miopen/Conv_miopen.cpp index be92f5a311a558..3d15c9e914d0e9 100644 --- a/aten/src/ATen/native/miopen/Conv_miopen.cpp +++ b/aten/src/ATen/native/miopen/Conv_miopen.cpp @@ -1570,7 +1570,7 @@ Tensor miopen_convolution_add_relu( auto _alpha = alpha.has_value() ? alpha.value().to() : 1.0; auto _bias = bias.has_value() ? bias.value() - : at::native::zeros( + : at::zeros( {contig_output.size(1)}, optTypeMetaToScalarType(contig_output.options().dtype_opt()), contig_output.options().layout_opt(), @@ -1614,7 +1614,7 @@ Tensor miopen_convolution_relu( auto _bias = bias.has_value() ? bias.value() - : at::native::zeros( + : at::zeros( {output_t.size(1)}, optTypeMetaToScalarType(output_t.options().dtype_opt()), output_t.options().layout_opt(), @@ -1661,7 +1661,7 @@ Tensor miopen_convolution_relu( auto _bias = bias.has_value() ? bias.value() - : at::native::zeros( + : at::zeros( {contig_output.size(1)}, optTypeMetaToScalarType(contig_output.options().dtype_opt()), contig_output.options().layout_opt(), diff --git a/aten/src/ATen/native/mkldnn/TensorFactories.cpp b/aten/src/ATen/native/mkldnn/TensorFactories.cpp index a944d4db19b62b..dc34281d25cac1 100644 --- a/aten/src/ATen/native/mkldnn/TensorFactories.cpp +++ b/aten/src/ATen/native/mkldnn/TensorFactories.cpp @@ -2,10 +2,6 @@ namespace at { namespace native { -Tensor empty_symint_mkldnn(c10::SymIntArrayRef sizes, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory, c10::optional optional_memory_format) { - return at::native::empty_mkldnn(c10::asIntArrayRefSlow(sizes), dtype, layout, device, pin_memory, optional_memory_format); -} - #if AT_MKLDNN_ENABLED() Tensor empty_mkldnn(IntArrayRef sizes, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory, c10::optional optional_memory_format) { diff --git a/aten/src/ATen/native/mps/TensorFactory.cpp b/aten/src/ATen/native/mps/TensorFactory.cpp index d280da4d9c650a..78899fc8fa3c3e 100644 --- a/aten/src/ATen/native/mps/TensorFactory.cpp +++ b/aten/src/ATen/native/mps/TensorFactory.cpp @@ -71,17 +71,6 @@ Tensor empty_mps( return at::detail::empty_mps(size, dtype_opt, layout_opt, device_opt, pin_memory_opt, memory_format_opt); } -Tensor empty_symint_mps( - c10::SymIntArrayRef size, - c10::optional dtype_opt, - c10::optional layout_opt, - c10::optional device_opt, - c10::optional pin_memory_opt, - c10::optional memory_format_opt) { - - return at::native::empty_mps(c10::asIntArrayRefSlow(size), dtype_opt, layout_opt, device_opt, pin_memory_opt, memory_format_opt); -} - Tensor empty_strided_mps( IntArrayRef size, IntArrayRef stride, diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 848a84a7a55e1f..c02c48b4073594 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -2046,10 +2046,10 @@ device_check: NoCheck device_guard: False dispatch: - CompositeExplicitAutograd: empty + CompositeExplicitAutograd: empty_names autogen: empty.names_out -- func: empty.memory_format(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor +- func: empty.memory_format(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor dispatch: CPU: empty_cpu CUDA: empty_cuda @@ -2060,39 +2060,14 @@ SparseCsrCPU, SparseCsrCUDA: empty_sparse_compressed QuantizedCPU, QuantizedCUDA, QuantizedMeta: empty_unknown_quantized -# all calls to empty() in python used to go through the symint overload -# even if all arguments were concerete integers. -# adding symint overloads of kernels for every dispatch key allowed us -# to skip redispatching to `empty.memory_format` and hit backend kernels directly -# we recently updated signature parsing to dispath `empty()` calls in python -# to `empty.SymInt` iff there's is a symint node argument -# hopefully, we could simplify this entry soon -- func: empty.SymInt(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor - dispatch: - CPU: empty_symint_cpu - CUDA: empty_symint_cuda - MPS: empty_symint_mps - Meta: empty_symint_meta - MkldnnCPU: empty_symint_mkldnn - SparseCPU, SparseCUDA, SparseMeta: empty_symint_sparse - SparseCsrCPU, SparseCsrCUDA: empty_symint_sparse_compressed - QuantizedCPU, QuantizedCUDA: empty_symint_unknown_quantized - autogen: empty.SymInt_out - # We do not make new_empty a composite that calls into new_empty_strided, as the strided version # is significantly more difficult to implement by different backends -- func: new_empty(Tensor self, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor +- func: new_empty(Tensor self, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor variants: method dispatch: CompositeExplicitAutograd: new_empty autogen: new_empty.out -- func: new_empty.SymInt(Tensor self, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor - variants: method - dispatch: - CompositeExplicitAutograd: new_empty_symint - autogen: new_empty.SymInt_out - - func: new_empty_strided(Tensor self, int[] size, int[] stride, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor variants: method dispatch: @@ -2170,7 +2145,7 @@ QuantizedCPU, QuantizedCUDA: empty_quantized autogen: empty_quantized.out -- func: empty.out(int[] size, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!) +- func: empty.out(SymInt[] size, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck device_guard: False @@ -2294,14 +2269,7 @@ SparseCPU, SparseCUDA: expm1_sparse_out SparseCsrCPU, SparseCsrCUDA: expm1_sparse_csr_out -- func: expand.SymInt(Tensor(a) self, SymInt[] size, *, bool implicit=False) -> Tensor(a) - variants: method # This is method-only to match the previous tensor API. In the future we could make this a function too. - device_check: NoCheck - device_guard: False - dispatch: - CompositeExplicitAutograd: expand_symint - -- func: expand(Tensor(a) self, int[] size, *, bool implicit=False) -> Tensor(a) +- func: expand(Tensor(a) self, SymInt[] size, *, bool implicit=False) -> Tensor(a) variants: method # This is method-only to match the previous tensor API. In the future we could make this a function too. device_check: NoCheck device_guard: False @@ -3697,7 +3665,7 @@ dispatch: CompositeExplicitAutograd: mvlgamma_ -- func: narrow_copy(Tensor self, int dim, int start, int length) -> Tensor +- func: narrow_copy(Tensor self, int dim, SymInt start, SymInt length) -> Tensor variants: function, method dispatch: CPU: narrow_copy_dense_cpu @@ -3705,13 +3673,7 @@ CompositeExplicitAutogradNonFunctional: narrow_copy_dense tags: view_copy -- func: narrow_copy.SymInt(Tensor self, int dim, int start, SymInt length) -> Tensor - variants: function, method - dispatch: - CompositeExplicitAutograd: narrow_copy_symint - autogen: narrow_copy.SymInt_out - -- func: narrow_copy.out(Tensor self, int dim, int start, int length, *, Tensor(a!) out) -> Tensor(a!) +- func: narrow_copy.out(Tensor self, int dim, SymInt start, SymInt length, *, Tensor(a!) out) -> Tensor(a!) dispatch: CPU: narrow_copy_dense_cpu_out @@ -5581,19 +5543,14 @@ CUDA: _efficientzerotensor_cuda autogen: _efficientzerotensor.out -- func: zeros(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor +- func: zeros(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor dispatch: CompositeExplicitAutograd: zeros -- func: zeros.SymInt(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor - dispatch: - CompositeExplicitAutograd: zeros_symint - autogen: zeros.SymInt_out - -- func: zeros.out(int[] size, *, Tensor(a!) out) -> Tensor(a!) +- func: zeros.out(SymInt[] size, *, Tensor(a!) out) -> Tensor(a!) dispatch: CompositeExplicitAutograd: zeros_out - SparseCPU, SparseCUDA, SparseMeta: zeros_out + SparseCPU, SparseCUDA, SparseMeta: zeros_sparse_out - func: zeros_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor dispatch: @@ -6920,20 +6877,13 @@ CPU: masked_softmax_backward_cpu autogen: _masked_softmax_backward.out -- func: view.SymInt(Tensor(a) self, SymInt[] size) -> Tensor(a) - variants: method - device_check: NoCheck - device_guard: False - dispatch: - CompositeExplicitAutograd: view_symint - MkldnnCPU: mkldnn_view_symint - -- func: view(Tensor(a) self, int[] size) -> Tensor(a) +- func: view(Tensor(a) self, SymInt[] size) -> Tensor(a) variants: method device_check: NoCheck device_guard: False dispatch: - ZeroTensor, CPU, CUDA, Meta, QuantizedCPU, QuantizedCUDA, MPS: view + Meta: view_meta + ZeroTensor, CPU, CUDA, QuantizedCPU, QuantizedCUDA, MPS: view MkldnnCPU: mkldnn_view # Warning: If you want to change the name or overload name of this @@ -12765,18 +12715,12 @@ CompositeExplicitAutogradNonFunctional: diagonal_copy tags: view_copy -- func: expand_copy(Tensor self, int[] size, *, bool implicit=False) -> Tensor +- func: expand_copy(Tensor self, SymInt[] size, *, bool implicit=False) -> Tensor variants: function dispatch: CompositeExplicitAutogradNonFunctional: expand_copy tags: view_copy -- func: expand_copy.SymInt(Tensor self, SymInt[] size, *, bool implicit=False) -> Tensor - variants: function - dispatch: - CompositeExplicitAutograd: expand_copy_SymInt - tags: view_copy - - func: permute_copy(Tensor self, int[] dims) -> Tensor variants: function dispatch: @@ -12905,7 +12849,7 @@ CompositeExplicitAutogradNonFunctional: unbind_copy_int tags: view_copy -- func: view_copy(Tensor self, int[] size) -> Tensor +- func: view_copy(Tensor self, SymInt[] size) -> Tensor variants: function dispatch: CompositeExplicitAutogradNonFunctional: view_copy @@ -12965,14 +12909,6 @@ CompositeExplicitAutograd: _neg_view_copy_out -- func: view_copy.SymInt(Tensor self, SymInt[] size) -> Tensor - variants: function - dispatch: - CompositeExplicitAutograd: view_copy_SymInt - tags: view_copy - autogen: view_copy.SymInt_out - - - func: as_strided_copy.out(Tensor self, int[] size, int[] stride, int? storage_offset=None, *, Tensor(a!) out) -> Tensor(a!) variants: function dispatch: @@ -12991,13 +12927,7 @@ CompositeExplicitAutograd: diagonal_copy_out -- func: expand_copy.SymInt_out(Tensor self, SymInt[] size, *, bool implicit=False, Tensor(a!) out) -> Tensor(a!) - variants: function - dispatch: - CompositeExplicitAutograd: expand_copy_SymInt_out - - -- func: expand_copy.out(Tensor self, int[] size, *, bool implicit=False, Tensor(a!) out) -> Tensor(a!) +- func: expand_copy.out(Tensor self, SymInt[] size, *, bool implicit=False, Tensor(a!) out) -> Tensor(a!) variants: function dispatch: CompositeExplicitAutograd: expand_copy_out @@ -13117,7 +13047,7 @@ CompositeExplicitAutograd: unbind_copy_int_out -- func: view_copy.out(Tensor self, int[] size, *, Tensor(a!) out) -> Tensor(a!) +- func: view_copy.out(Tensor self, SymInt[] size, *, Tensor(a!) out) -> Tensor(a!) variants: function dispatch: CompositeExplicitAutograd: view_copy_out diff --git a/aten/src/ATen/native/quantized/TensorFactories.cpp b/aten/src/ATen/native/quantized/TensorFactories.cpp index 66c48f4ce75286..aa0fef5df9dc02 100644 --- a/aten/src/ATen/native/quantized/TensorFactories.cpp +++ b/aten/src/ATen/native/quantized/TensorFactories.cpp @@ -66,16 +66,6 @@ Tensor empty_per_channel_affine_quantized( quantizer); } -Tensor empty_symint_unknown_quantized( - c10::SymIntArrayRef size, - c10::optional dtype, - c10::optional layout, - c10::optional device, - c10::optional pin_memory, - c10::optional optional_memory_format) { - return at::native::empty_unknown_quantized(c10::asIntArrayRefSlow(size), dtype, layout, device, pin_memory, optional_memory_format); -} - Tensor empty_unknown_quantized( IntArrayRef size, c10::optional dtype, diff --git a/aten/src/ATen/native/sparse/SparseCsrTensor.cpp b/aten/src/ATen/native/sparse/SparseCsrTensor.cpp index 062cc3d1262935..7ed70e99add113 100644 --- a/aten/src/ATen/native/sparse/SparseCsrTensor.cpp +++ b/aten/src/ATen/native/sparse/SparseCsrTensor.cpp @@ -488,16 +488,6 @@ SPARSE_COMPRESSED_TENSOR(csc, kSparseCsc) SPARSE_COMPRESSED_TENSOR(bsr, kSparseBsr) SPARSE_COMPRESSED_TENSOR(bsc, kSparseBsc) -Tensor empty_symint_sparse_compressed( - c10::SymIntArrayRef size, - c10::optional dtype, - c10::optional layout, - c10::optional device, - c10::optional pin_memory, - c10::optional optional_memory_format) { - return at::native::empty_sparse_compressed(c10::asIntArrayRefSlow(size), dtype, layout, device, pin_memory, optional_memory_format); -} - Tensor empty_sparse_compressed( IntArrayRef size, c10::optional dtype, diff --git a/aten/src/ATen/native/sparse/SparseTensor.cpp b/aten/src/ATen/native/sparse/SparseTensor.cpp index a162689eb5fb95..815d264aa7027a 100644 --- a/aten/src/ATen/native/sparse/SparseTensor.cpp +++ b/aten/src/ATen/native/sparse/SparseTensor.cpp @@ -207,17 +207,6 @@ Tensor empty_sparse( size.size(), 0, size, dtype, layout, device, pin_memory); } -/** Empty init **/ -Tensor empty_symint_sparse( - c10::SymIntArrayRef size, - c10::optional dtype, - c10::optional layout, - c10::optional device, - c10::optional pin_memory, - c10::optional optional_memory_format) { - return at::native::empty_sparse(c10::asIntArrayRefSlow(size), dtype, layout, device, pin_memory, optional_memory_format); -} - /* Shape init */ Tensor sparse_coo_tensor(IntArrayRef size, c10::optional dtype, diff --git a/aten/src/ATen/native/sparse/SparseTensorMath.cpp b/aten/src/ATen/native/sparse/SparseTensorMath.cpp index a25083c6fae887..7d62ea22b0f537 100644 --- a/aten/src/ATen/native/sparse/SparseTensorMath.cpp +++ b/aten/src/ATen/native/sparse/SparseTensorMath.cpp @@ -1524,7 +1524,7 @@ SparseTensor& _sspaddmm_out_cpu( int64_t t_nnz = t._nnz(); int64_t r_nnz = nnz * dim_k + t_nnz; Tensor newi = at::empty({2, r_nnz}, kLong); - Tensor newv = native::zeros( + Tensor newv = at::zeros( {r_nnz}, optTypeMetaToScalarType(values.options().dtype_opt()), values.options().layout_opt(), diff --git a/aten/src/ATen/native/ts_native_functions.yaml b/aten/src/ATen/native/ts_native_functions.yaml index a6d26b3ad75b67..b110aa75c83d2b 100644 --- a/aten/src/ATen/native/ts_native_functions.yaml +++ b/aten/src/ATen/native/ts_native_functions.yaml @@ -141,7 +141,6 @@ full_codegen: - upsample_nearest2d - upsample_nearest2d_backward - zero - - narrow_copy.SymInt - alias_copy - as_strided_copy - diagonal_copy @@ -175,7 +174,6 @@ supported: - _copy_from - _copy_from_and_resize - empty.memory_format - - empty.SymInt - empty_strided - fill_.Scalar - normal_ diff --git a/aten/src/ATen/native/vulkan/ops/Factory.cpp b/aten/src/ATen/native/vulkan/ops/Factory.cpp index 06d44ec0619353..ce09521668f4f3 100644 --- a/aten/src/ATen/native/vulkan/ops/Factory.cpp +++ b/aten/src/ATen/native/vulkan/ops/Factory.cpp @@ -29,12 +29,13 @@ Tensor _empty_affine_quantized( } Tensor empty_memory_format( - const IntArrayRef sizes, + const SymIntArrayRef sym_sizes, const c10::optional dtype, const c10::optional layout, const c10::optional device, const c10::optional pin_memory, const optional memory_format) { + auto sizes = c10::asIntArrayRefSlow(sym_sizes); return convert(vTensor{ api::context(), sizes, @@ -55,7 +56,12 @@ Tensor empty_strided( const optional device, const optional pin_memory) { return empty_memory_format( - sizes, dtype, layout, device, pin_memory, c10::MemoryFormat::Contiguous); + c10::SymIntArrayRef::fromIntArrayRef(sizes), + dtype, + layout, + device, + pin_memory, + c10::MemoryFormat::Contiguous); } #ifdef USE_VULKAN_API diff --git a/aten/src/ATen/native/vulkan/ops/Shape.cpp b/aten/src/ATen/native/vulkan/ops/Shape.cpp index 14b32c2eea1791..e1bda761749d7d 100644 --- a/aten/src/ATen/native/vulkan/ops/Shape.cpp +++ b/aten/src/ATen/native/vulkan/ops/Shape.cpp @@ -42,7 +42,8 @@ Tensor view_internal(const Tensor& self_arg, const IntArrayRef shape) { return convert(v_output); } -inline Tensor view(const Tensor& self_arg, const IntArrayRef shape) { +inline Tensor view(const Tensor& self_arg, const SymIntArrayRef sym_shape) { + auto shape = c10::asIntArrayRefSlow(sym_shape); return view_internal(self_arg, shape); } diff --git a/aten/src/ATen/templates/CompositeViewCopyKernels.cpp b/aten/src/ATen/templates/CompositeViewCopyKernels.cpp index dd4c3270843ff8..d6a7266952e9bc 100644 --- a/aten/src/ATen/templates/CompositeViewCopyKernels.cpp +++ b/aten/src/ATen/templates/CompositeViewCopyKernels.cpp @@ -55,8 +55,6 @@ void resize_out_helper(const at::TensorList& dst, const at::TensorList& src) { ${CompositeViewCopyKernel_Definitions} -${SymIntViewCopyKernel_Definitions} - ${GeneratedCompositeFunctional_Definitions} ${GeneratedCompositeOut_Definitions} diff --git a/aten/src/ATen/test/ExclusivelyOwned_test.cpp b/aten/src/ATen/test/ExclusivelyOwned_test.cpp index 5d1dcf7127d7ca..8e93ec9d14f918 100644 --- a/aten/src/ATen/test/ExclusivelyOwned_test.cpp +++ b/aten/src/ATen/test/ExclusivelyOwned_test.cpp @@ -28,7 +28,7 @@ T getSampleValue(); template <> at::Tensor getSampleValue() { - return at::native::zeros({2, 2}).to(at::kCPU); + return at::zeros({2, 2}).to(at::kCPU); } template <> diff --git a/aten/src/ATen/test/MaybeOwned_test.cpp b/aten/src/ATen/test/MaybeOwned_test.cpp index e57bf3a4d9d402..d2d6a3a7059711 100644 --- a/aten/src/ATen/test/MaybeOwned_test.cpp +++ b/aten/src/ATen/test/MaybeOwned_test.cpp @@ -105,7 +105,7 @@ void assertOwn( template<> Tensor getSampleValue() { - return at::native::zeros({2, 2}).to(at::kCPU); + return at::zeros({2, 2}).to(at::kCPU); } template<> diff --git a/aten/src/ATen/test/extension_backend_test.cpp b/aten/src/ATen/test/extension_backend_test.cpp index 9b215a90ae74a4..a5c5868153f74c 100644 --- a/aten/src/ATen/test/extension_backend_test.cpp +++ b/aten/src/ATen/test/extension_backend_test.cpp @@ -15,7 +15,7 @@ using namespace at; static int test_int; -Tensor empty_override(IntArrayRef size, c10::optional dtype, c10::optional layout, +Tensor empty_override(SymIntArrayRef size, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory, c10::optional optional_memory_format) { test_int = 1; auto tensor_impl = c10::make_intrusive( @@ -44,7 +44,7 @@ Tensor empty_strided_override( c10::optional device, c10::optional pin_memory) { - return empty_override(size, dtype, layout, device, pin_memory, c10::nullopt); + return empty_override(SymIntArrayRef::fromIntArrayRef(size), dtype, layout, device, pin_memory, c10::nullopt); } TORCH_LIBRARY_IMPL(aten, ORT, m) { diff --git a/aten/src/ATen/test/math_kernel_test.cpp b/aten/src/ATen/test/math_kernel_test.cpp index 15ce0af4001d5a..29c3388990919d 100644 --- a/aten/src/ATen/test/math_kernel_test.cpp +++ b/aten/src/ATen/test/math_kernel_test.cpp @@ -119,7 +119,7 @@ TEST(MathKernelTest, NarrowCopy) { for (const auto dim : c10::irange(3)) { const int64_t start = 1, length = 4; auto y_ref = x.narrow(dim, start, length); - auto y_test = at::native::narrow_copy_dense(x, dim, start, length); + auto y_test = at::native::narrow_copy_dense(x, dim, c10::SymInt(start), c10::SymInt(length)); ASSERT_ALLCLOSE_TOLERANCES(y_ref, y_test, 0, 0); } } diff --git a/functorch/functorch/csrc/BatchRulesFactory.cpp b/functorch/functorch/csrc/BatchRulesFactory.cpp index 3f63d27a0c8e18..160b4c752f57a8 100644 --- a/functorch/functorch/csrc/BatchRulesFactory.cpp +++ b/functorch/functorch/csrc/BatchRulesFactory.cpp @@ -9,6 +9,25 @@ namespace at { namespace functorch { +template +struct NewBlahBatchRuleHelperSymInt; + +template +struct NewBlahBatchRuleHelperSymInt> { + static std::tuple> apply( + const Tensor& tensor, + optional batch_dim, + SymIntArrayRef shape, + T... extra_args) { + const auto bdim_size = tensor.sym_size(batch_dim.value()); + c10::SmallVector new_shape; + new_shape.reserve(shape.size() + 1); + new_shape.emplace_back(bdim_size); + new_shape.insert(new_shape.end(), shape.begin(), shape.end()); + return std::make_tuple(Func(tensor, new_shape, std::forward(extra_args)...), 0); + } +}; + template struct NewBlahBatchRuleHelper; @@ -37,6 +56,12 @@ struct NewBlahBatchRuleHelper> { &fn,\ c10::guts::function_traits::parameter_types>::apply) +#define NEW_BLAH_BATCH_RULE_SYMINT(fn) SINGLE_ARG(\ + NewBlahBatchRuleHelperSymInt<\ + decltype(&fn),\ + &fn,\ + c10::guts::function_traits::parameter_types>::apply) + std::tuple> _new_zeros_with_same_feature_meta_batch_rule( const Tensor& self, optional self_bdim, const Tensor& other, optional other_bdim, @@ -82,17 +107,6 @@ bool _has_same_storage_numel_batch_rule(const Tensor& a, const Tensor& b) { return true; } -Tensor new_empty_symint_decomp( - const Tensor& self, - SymIntArrayRef size, - c10::optional dtype_opt, - c10::optional layout_opt, - c10::optional device_opt, - c10::optional pin_memory_opt - ) { - return self.new_empty(c10::asIntArrayRefSlow(size), dtype_opt, layout_opt, device_opt, pin_memory_opt); -} - TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) { m.impl("_has_same_storage_numel", _has_same_storage_numel_batch_rule); VMAP_SUPPORT(ones_like, BASIC_UNARY_BATCH_RULE(ATEN_FN(ones_like))); @@ -101,8 +115,7 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) { VMAP_SUPPORT(randn_like, BASIC_UNARY_BATCH_RULE(ATEN_FN(randn_like))); VMAP_SUPPORT(rand_like, BASIC_UNARY_BATCH_RULE(ATEN_FN(rand_like))); VMAP_SUPPORT(full_like, BASIC_UNARY_BATCH_RULE(ATEN_FN(full_like))); - VMAP_SUPPORT(new_empty, NEW_BLAH_BATCH_RULE(ATEN_FN(new_empty))); - m.impl("new_empty.SymInt", new_empty_symint_decomp); + VMAP_SUPPORT(new_empty, NEW_BLAH_BATCH_RULE_SYMINT(ATEN_FN(new_empty))); VMAP_SUPPORT(new_zeros, NEW_BLAH_BATCH_RULE(ATEN_FN(new_zeros))); VMAP_SUPPORT(new_ones, NEW_BLAH_BATCH_RULE(ATEN_FN(new_ones))); VMAP_SUPPORT(new_full, NEW_BLAH_BATCH_RULE(ATEN_FN(new_full))); diff --git a/functorch/functorch/csrc/BatchRulesViews.cpp b/functorch/functorch/csrc/BatchRulesViews.cpp index e4160ea4c98f11..9c382cbaf207ea 100644 --- a/functorch/functorch/csrc/BatchRulesViews.cpp +++ b/functorch/functorch/csrc/BatchRulesViews.cpp @@ -427,15 +427,15 @@ std::tuple> slice_backward_batch_rule( } std::tuple> view_batching_rule( - const Tensor &self, optional self_bdim, IntArrayRef size) + const Tensor &self, optional self_bdim, SymIntArrayRef sym_size) { TORCH_INTERNAL_ASSERT(self_bdim.has_value()); auto self_ = moveBatchDimToFront(self, self_bdim); - VmapDimVector size_(size.size() + 1); + c10::SmallVector size_(sym_size.size() + 1); // copy batch size size_[0] = self_.size(0); - std::copy(size.cbegin(), size.cend(), size_.begin() + 1); - return std::make_tuple(self_.view(size_), 0); + std::copy(sym_size.cbegin(), sym_size.cend(), size_.begin() + 1); + return std::make_tuple(self_.view_symint(size_), 0); } Tensor view_symint_decomposition(const Tensor& self, @@ -446,7 +446,7 @@ Tensor view_symint_decomposition(const Tensor& self, template std::tuple> expand_batch_rule( - const Tensor &self, optional self_bdim, IntArrayRef size, bool implicit) + const Tensor &self, optional self_bdim, SymIntArrayRef size, bool implicit) { auto self_dim = self.dim(); TORCH_CHECK(static_cast(self_dim - 1) <= size.size(), @@ -457,7 +457,7 @@ std::tuple> expand_batch_rule( auto self_sizes = self_.sizes(); auto batch_size = self_sizes[0]; - c10::SmallBuffer size_(size.size() + 1); + c10::SmallVector size_(size.size() + 1); size_[0] = batch_size; std::copy(size.cbegin(), size.cend(), size_.begin() + 1); @@ -471,12 +471,12 @@ std::tuple> expand_batch_rule( // so the strategy here is to view it first as a tensor of size [B0, 1, 3] and // then expand. auto extra_dims = size.size() - (self_dim - 1); - VmapDimVector view_shape(size_.size(), /*init_value*/1); + c10::SmallVector view_shape(size_.size(), /*init_value*/1); view_shape[0] = batch_size; std::copy(self_sizes.cbegin() + 1, self_sizes.cend(), view_shape.begin() + 1 + extra_dims); - return std::make_tuple(Func(self_.view(view_shape), size_, implicit), 0); + return std::make_tuple(Func(self_.view_symint(view_shape), size_, implicit), 0); } std::tuple> unfold_batch_rule( @@ -549,8 +549,6 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) { VMAP_SUPPORT2(slice, Tensor, slice_batch_rule); VMAP_SUPPORT2(transpose, int, transpose_int_batch_rule); VMAP_SUPPORT(diag_embed, diag_embed_batch_rule); - m.impl("expand.SymInt", expand_symint_decomp_hack); - m.impl("view.SymInt", view_symint_decomposition); } }} diff --git a/test/cpp/jit/test_custom_class.cpp b/test/cpp/jit/test_custom_class.cpp index 9a9b5ce956c6fc..f7fe339b561435 100644 --- a/test/cpp/jit/test_custom_class.cpp +++ b/test/cpp/jit/test_custom_class.cpp @@ -45,6 +45,20 @@ TEST(CustomClassTest, TorchbindIValueAPI) { test_with_obj(new_stack_ivalue, "boo"); } +TEST(CustomClassTest, ScalarTypeClass) { + script::Module m("m"); + + // test make_custom_class API + auto cc = make_custom_class(at::kFloat); + m.register_attribute("s", cc.type(), cc, false); + + std::ostringstream oss; + m.save(oss); + std::istringstream iss(oss.str()); + caffe2::serialize::IStreamAdapter adapter{&iss}; + auto loaded_module = torch::jit::load(iss, torch::kCPU); +} + class TorchBindTestClass : public torch::jit::CustomClassHolder { public: std::string get() { diff --git a/test/cpp/jit/test_custom_class_registrations.cpp b/test/cpp/jit/test_custom_class_registrations.cpp index 35d35f4cf640ff..63c6b701330629 100644 --- a/test/cpp/jit/test_custom_class_registrations.cpp +++ b/test/cpp/jit/test_custom_class_registrations.cpp @@ -275,6 +275,16 @@ struct ReLUClass : public torch::CustomClassHolder { }; TORCH_LIBRARY(_TorchScriptTesting, m) { + m.class_("_ScalarTypeClass") + .def(torch::init()) + .def_pickle( + [](const c10::intrusive_ptr& self) { + return std::make_tuple(self->scalar_type_); + }, + [](std::tuple s) { + return c10::make_intrusive(std::get<0>(s)); + }); + m.class_("_ReLUClass") .def(torch::init<>()) .def("run", &ReLUClass::run); diff --git a/test/cpp/jit/test_custom_class_registrations.h b/test/cpp/jit/test_custom_class_registrations.h index 4e6b7bd43883cb..59ee7c9fe15f87 100644 --- a/test/cpp/jit/test_custom_class_registrations.h +++ b/test/cpp/jit/test_custom_class_registrations.h @@ -4,6 +4,11 @@ namespace torch { namespace jit { +struct ScalarTypeClass : public torch::CustomClassHolder { + ScalarTypeClass(at::ScalarType s) : scalar_type_(s) {} + at::ScalarType scalar_type_; +}; + template struct MyStackClass : torch::CustomClassHolder { std::vector stack_; diff --git a/test/cpp/lazy/test_lazy_ops.cpp b/test/cpp/lazy/test_lazy_ops.cpp index b20ccc839e742b..6198940b31001b 100644 --- a/test/cpp/lazy/test_lazy_ops.cpp +++ b/test/cpp/lazy/test_lazy_ops.cpp @@ -84,35 +84,6 @@ static inline at::DeviceType DefaultDevice() { } // namespace -#ifndef C10_MOBILE -TEST(LazyDynamicOpsTest, NarrowCopy) { - auto x = torch::rand({5, 10, 10}).to(kLazy); - const size_t Y_DIM = 3; - const size_t X_DIM_INDEX = 2; - auto y = torch::rand({Y_DIM}).to(kLazy); - auto ly = torch::lazy::TryGetLtcTensor(y); - auto dim_node = MakeNode(ly->GetIrValue(), 0); - auto lmn = c10::make_intrusive(dim_node); - auto z = x.narrow_copy_symint(X_DIM_INDEX, 0, lmn->toSymInt()); - AllClose(z.cpu(), x.cpu().narrow_copy(X_DIM_INDEX, 0, Y_DIM)); -} - -TEST(LazyDynamicOpsTest, NarrowCopyViaSymSizes) { - FLAGS_ltc_enable_symbolic_shapes = true; - auto xc = torch::rand({10}); - auto x = xc.to(kLazy); - const size_t Y_DIM = 3; - const size_t X_DIM_INDEX = 0; - auto y = torch::rand({Y_DIM}).to(kLazy); - auto z = x.narrow_copy_symint(X_DIM_INDEX, 0, y.sym_sizes()[0]); - auto zc = xc.narrow_copy(X_DIM_INDEX, 0, Y_DIM); - ASSERT_EQ(z.sizes()[0], xc.sizes()[0]); // note, xc not zc - // shape inference assumes narrow_copy can copy the whole tensor - AllClose(z.cpu(), zc); - FLAGS_ltc_enable_symbolic_shapes = false; -} -#endif - TEST_F(LazyOpsTest, TestScalarTensor) { torch::Tensor scalar_tensor = torch::scalar_tensor( 1., torch::TensorOptions(torch::kFloat).device(DefaultDevice())); diff --git a/test/cpp_extensions/open_registration_extension.cpp b/test/cpp_extensions/open_registration_extension.cpp index c6d05042ef76f3..7f43e60a6b393d 100644 --- a/test/cpp_extensions/open_registration_extension.cpp +++ b/test/cpp_extensions/open_registration_extension.cpp @@ -85,8 +85,7 @@ at::Tensor custom__copy_from(const at::Tensor& self, const at::Tensor& dst, bool // More details on the dispatcher can be found at http://blog.ezyang.com/2020/09/lets-talk-about-the-pytorch-dispatcher/. TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) { m.impl("add.Tensor", &custom_add_Tensor); - m.impl("empty.memory_format", &custom_empty_memory_format); - m.impl("empty.SymInt", &custom_empty_symint); + m.impl("empty.memory_format", &custom_empty_symint); m.impl("fill_.Scalar", &custom_fill__scalar); m.impl("_copy_from", &custom__copy_from); } diff --git a/test/cpp_extensions/ort_extension.cpp b/test/cpp_extensions/ort_extension.cpp index 24617abeb06d58..3422bccd6d38c6 100644 --- a/test/cpp_extensions/ort_extension.cpp +++ b/test/cpp_extensions/ort_extension.cpp @@ -20,15 +20,10 @@ Tensor get_tensor(caffe2::TypeMeta dtype, IntArrayRef size) { return Tensor(std::move(tensor_impl)); } -Tensor empty_override(IntArrayRef size, c10::optional dtype, c10::optional layout, c10::optional device, +Tensor empty_override(SymIntArrayRef size, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory, c10::optional optional_memory_format) { test_int = 0; - return get_tensor(scalarTypeToTypeMeta(dtype_or_default(dtype)), size); -} - -Tensor empty_symint_override(c10::SymIntArrayRef size, c10::optional dtype_opt, c10::optional layout_opt, - c10::optional device_opt, c10::optional pin_memory_opt, c10::optional memory_format_opt) { - return empty_override(c10::asIntArrayRefSlow(size), dtype_opt, layout_opt, device_opt, pin_memory_opt, memory_format_opt); + return get_tensor(scalarTypeToTypeMeta(dtype_or_default(dtype)), c10::asIntArrayRefSlow(size)); } Tensor& add_out_override(const Tensor & a, const Tensor & b , const Scalar& c, Tensor & out) { @@ -58,7 +53,6 @@ std::tuple fake_convolution_backward( } TORCH_LIBRARY_IMPL(aten, ORT, m) { - m.impl("empty.SymInt", empty_symint_override); m.impl("empty.memory_format", empty_override); m.impl("add.out", add_out_override); m.impl("convolution_overrideable", fake_convolution); diff --git a/test/forward_backward_compatibility/check_forward_backward_compatibility.py b/test/forward_backward_compatibility/check_forward_backward_compatibility.py index f5a0d3d26254b9..c37c795ec96639 100644 --- a/test/forward_backward_compatibility/check_forward_backward_compatibility.py +++ b/test/forward_backward_compatibility/check_forward_backward_compatibility.py @@ -131,6 +131,18 @@ ("aten::sum.SymInt", datetime.date(2022, 11, 30)), ("aten::mps_linear", datetime.date(9999, 1, 1)), ("aten::_mps_linear", datetime.date(9999, 1, 1)), + ("aten::view_copy.SymInt", datetime.date(2022, 11, 30)), + ("aten::view_copy.SymInt_out", datetime.date(2022, 11, 30)), + ("aten::expand_copy.SymInt", datetime.date(2022, 11, 30)), + ("aten::expand_copy.SymInt_out", datetime.date(2022, 11, 30)), + ("aten::expand.SymInt", datetime.date(2022, 11, 30)), + ("aten::narrow_copy.SymInt", datetime.date(2022, 11, 30)), + ("aten::narrow_copy.SymInt_out", datetime.date(2022, 11, 30)), + ("aten::view.SymInt", datetime.date(2022, 11, 30)), + ("aten::new_empty.SymInt", datetime.date(2022, 11, 30)), + ("aten::new_empty.SymInt_out", datetime.date(2022, 11, 30)), + ("aten::zeros.SymInt", datetime.date(2022, 11, 30)), + ("aten::zeros.SymInt_out", datetime.date(2022, 11, 30)), # TODO: FIXME: prims shouldn't be checked ("prims::.*", datetime.date(9999, 1, 1)), ("aten::_amp_foreach_non_finite_check_and_unscale.out", datetime.date(2022, 9, 1)), diff --git a/test/test_decomp.py b/test/test_decomp.py index 8077aa86d4abc6..1e31b07d4e7961 100644 --- a/test/test_decomp.py +++ b/test/test_decomp.py @@ -392,8 +392,7 @@ def _torch_dispatch(cls, func, types, args=(), kwargs=None): if func not in decomposition_table or func in [ torch.ops.aten.detach.default, # non-deterministic ops - torch.ops.aten.new_empty.default, - torch.ops.aten.new_empty.SymInt + torch.ops.aten.new_empty.default ] or any_unsupported(args, kwargs): return func(*args, **kwargs) diff --git a/test/test_dynamic_shapes.py b/test/test_dynamic_shapes.py index 56e7e9b1219256..1f2c5adca66634 100644 --- a/test/test_dynamic_shapes.py +++ b/test/test_dynamic_shapes.py @@ -54,7 +54,7 @@ def cat_meta(tensors, dim=0): return tensors[0].new_empty(new_shape) -@register_meta([aten.narrow_copy.SymInt]) +@register_meta([aten.narrow_copy.default]) def narrow_copy_symint_meta(a, dim, start, length, **kwargs): shape = [] for i, x in enumerate(a.shape): @@ -65,7 +65,7 @@ def narrow_copy_symint_meta(a, dim, start, length, **kwargs): return a.new_empty(tuple(shape)) -@register_meta([aten.expand.SymInt]) +@register_meta([aten.expand.default]) def expand_symint_meta(a, size, implicit=False): return a.new_empty(size) @@ -293,11 +293,11 @@ def test_aten_ops(self): shape_env = ShapeEnv() x = create_symbolic_tensor("x", torch.randn(5), shape_env) - torch.ops.aten.narrow_copy.SymInt(x, 0, 0, x.shape[0]) + torch.ops.aten.narrow_copy.default(x, 0, 0, x.shape[0]) shape_env = ShapeEnv() x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env) - torch.ops.aten.expand.SymInt(x, [x.shape[0], x.shape[1], x.shape[2]]) + torch.ops.aten.expand.default(x, [x.shape[0], x.shape[1], x.shape[2]]) def test_fx_trace_intlist(self): class CustomModule(torch.nn.Module): diff --git a/test/test_meta.py b/test/test_meta.py index 14fab0c9b2914c..540d54e50ed535 100644 --- a/test/test_meta.py +++ b/test/test_meta.py @@ -727,7 +727,6 @@ def __torch_function__(self, func, types, args=(), kwargs=None): aten.linalg_pinv.atol_rtol_tensor: {f32, f64}, aten.linalg_pinv.atol_rtol_tensor_out: {f32, f64}, aten.empty.memory_format: {b8, bf16, c128, c64, c32, f16, f32, f64, i16, i32, i64, i8, u8}, - aten.empty.SymInt: {b8, bf16, c128, c64, c32, f16, f32, f64, i16, i32, i64, i8, u8}, } meta_dispatch_device_expected_failures = defaultdict(dict) diff --git a/test/test_nn.py b/test/test_nn.py index f4f5b2be62e21b..bdb5ba7ef24c34 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -15965,6 +15965,7 @@ def issue_24823_2(): torch.cuda.synchronize() issue_24823_2() + @skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/945") @dtypes(torch.float, torch.double) @largeTensorTest(lambda self, device, dtype: # Compute sum of the large tensor sizes: diff --git a/test/test_profiler_tree.py b/test/test_profiler_tree.py index 55491bc86b45c0..90fdd1118f494d 100644 --- a/test/test_profiler_tree.py +++ b/test/test_profiler_tree.py @@ -317,24 +317,21 @@ def test_profiler_experimental_tree_with_record_function(self): ProfilerTree.format(p.profiler, 12), """\ aten::zeros - aten::zeros - aten::empty - aten::zero_ + aten::empty + aten::zero_ Top level Annotation aten::empty aten::zeros - aten::zeros - aten::empty - aten::zero_ + aten::empty + aten::zero_ First Annotation aten::empty aten::ones aten::empty aten::fill_ aten::zeros - aten::zeros - aten::empty - aten::zero_ + aten::empty + aten::zero_ Second Annotation aten::empty aten::add @@ -343,9 +340,8 @@ def test_profiler_experimental_tree_with_record_function(self): aten::empty_strided aten::copy_ aten::zeros - aten::zeros - aten::empty - aten::zero_ + aten::empty + aten::zero_ Third Annotation aten::empty aten::ones_like @@ -716,6 +712,7 @@ def test_profiler_experimental_tree_with_stack_and_torch_dispatch(self): torch/profiler/profiler.py(...): stop ...""") + @unittest.skip("https://github.com/pytorch/pytorch/issues/83606") @unittest.skipIf(not torch.cuda.is_available(), "CUDA is required") @ProfilerTree.test def test_profiler_experimental_tree_cuda(self): @@ -813,6 +810,7 @@ def test_profiler_experimental_tree_cuda(self): allow_failure=ALLOW_CUDA_FAILURE, ) + @unittest.skip("https://github.com/pytorch/pytorch/issues/83606") @unittest.skipIf(not torch.cuda.is_available(), "CUDA is required") @ProfilerTree.test def test_profiler_experimental_tree_cuda_with_stream(self): diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py index f3682560fde008..d19b5907170891 100644 --- a/test/test_proxy_tensor.py +++ b/test/test_proxy_tensor.py @@ -762,7 +762,7 @@ def f(a): def forward(self, a_1): sym_size = torch.ops.aten.sym_size(a_1, 0); a_1 = None mul = sym_size * 2; sym_size = None - empty = torch.ops.aten.empty.SymInt([mul], device = device(type='cpu'), pin_memory = False); mul = None + empty = torch.ops.aten.empty.memory_format([mul], device = device(type='cpu'), pin_memory = False); mul = None sym_size_1 = torch.ops.aten.sym_size(empty, 0) return empty""") diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index 9c08d05a7ff8a5..dd649bc9c08361 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -622,12 +622,9 @@ self: grad * (result + 1) result: auto_element_wise -- name: expand(Tensor(a) self, int[] size, *, bool implicit=False) -> Tensor(a) - self: at::sum_to(grad, self.sizes()) - result: auto_linear - -- name: expand.SymInt(Tensor(a) self, SymInt[] size, *, bool implicit=False) -> Tensor(a) - self: at::sum_to(grad, c10::asIntArrayRefSlow(self.sym_sizes())) +# TODO: this derivative is not SymInt safe, need sum_to support +- name: expand(Tensor(a) self, SymInt[] size, *, bool implicit=False) -> Tensor(a) + self: at::sum_to(grad, self.sym_sizes()) result: auto_linear - name: exponential_(Tensor(a!) self, float lambd=1, *, Generator? generator=None) -> Tensor(a!) @@ -1735,16 +1732,11 @@ # linear result1: mean(self_t, dim.value_or(IntArrayRef({})), keepdim) -- name: view(Tensor(a) self, int[] size) -> Tensor(a) +# TODO: this derivative is not SymInt safe, need reshape_symint +- name: view(Tensor(a) self, SymInt[] size) -> Tensor(a) self: grad.reshape(self.sizes()) result: auto_linear -- name: view.SymInt(Tensor(a) self, SymInt[] size) -> Tensor(a) - # TODO: add proper double backward for view.SymInt - # by SymIntizing `reshape` - self: grad.reshape(c10::asIntArrayRefSlow(self.sym_sizes())) - result: auto_linear - - name: view.dtype(Tensor(a) self, ScalarType dtype) -> Tensor(a) output_differentiability: [False] diff --git a/tools/autograd/gen_inplace_or_view_type.py b/tools/autograd/gen_inplace_or_view_type.py index 69f3eecf590cc8..cd81305347a11c 100644 --- a/tools/autograd/gen_inplace_or_view_type.py +++ b/tools/autograd/gen_inplace_or_view_type.py @@ -249,6 +249,7 @@ def unpack_args(f: NativeFunction) -> Tuple[List[str], List[Binding]]: for r in cpp.argument( a, method=False, + symint=True, cpp_no_default_args=set(), faithful=False, has_tensor_options=False, @@ -494,7 +495,7 @@ def gen_formals(f: NativeFunction) -> str: # See Note [Plumbing Keys Through The Dispatcher] for details. ["c10::DispatchKeySet ks"] + [ - f'{cpp.argument_type(a, binds="__placeholder__").cpp_type()} {a.name}' + f'{cpp.argument_type(a, binds="__placeholder__", symint=True).cpp_type()} {a.name}' for a in f.func.schema_order_arguments() ] ) @@ -514,7 +515,7 @@ def inplace_or_view_method_definition( ): return None return METHOD_DEFINITION.substitute( - return_type=cpp.returns_type(f.func.returns).cpp_type(), + return_type=cpp.returns_type(f.func.returns, symint=True).cpp_type(), type_wrapper_name=type_wrapper_name(f), formals=gen_formals(f), type_definition_body=emit_inplace_or_view_body(fn), diff --git a/tools/autograd/gen_python_functions.py b/tools/autograd/gen_python_functions.py index cf585daba2f071..1c476300ad4275 100644 --- a/tools/autograd/gen_python_functions.py +++ b/tools/autograd/gen_python_functions.py @@ -238,6 +238,8 @@ def gen( tags_yaml_path: str, deprecated_yaml_path: str, template_path: str, + *, + symint: bool = True, ) -> None: fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False) native_functions = parse_native_yaml( @@ -253,6 +255,7 @@ def gen( None, "python_variable_methods.cpp", method=True, + symint=symint, ) # NOTE: num_shards here must be synced with gatherTorchFunctions in @@ -266,6 +269,7 @@ def gen( "python_torch_functions.cpp", method=False, num_shards=3, + symint=symint, ) create_python_bindings( @@ -275,6 +279,7 @@ def gen( "torch.nn", "python_nn_functions.cpp", method=False, + symint=symint, ) create_python_bindings( @@ -284,6 +289,7 @@ def gen( "torch.fft", "python_fft_functions.cpp", method=False, + symint=symint, ) create_python_bindings( @@ -293,6 +299,7 @@ def gen( "torch.linalg", "python_linalg_functions.cpp", method=False, + symint=symint, ) create_python_bindings( @@ -302,6 +309,7 @@ def gen( "torch.sparse", "python_sparse_functions.cpp", method=False, + symint=symint, ) create_python_bindings( @@ -311,6 +319,7 @@ def gen( "torch.special", "python_special_functions.cpp", method=False, + symint=symint, ) # Currently, we only use `functions` to generate `return_types` bindings. @@ -354,6 +363,7 @@ def create_python_bindings( filename: str, *, method: bool, + symint: bool = True, ) -> None: """Generates Python bindings to ATen functions""" py_methods: List[str] = [] @@ -365,7 +375,9 @@ def create_python_bindings( for name in sorted(grouped.keys(), key=lambda x: str(x)): overloads = grouped[name] - py_methods.append(method_impl(name, module, overloads, method=method)) + py_methods.append( + method_impl(name, module, overloads, method=method, symint=symint) + ) py_method_defs.append(method_def(name, module, overloads, method=method)) py_forwards.extend(forward_decls(name, overloads, method=method)) ops_headers.append(f"#include ") @@ -428,6 +440,7 @@ def create_python_bindings_sharded( *, method: bool, num_shards: int, + symint: bool = True, ) -> None: """Generates Python bindings to ATen functions""" grouped = group_filter_overloads(pairs, pred) @@ -444,7 +457,9 @@ def env_func( return { "ops_headers": [f"#include "], "py_forwards": list(forward_decls(name, fn_pairs, method=method)), - "py_methods": [method_impl(name, module, fn_pairs, method=method)], + "py_methods": [ + method_impl(name, module, fn_pairs, method=method, symint=symint) + ], "py_method_defs": [method_def(name, module, fn_pairs, method=method)], } @@ -773,6 +788,7 @@ def method_impl( overloads: Sequence[PythonSignatureNativeFunctionPair], *, method: bool, + symint: bool = True, ) -> str: """ Generate a python binding for all overloads of an op. @@ -791,14 +807,18 @@ def method_impl( traceable = "true" if all(should_trace(o.function) for o in overloads) else "false" - grouped_overloads: Sequence[PythonSignatureGroup] = group_overloads(overloads) + grouped_overloads: Sequence[PythonSignatureGroup] = group_overloads( + overloads, symint=symint + ) is_singleton = len(grouped_overloads) == 1 signatures: List[str] = [] dispatch: List[str] = [] for overload_index, overload in enumerate(grouped_overloads): - signature = overload.signature.signature_str() + signature = overload.signature.signature_str(symint=symint) signatures.append(f"{cpp_string(str(signature))},") - dispatch_body = emit_dispatch_case(overload, namedtuple_typenames) + dispatch_body = emit_dispatch_case( + overload, namedtuple_typenames, symint=symint + ) dispatch.append( PY_VARIABLE_CASE.substitute( overload_index=overload_index, body=dispatch_body @@ -882,6 +902,8 @@ def gen_has_torch_function_check( def emit_dispatch_case( overload: PythonSignatureGroup, namedtuple_typenames: Dict[str, str], + *, + symint: bool = True, ) -> str: """ Emit dispatch code for a single parsed signature. This corresponds to either @@ -894,18 +916,19 @@ def emit_dispatch_case( return PY_VARIABLE_OUT.substitute( out_idx=overload.signature.output_idx(), call_dispatch=emit_single_dispatch( - overload.signature, overload.base, namedtuple_typenames + overload.signature, overload.base, namedtuple_typenames, symint=symint ), call_dispatch_out=emit_single_dispatch( overload.signature, overload.outplace, namedtuple_typenames, + symint=symint, ), ) else: # no-output version only return emit_single_dispatch( - overload.signature, overload.base, namedtuple_typenames + overload.signature, overload.base, namedtuple_typenames, symint=symint ) @@ -987,14 +1010,14 @@ def method_def( def group_overloads( - overloads: Sequence[PythonSignatureNativeFunctionPair], + overloads: Sequence[PythonSignatureNativeFunctionPair], *, symint: bool = True ) -> Sequence[PythonSignatureGroup]: bases: Dict[str, PythonSignatureNativeFunctionPair] = {} outplaces: Dict[str, PythonSignatureNativeFunctionPair] = {} # first group by signature ignoring out arguments for overload in overloads: - sig = overload.signature.signature_str(skip_outputs=True) + sig = overload.signature.signature_str(skip_outputs=True, symint=symint) if overload.function.func.is_out_fn(): if sig in outplaces: raise RuntimeError( @@ -1021,9 +1044,11 @@ def group_overloads( and not overload.signature.deprecated ): candidates.append( - overload.signature.signature_str(skip_outputs=True) + overload.signature.signature_str( + skip_outputs=True, symint=symint + ) ) - out_sig = out.signature.signature_str() + out_sig = out.signature.signature_str(symint=symint) raise RuntimeError( f"While identifying overloads, we found an out schema {out_sig} without a corresponding non-out variant. " f"We expected the non-out variant to have schema: \n- {sig}\nPlease check that you spelled the schema " @@ -1038,7 +1063,7 @@ def group_overloads( ) for sig, base in bases.items() ] - return sort_overloads(grouped) + return sort_overloads(grouped, symint=symint) # This function declares a partial order on declarations, and sorts them according @@ -1087,7 +1112,7 @@ def group_overloads( def sort_overloads( - grouped_overloads: Sequence[PythonSignatureGroup], + grouped_overloads: Sequence[PythonSignatureGroup], *, symint: bool = True ) -> Sequence[PythonSignatureGroup]: # NB: Smaller here means lower priority @@ -1132,7 +1157,7 @@ def is_smaller(s1: PythonSignature, s2: PythonSignature) -> bool: # First sort by signature grouped_overloads = sorted( - grouped_overloads, key=lambda x: x.signature.signature_str() + grouped_overloads, key=lambda x: x.signature.signature_str(symint=symint) ) # Construct the relation graph @@ -1170,7 +1195,11 @@ def is_smaller(s1: PythonSignature, s2: PythonSignature) -> bool: def emit_single_dispatch( - ps: PythonSignature, f: NativeFunction, namedtuple_typenames: Dict[str, str] + ps: PythonSignature, + f: NativeFunction, + namedtuple_typenames: Dict[str, str], + *, + symint: bool = True, ) -> str: """ Emit dispatch code for a single native function. @@ -1189,7 +1218,10 @@ def go(f: NativeFunction) -> str: # dispatch lambda signature name = cpp.name(f.func) lambda_formals = ", ".join( - map(lambda a: f"{a.type_str} {a.name}", dispatch_lambda_args(ps, f)) + map( + lambda a: f"{a.type_str} {a.name}", + dispatch_lambda_args(ps, f, symint=symint), + ) ) lambda_return = dispatch_lambda_return_str(f) @@ -1198,8 +1230,8 @@ def go(f: NativeFunction) -> str: dispatch_args = ", ".join(cpp_dispatch_exprs(f, python_signature=ps)) # from arg parser outputs to dispatch lambda arguments - parser_outputs = arg_parser_output_exprs(ps, f) - lambda_arg_exprs = dispatch_lambda_exprs(ps, f) + parser_outputs = arg_parser_output_exprs(ps, f, symint=symint) + lambda_arg_exprs = dispatch_lambda_exprs(ps, f, symint=symint) inits = "\n".join(lambda_arg_exprs.inits) lambda_args = ", ".join(lambda_arg_exprs.exprs) diff --git a/tools/autograd/gen_trace_type.py b/tools/autograd/gen_trace_type.py index f7df5efb39c918..21739bb80510f2 100644 --- a/tools/autograd/gen_trace_type.py +++ b/tools/autograd/gen_trace_type.py @@ -383,7 +383,7 @@ def declare_returned_variables(f: NativeFunction) -> str: return "" if len(f.func.returns) == 1: return "" - types = map(cpp.return_type, f.func.returns) + types = [cpp.return_type(r, symint=True) for r in f.func.returns] names = cpp.return_names(f) return "\n".join(f"{type.cpp_type()} {name};" for type, name in zip(types, names)) @@ -483,13 +483,13 @@ def method_definition(f: NativeFunction) -> str: # See Note [Plumbing Keys Through The Dispatcher] for details. ["c10::DispatchKeySet ks"] + [ - f'{cpp.argument_type(a, binds="__placeholder__").cpp_type()} {a.name}' + f'{cpp.argument_type(a, binds="__placeholder__", symint=True).cpp_type()} {a.name}' for a in f.func.schema_order_arguments() ] ) return METHOD_DEFINITION.substitute( - return_type=cpp.returns_type(f.func.returns).cpp_type(), + return_type=cpp.returns_type(f.func.returns, symint=True).cpp_type(), type_wrapper_name=type_wrapper_name(f), formals=formals, type_definition_body=emit_trace_body(f), diff --git a/tools/autograd/gen_variable_factories.py b/tools/autograd/gen_variable_factories.py index 07abc98a8d4c27..88356bd7234de4 100644 --- a/tools/autograd/gen_variable_factories.py +++ b/tools/autograd/gen_variable_factories.py @@ -76,31 +76,39 @@ def process_function(f: NativeFunction) -> Optional[str]: if Variant.function not in f.variants or not is_factory: return None - sig = CppSignatureGroup.from_native_function(f, method=False).signature - formals: List[str] = [] - exprs: List[str] = [] - requires_grad = "false" - for arg in sig.arguments(): - qualified_type = fully_qualified_type(arg.type) - if arg.default: - formals.append(f"{qualified_type} {arg.name} = {arg.default}") - else: - formals.append(f"{qualified_type} {arg.name}") - - if isinstance(arg.argument, TensorOptionsArguments): - # note: we remove the requires_grad setting from the TensorOptions because - # it is ignored anyways (and we actually have an assertion that it isn't set - # which would fail otherwise). We handle requires_grad explicitly here - # instead of passing it through to the kernel. - exprs.append(f"at::TensorOptions({arg.name}).requires_grad(c10::nullopt)") - # Manually set the requires_grad bit on the result tensor. - requires_grad = f"{arg.name}.requires_grad()" - else: - exprs.append(arg.name) - - return f"""\ -inline at::Tensor {name}({', '.join(formals)}) {{ + cpp_sigs = CppSignatureGroup.from_native_function(f, method=False) + sigs = [cpp_sigs.signature] + if cpp_sigs.symint_signature is not None: + sigs.append(cpp_sigs.symint_signature) + r = "" + for sig in sigs: + formals: List[str] = [] + exprs: List[str] = [] + requires_grad = "false" + for arg in sig.arguments(): + qualified_type = fully_qualified_type(arg.type) + if arg.default: + formals.append(f"{qualified_type} {arg.name} = {arg.default}") + else: + formals.append(f"{qualified_type} {arg.name}") + + if isinstance(arg.argument, TensorOptionsArguments): + # note: we remove the requires_grad setting from the TensorOptions because + # it is ignored anyways (and we actually have an assertion that it isn't set + # which would fail otherwise). We handle requires_grad explicitly here + # instead of passing it through to the kernel. + exprs.append( + f"at::TensorOptions({arg.name}).requires_grad(c10::nullopt)" + ) + # Manually set the requires_grad bit on the result tensor. + requires_grad = f"{arg.name}.requires_grad()" + else: + exprs.append(arg.name) + + r += f"""\ +inline at::Tensor {sig.name()}({', '.join(formals)}) {{ at::AutoDispatchBelowADInplaceOrView guard; - return autograd::make_variable(at::{name}({', '.join(exprs)}), /*requires_grad=*/{requires_grad}); + return autograd::make_variable(at::{sig.name()}({', '.join(exprs)}), /*requires_grad=*/{requires_grad}); }} """ + return r diff --git a/tools/autograd/gen_variable_type.py b/tools/autograd/gen_variable_type.py index d9488b6398f989..f9afe838203ded 100644 --- a/tools/autograd/gen_variable_type.py +++ b/tools/autograd/gen_variable_type.py @@ -849,7 +849,9 @@ def gen_variable_type_func( if not fn.info: key = "Default" type_definition = METHOD_DEFINITION.substitute( - return_type=cpp.returns_type(f.func.returns).cpp_type(), + return_type=cpp.returns_type( + f.func.returns, symint=True + ).cpp_type(), type_wrapper_name=type_wrapper_name(f, key), type_definition_body=emit_body(fn, key), formals=formals, @@ -860,7 +862,9 @@ def gen_variable_type_func( else: for key, _ in fn.info.items(): type_definition = METHOD_DEFINITION.substitute( - return_type=cpp.returns_type(f.func.returns).cpp_type(), + return_type=cpp.returns_type( + f.func.returns, symint=True + ).cpp_type(), type_wrapper_name=type_wrapper_name(f, key), type_definition_body=emit_body(fn, key), formals=formals, @@ -913,7 +917,7 @@ def gen_differentiable_input( # TODO: `cpp_type` is only to keep it byte-for-byte compatible with the old codegen, should remove. # NB: This is not a clone of cpp.argument() - TensorOptionsArguments / faithful / binds are # not handled properly as they are irrelevant for this codegen. - cpp_type = cpp.argument_type(a, binds=a.name).cpp_type() + cpp_type = cpp.argument_type(a, binds=a.name, symint=True).cpp_type() if not is_differentiable(a.name, a.type, info): return None @@ -1204,6 +1208,7 @@ def emit_dispatch_call( api_name=cpp.name( f.func, faithful_name_for_out_overloads=True, + symint_overload=f.func.has_symint(), ), unpacked_args=[dispatch_key_set] + list(unpacked_args), ) @@ -1285,7 +1290,7 @@ def check_tensorimpl_and_storage( for i, (ret, ret_name) in enumerate( zip(f.func.returns, cpp.return_names(f)) ): - noref_cpp_type = cpp.return_type(ret).remove_const_ref() + noref_cpp_type = cpp.return_type(ret, symint=True).remove_const_ref() if noref_cpp_type == BaseCType(tensorT): if aliased_arg_name is not None: assert ( diff --git a/tools/autograd/load_derivatives.py b/tools/autograd/load_derivatives.py index ce6e3d808192c7..5bffe2ba979d30 100644 --- a/tools/autograd/load_derivatives.py +++ b/tools/autograd/load_derivatives.py @@ -184,7 +184,9 @@ def create_derivative( ] return_names = tuple(n if n != "self" else "result" for n in cpp.return_names(f)) - return_types = tuple(cpp.return_type(r).remove_const_ref() for r in f.func.returns) + return_types = tuple( + cpp.return_type(r, symint=True).remove_const_ref() for r in f.func.returns + ) named_returns = [ NamedCType(name, type) for name, type in zip(return_names, return_types) @@ -375,7 +377,7 @@ def repl(m: Any) -> str: new_args.append(arg_name) # TODO we are trolling - if f.func.is_symint_fn(): + if f.func.has_symint(): defn_name += "_symint" # Call into the forward again. We need two cases here to handle both Tensor methods and at:: functions. diff --git a/tools/test/test_codegen.py b/tools/test/test_codegen.py index cbce8b3bc5e66d..781dde46fe7006 100644 --- a/tools/test/test_codegen.py +++ b/tools/test/test_codegen.py @@ -314,6 +314,7 @@ def setUp(self) -> None: dispatch_key=k, use_out_as_primary=True, external=False, + symint=False, device_guard=False, index=backend_indices[k], ) diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py index 034fa33f12ed51..f6eae115e11312 100644 --- a/torch/_subclasses/fake_tensor.py +++ b/torch/_subclasses/fake_tensor.py @@ -612,7 +612,7 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None): torch._C._remove_meta_from_tls_dispatch_include() if has_symbolic_sizes: - constructors = [aten.empty.SymInt] + constructors = [aten.empty.memory_format] if func not in constructors: raise RuntimeError( f"{func} - couldn't find symbolic meta function/decomposition" diff --git a/torch/csrc/jit/codegen/cuda/interface.cpp b/torch/csrc/jit/codegen/cuda/interface.cpp index ae353220f1144b..582ce64bdb92bf 100644 --- a/torch/csrc/jit/codegen/cuda/interface.cpp +++ b/torch/csrc/jit/codegen/cuda/interface.cpp @@ -843,8 +843,7 @@ RegisterOperators reg_expand_copy({ "alias ops, should be restored after fusion pass!"); IValue self, size, implicit; pop(stack, self, size, implicit); - push( - stack, at::native::expand(self.toTensor(), size.toIntVector())); + push(stack, self.toTensor().expand(size.toIntVector())); }; }, aliasAnalysisFromSchema()), diff --git a/torch/csrc/jit/frontend/schema_type_parser.cpp b/torch/csrc/jit/frontend/schema_type_parser.cpp index a5fdf8bc539efd..a5e23b9de37d19 100644 --- a/torch/csrc/jit/frontend/schema_type_parser.cpp +++ b/torch/csrc/jit/frontend/schema_type_parser.cpp @@ -411,7 +411,8 @@ SchemaTypeParser::parseFakeAndRealType() { real_value = parseBaseType(); if (real_value->kind() == ScalarTypeType::Kind || real_value->kind() == MemoryFormatType::Kind || - real_value->kind() == LayoutType::Kind) { + real_value->kind() == LayoutType::Kind || + real_value->kind() == SymIntType::Kind) { fake_value = c10::TypeFactory::get(); } else { fake_value = real_value; diff --git a/torch/csrc/jit/python/pybind_utils.cpp b/torch/csrc/jit/python/pybind_utils.cpp index afe5961ff89975..624e11ac23599e 100644 --- a/torch/csrc/jit/python/pybind_utils.cpp +++ b/torch/csrc/jit/python/pybind_utils.cpp @@ -80,16 +80,13 @@ IValue toIValue(py::handle obj, const TypePtr& type, c10::optional N) { auto c_obj = py::cast>(obj.ptr()); return static_cast>(c_obj); } - case TypeKind::SymIntType: - return py::cast(obj); case TypeKind::IntType: - // NB: Typically, these switches are completely dead, because - // Argument::type() will always report IntType for these types. - // So this is a bit overly permissive: we'll accept a dtype - // passed to an int argument, for example. - case TypeKind::LayoutType: - case TypeKind::ScalarTypeType: - case TypeKind::MemoryFormatType: + // TODO: Properly fake this type + if (THPQScheme_Check(obj.ptr())) { + auto qscheme = reinterpret_cast(obj.ptr()); + return static_cast(qscheme->qscheme); + } + // For backwards compatibility if (THPDtype_Check(obj.ptr())) { auto dtype = reinterpret_cast(obj.ptr()); return static_cast(dtype->scalar_type); @@ -107,6 +104,35 @@ IValue toIValue(py::handle obj, const TypePtr& type, c10::optional N) { return static_cast(memory_format->memory_format); } return py::cast(obj); + case TypeKind::LayoutType: { + if (THPLayout_Check(obj.ptr())) { + auto layout = reinterpret_cast(obj.ptr()); + return static_cast(layout->layout); + } + // For backwards compatibility + return py::cast(obj); + } + case TypeKind::ScalarTypeType: { + if (THPDtype_Check(obj.ptr())) { + auto dtype = reinterpret_cast(obj.ptr()); + return static_cast(dtype->scalar_type); + } + // For backwards compatibility + return py::cast(obj); + } + case TypeKind::MemoryFormatType: { + if (THPMemoryFormat_Check(obj.ptr())) { + auto memory_format = reinterpret_cast(obj.ptr()); + return static_cast(memory_format->memory_format); + } + // For backwards compatibility + return py::cast(obj); + } + case TypeKind::SymIntType: + if (torch::is_symint_node(obj.ptr())) { + return py::cast(obj); + } + return py::cast(obj); case TypeKind::NoneType: if (!obj.is_none()) { throw py::cast_error( diff --git a/torch/csrc/jit/python/pybind_utils.h b/torch/csrc/jit/python/pybind_utils.h index cb339b9bb308bf..11a5c6494a00ac 100644 --- a/torch/csrc/jit/python/pybind_utils.h +++ b/torch/csrc/jit/python/pybind_utils.h @@ -646,7 +646,7 @@ inline IValue argumentToIValue( py::handle object) { const auto& argument = schema.arguments().at(argumentPosition); try { - return toIValue(object, argument.type(), argument.N()); + return toIValue(object, argument.real_type(), argument.N()); } catch (const py::cast_error& error) { throw schema_match_error(c10::str( schema.formatTypeMismatchMsg( diff --git a/torch/csrc/jit/runtime/static/ops.cpp b/torch/csrc/jit/runtime/static/ops.cpp index 0439393affa297..96e1869552c0f2 100644 --- a/torch/csrc/jit/runtime/static/ops.cpp +++ b/torch/csrc/jit/runtime/static/ops.cpp @@ -39,6 +39,8 @@ #include #include +#include + C10_DEFINE_bool( static_runtime_enable_fast_math, true, @@ -75,7 +77,7 @@ void repeat_out(at::Tensor& result, const Tensor& self, IntArrayRef repeats) { return; } - Tensor xtensor = at::native::expand(self, padded_size); + Tensor xtensor = at::compositeexplicitautograd::expand(self, padded_size); Tensor urtensor = at::native::alias(result); for (const auto i : c10::irange(xtensor.dim())) { // can't unfold with step 0, so make sure step is at least 1 @@ -2526,12 +2528,13 @@ REGISTER_OPERATOR_FUNCTOR(aten::zeros, aten_zeros, [](Node* n) -> SROperator { const auto dtype = p_node->Input(1).toOptional(); const auto layout = p_node->Input(2).toOptional(); if (!hasTensorWithOptions(p_node->Output(0), dtype, layout)) { - p_node->Output(0) = at::native::zeros(size, dtype, layout); + p_node->Output(0) = at::compositeexplicitautograd::zeros( + size, dtype, layout, c10::nullopt, c10::nullopt); return; } auto& out_t = p_node->Output(0).toTensor(); fastResizeToZero(out_t); - at::native::zeros_out(size, out_t); + at::compositeexplicitautograd::zeros_out(out_t, size); }; }); diff --git a/torch/csrc/lazy/core/ir_builder.h b/torch/csrc/lazy/core/ir_builder.h index 0af188f0131a57..22d70c6fcdd8c6 100644 --- a/torch/csrc/lazy/core/ir_builder.h +++ b/torch/csrc/lazy/core/ir_builder.h @@ -5,6 +5,7 @@ #include #include #include +#include #include #include @@ -266,5 +267,21 @@ static inline NodePtr MakeSizeDiv(const Value& a, const Value& b) { return getIrBuilder()->MakeSizeDiv(a, b); } +inline Value GetSymIntValue(c10::SymInt a) { + return Value( + dynamic_cast(a.toSymIntNodeImpl().get()) + ->node_, + 0); +} + +// TODO: this should return Value +inline std::vector GetSymIntArrayRefValue(c10::SymIntArrayRef arr) { + std::vector r; + for (const auto& a : arr) { + r.emplace_back(a.expect_int()); + } + return r; +} + } // namespace lazy } // namespace torch diff --git a/torch/csrc/lazy/ts_backend/ts_native_functions.cpp b/torch/csrc/lazy/ts_backend/ts_native_functions.cpp index ec32a45b50deca..5787ebc62a4c79 100644 --- a/torch/csrc/lazy/ts_backend/ts_native_functions.cpp +++ b/torch/csrc/lazy/ts_backend/ts_native_functions.cpp @@ -269,30 +269,15 @@ at::Tensor LazyNativeFunctions::_to_copy( } }; -at::Tensor LazyNativeFunctions::empty_symint( - c10::SymIntArrayRef size, - c10::optional dtype, - c10::optional layout, - c10::optional device, - c10::optional pin_memory, - c10::optional memory_format) { - // TODO: support SymIntNodes as well - return empty( - c10::asIntArrayRefSlow(size), - dtype, - layout, - device, - pin_memory, - memory_format); -} - at::Tensor LazyNativeFunctions::empty( - at::IntArrayRef size, + at::SymIntArrayRef sym_size, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory, c10::optional memory_format) { + // TODO: support this directly + auto size = c10::asIntArrayRefSlow(sym_size); const auto device_type = torch::lazy::getBackend()->EagerFallbackDeviceType(); at::TensorOptions options = at::TensorOptions() .device(c10::Device(device_type)) @@ -322,7 +307,13 @@ at::Tensor LazyNativeFunctions::empty_strided( c10::optional device, c10::optional pin_memory) { TORCH_LAZY_FN_COUNTER("lazy::"); - at::Tensor t = empty(size, dtype, layout, device, pin_memory, c10::nullopt); + at::Tensor t = empty( + c10::SymIntArrayRef::fromIntArrayRef(size), + dtype, + layout, + device, + pin_memory, + c10::nullopt); return t.as_strided(size, stride, /*storage_offset=*/0); } @@ -418,7 +409,8 @@ at::Tensor LazyNativeFunctions::_unsafe_view( const at::Tensor& self, at::IntArrayRef size) { TORCH_LAZY_FN_COUNTER("lazy::"); - return LazyNativeFunctions::view_copy(self, size); + return LazyNativeFunctions::view_copy( + self, c10::SymIntArrayRef::fromIntArrayRef(size)); } // This is needed by the torch.tensor constructor. @@ -460,8 +452,8 @@ at::Tensor LazyNativeFunctions::new_empty_strided( at::Tensor LazyNativeFunctions::narrow_copy( const at::Tensor& self, int64_t dim, - int64_t start, - int64_t length) { + c10::SymInt start, + c10::SymInt length) { return at::functionalization::functionalize_aten_op::call(self, dim, start, length); } diff --git a/torchgen/api/autograd.py b/torchgen/api/autograd.py index a22072b744d1ed..19ffd332b70380 100644 --- a/torchgen/api/autograd.py +++ b/torchgen/api/autograd.py @@ -538,7 +538,9 @@ def gen_differentiable_outputs( info = fn.info[key] if fn.info else None outputs: List[DifferentiableOutput] = [ DifferentiableOutput( - name=name, type=ret.type, cpp_type=cpp.return_type(ret).cpp_type() + name=name, + type=ret.type, + cpp_type=cpp.return_type(ret, symint=True).cpp_type(), ) for name, ret in zip(cpp.return_names(f), f.func.returns) ] diff --git a/torchgen/api/cpp.py b/torchgen/api/cpp.py index 408c123a0952fb..128215f2a26704 100644 --- a/torchgen/api/cpp.py +++ b/torchgen/api/cpp.py @@ -64,9 +64,14 @@ # collisions, but functions are fair game to collide -def name(func: FunctionSchema, *, faithful_name_for_out_overloads: bool = False) -> str: +def name( + func: FunctionSchema, + *, + faithful_name_for_out_overloads: bool = False, + symint_overload: bool = False, +) -> str: name = str(func.name.name) - if func.is_symint_fn(): + if symint_overload: name += "_symint" if func.is_out_fn(): if faithful_name_for_out_overloads: @@ -81,11 +86,20 @@ def name(func: FunctionSchema, *, faithful_name_for_out_overloads: bool = False) # types look the same no matter if they are argument types or return # types. Returns None if the type in question is not a value type. def valuetype_type( - t: Type, *, binds: ArgName, remove_non_owning_ref_types: bool = False + t: Type, + *, + binds: ArgName, + remove_non_owning_ref_types: bool = False, + symint: bool = False, ) -> Optional[NamedCType]: if isinstance(t, BaseType): if t.name == BaseTy.Tensor or t.name == BaseTy.Scalar: return None + elif str(t) == "SymInt": + if symint: + return NamedCType(binds, BaseCType(SymIntT)) + else: + return NamedCType(binds, BaseCType(longT)) if remove_non_owning_ref_types: if t.name == BaseTy.str: raise AssertionError( @@ -94,7 +108,7 @@ def valuetype_type( # All other BaseType currently map directly to BaseCppTypes. return NamedCType(binds, BaseCType(BaseTypeToCppMapping[t.name])) elif isinstance(t, OptionalType): - elem = valuetype_type(t.elem, binds=binds) + elem = valuetype_type(t.elem, binds=binds, symint=symint) if elem is None: return None return NamedCType(binds, OptionalCType(elem.type)) @@ -113,11 +127,19 @@ def valuetype_type( # For example, we'll return std::vector instead of IntArrayRef. # See Note [translation from C++ reference to value types] def argumenttype_type( - t: Type, *, mutable: bool, binds: ArgName, remove_non_owning_ref_types: bool = False + t: Type, + *, + mutable: bool, + binds: ArgName, + remove_non_owning_ref_types: bool = False, + symint: bool = False, ) -> NamedCType: # If it's a value type, do the value type translation r = valuetype_type( - t, binds=binds, remove_non_owning_ref_types=remove_non_owning_ref_types + t, + binds=binds, + symint=symint, + remove_non_owning_ref_types=remove_non_owning_ref_types, ) if r is not None: return r @@ -146,7 +168,7 @@ def argumenttype_type( return NamedCType(binds, ConstRefCType(OptionalCType(BaseCType(scalarT)))) elif isinstance(t.elem, ListType) and str(t.elem.elem) == "int": return NamedCType(binds, BaseCType(optionalIntArrayRefT)) - elem = argumenttype_type(t.elem, mutable=mutable, binds=binds) + elem = argumenttype_type(t.elem, mutable=mutable, binds=binds, symint=symint) return NamedCType(binds, OptionalCType(elem.type)) elif isinstance(t, ListType): # TODO: remove these special cases, ArrayRef fallthrough works fine @@ -157,10 +179,16 @@ def argumenttype_type( return NamedCType(binds, BaseCType(intArrayRefT)) if str(t.elem) == "SymInt": if remove_non_owning_ref_types: - return NamedCType(binds, VectorCType(BaseCType(SymIntT))) + if symint: + return NamedCType(binds, VectorCType(BaseCType(SymIntT))) + else: + return NamedCType(binds, VectorCType(BaseCType(longT))) else: - return NamedCType(binds, BaseCType(symIntArrayRefT)) - elif str(t.elem) == "Tensor": + if symint: + return NamedCType(binds, BaseCType(symIntArrayRefT)) + else: + return NamedCType(binds, BaseCType(intArrayRefT)) + if str(t.elem) == "Tensor": return NamedCType(binds, BaseCType(tensorListT)) elif str(t.elem) == "Scalar": return NamedCType(binds, ArrayRefCType(BaseCType(scalarT))) @@ -170,15 +198,15 @@ def argumenttype_type( return NamedCType( binds, ConstRefCType(ListCType(OptionalCType(BaseCType(tensorT)))) ) - elem = argumenttype_type(t.elem, mutable=mutable, binds=binds) + elem = argumenttype_type(t.elem, mutable=mutable, binds=binds, symint=symint) return NamedCType(binds, ArrayRefCType(elem.type)) else: raise AssertionError(f"unrecognized type {repr(t)}") # Translate a JIT argument into its C++ type -def argument_type(a: Argument, *, binds: ArgName) -> NamedCType: - return argumenttype_type(a.type, mutable=a.is_write, binds=binds) +def argument_type(a: Argument, *, binds: ArgName, symint: bool = False) -> NamedCType: + return argumenttype_type(a.type, mutable=a.is_write, symint=symint, binds=binds) # Translation of a (non-multi) return type from JIT to C++ @@ -186,9 +214,9 @@ def argument_type(a: Argument, *, binds: ArgName) -> NamedCType: # This is mostly because of the mismatch between return types and return names. # e.g. a function with a return type of 'void' has 0 return names, # and a function with a return type of 'std::tuple' has >1 return name. -def returntype_type(t: Type, *, mutable: bool) -> CType: +def returntype_type(t: Type, *, mutable: bool, symint: bool = False) -> CType: # placeholder is ignored - r = valuetype_type(t, binds="__placeholder__") + r = valuetype_type(t, binds="__placeholder__", symint=symint) if r is not None: return r.type @@ -211,7 +239,7 @@ def returntype_type(t: Type, *, mutable: bool) -> CType: assert ( not mutable ), "Native functions should never return a mutable tensor list. They should return void." - elem = returntype_type(t.elem, mutable=False) + elem = returntype_type(t.elem, mutable=False, symint=symint) assert t.size is None, f"fixed size list returns not supported: {t}" return VectorCType(elem) @@ -219,18 +247,18 @@ def returntype_type(t: Type, *, mutable: bool) -> CType: # Translation of a single return to its C++ type -def return_type(r: Return) -> CType: - return returntype_type(r.type, mutable=r.is_write) +def return_type(r: Return, *, symint: bool = False) -> CType: + return returntype_type(r.type, mutable=r.is_write, symint=symint) # Translation of a full (possibly multi) return from JIT to its C++ type -def returns_type(rs: Sequence[Return]) -> CType: +def returns_type(rs: Sequence[Return], *, symint: bool = False) -> CType: if len(rs) == 0: return BaseCType(voidT) elif len(rs) == 1: - return return_type(rs[0]) + return return_type(rs[0], symint=symint) else: - return TupleCType([return_type(r) for r in rs]) + return TupleCType([return_type(r, symint=symint) for r in rs]) def return_names(f: NativeFunction, *, fallback_name: str = "result") -> Sequence[str]: @@ -325,6 +353,7 @@ def argument( cpp_no_default_args: Set[str], method: bool, faithful: bool, + symint: bool = False, has_tensor_options: bool, ) -> List[Binding]: def sub_argument( @@ -335,6 +364,7 @@ def sub_argument( cpp_no_default_args=cpp_no_default_args, method=method, faithful=faithful, + symint=symint, has_tensor_options=has_tensor_options, ) @@ -349,7 +379,7 @@ def sub_argument( default = default_expr(a.default, a.type) return [ Binding( - nctype=argument_type(a, binds=binds), + nctype=argument_type(a, binds=binds, symint=symint), name=a.name, default=default, argument=a, @@ -390,7 +420,12 @@ def sub_argument( def arguments( - arguments: Arguments, *, faithful: bool, method: bool, cpp_no_default_args: Set[str] + arguments: Arguments, + *, + faithful: bool, + symint: bool = False, + method: bool, + cpp_no_default_args: Set[str], ) -> List[Binding]: args: List[Union[Argument, TensorOptionsArguments, SelfArgument]] = [] if faithful: @@ -405,6 +440,7 @@ def arguments( for r in argument( a, faithful=faithful, + symint=symint, method=method, has_tensor_options=arguments.tensor_options is not None, cpp_no_default_args=cpp_no_default_args, diff --git a/torchgen/api/dispatcher.py b/torchgen/api/dispatcher.py index 008e8c5664a472..aaab73ef737819 100644 --- a/torchgen/api/dispatcher.py +++ b/torchgen/api/dispatcher.py @@ -45,6 +45,7 @@ def argumenttype_type( t, mutable=mutable, binds=binds, + symint=True, remove_non_owning_ref_types=remove_non_owning_ref_types, ) @@ -62,7 +63,7 @@ def argument_type( def returns_type(rs: Sequence[Return]) -> CType: # At present, there is no difference. But there could be! - return cpp.returns_type(rs) + return cpp.returns_type(rs, symint=True) def jit_arguments(func: FunctionSchema) -> List[Argument]: diff --git a/torchgen/api/lazy.py b/torchgen/api/lazy.py index 6bce9db92bdb24..840400f0da19f2 100644 --- a/torchgen/api/lazy.py +++ b/torchgen/api/lazy.py @@ -37,6 +37,14 @@ _valueT = None +# A ValueT is an IR type which represents the computation of a Tensor. In other +# words, a PyTorch user will do operations on lazy tensors, and each output lazy +# tensor internally tracks a ValueT representing the IR node that would have +# actually produced the value of this tensor for real. +# +# This is configurable because different lazy tensor backends (LTC vs XLA) will +# have different IR representations. (Though, arguably, after unification they +# shouldn't!) def getValueT() -> BaseCppType: global _valueT if not _valueT: @@ -113,12 +121,27 @@ def process_ir_type( elif str(typ.elem) == "Tensor": # this is a TensorList which comes in from GetTensorList as a Value return BaseCType(tensorListValueT) + elif typ.elem == BaseType(BaseTy.SymInt): + # TODO: return a value type. The problem here is analogous to + # the problem with tensorListValueT: if you have SymInt[] you + # cannot conveniently save the list of Value directly, as nodes + # expect to save values as a vector for ALL arguments. So you + # need a separate IR node that represents all of the size nodes + # assembled into a list. I'm not an LTC dev so I don't want to + # figure it out right now. Y'all figure it out... + return VectorCType(BaseCType(longT)) + else: return VectorCType(process_ir_type(typ.elem, properties)) else: raise AssertionError(f"unrecognized type {repr(typ)}") +# TODO: Determining this based off of CType is bad; this should be computed +# from Type directly; then the same logic as process_ir_type can be used +# +# Invariant: passed typ should be an *owning* CType (e.g., we will report +# that ArrayRef is NOT a value type) def isValueType(typ: CType, properties: "Optional[LazyIrProperties]" = None) -> bool: """ Given a type, determine if it is a Value-like type. This is equivalent to @@ -133,6 +156,9 @@ def isValueType(typ: CType, properties: "Optional[LazyIrProperties]" = None) -> or (typ.type == scalarT and not treat_scalars_as_constants) or typ.type == SymIntT ) + elif typ == VectorCType(BaseCType(SymIntT)): + # TODO: report True for this + return False elif isinstance(typ, (OptionalCType, ListCType, VectorCType)): return isValueType(typ.elem, properties) return False @@ -157,6 +183,7 @@ def isWrappedScalarType(typ: Type) -> bool: return False +# TODO: dedupe with Type.is_generator_like def isGeneratorType(typ: Type) -> bool: if isinstance(typ, BaseType): return typ.name == BaseTy.Generator @@ -165,12 +192,15 @@ def isGeneratorType(typ: Type) -> bool: return False +# This class caches a few derived properties computed from an Argument +# and LazyIrProperties class LazyArgument: name: str orig_type: Type lazy_type_: Optional[CType] is_wrapped_scalar: bool is_generator: bool + # TODO: this is lies, it is false for symint list is_symint_or_list: bool # true if this argument is or contains a lazy IR value @@ -192,7 +222,11 @@ def __init__(self, arg: Argument, properties: "LazyIrProperties"): else: self.lazy_type_ = process_ir_type(arg.type, properties) self.is_wrapped_scalar = isWrappedScalarType(arg.type) - self.is_symint_or_list = isSymIntType(arg.type) + self.is_symint_or_list = ( + isSymIntType(arg.type) + # TODO: lists of symints are not currently treated as value types + # or (isinstance(arg.type, ListType) and isSymIntType(arg.type.elem)) + ) self.is_lazy_value = not self.is_generator and isValueType( self.lazy_type, properties @@ -268,6 +302,8 @@ def __setattr__(self, key: str, value: Any) -> Any: # Unlike a FunctionSchema, it has no round-trippable string form (relating to the YAML), # but carries type information from a native FunctionSchema modified for use with IR nodes, # and preserving original argument names. +# +# TODO: This is not idiomatic with how other torchgen APIs transform on schema. class LazyIrSchema: # The name of the operator this function schema describes. name: "OperatorName" diff --git a/torchgen/api/native.py b/torchgen/api/native.py index 16814e34867c01..b197a2a02983a9 100644 --- a/torchgen/api/native.py +++ b/torchgen/api/native.py @@ -37,6 +37,9 @@ # native:: kernels. The intention is to make native API and dispatcher API # line up as closely as possible, since this results in the least overhead # (no translation is needed from dispatcher API to native API). +# +# NB: this is symint aware, you will get the non-SymInt variant for some +# dispatch entries and SymInt for others. def name(func: FunctionSchema) -> str: @@ -49,7 +52,9 @@ def name(func: FunctionSchema) -> str: return name -def argumenttype_type(t: Type, *, mutable: bool, binds: ArgName) -> NamedCType: +def argumenttype_type( + t: Type, *, mutable: bool, binds: ArgName, symint: bool +) -> NamedCType: if str(t) == "Tensor?": tensor_type: OptionalCType = OptionalCType(BaseCType(tensorT)) if mutable and not local.use_const_ref_for_mutable_tensors(): @@ -64,19 +69,22 @@ def argumenttype_type(t: Type, *, mutable: bool, binds: ArgName) -> NamedCType: return NamedCType(binds, ConstRefCType(BaseCType(scalarT))) elif str(t) == "Scalar?": return NamedCType(binds, ConstRefCType(OptionalCType(BaseCType(scalarT)))) - return cpp.argumenttype_type(t, mutable=mutable, binds=binds) + return cpp.argumenttype_type(t, mutable=mutable, binds=binds, symint=symint) -def returns_type(rs: Sequence[Return]) -> CType: - return cpp.returns_type(rs) +def returns_type(rs: Sequence[Return], *, symint: bool) -> CType: + return cpp.returns_type(rs, symint=symint) -def argument_type(a: Argument, *, binds: ArgName) -> NamedCType: - return argumenttype_type(a.type, mutable=a.is_write, binds=binds) +def argument_type(a: Argument, *, binds: ArgName, symint: bool) -> NamedCType: + return argumenttype_type(a.type, mutable=a.is_write, binds=binds, symint=symint) def argument( - a: Union[Argument, SelfArgument, TensorOptionsArguments], *, is_out: bool + a: Union[Argument, SelfArgument, TensorOptionsArguments], + *, + is_out: bool, + symint: bool, ) -> List[Binding]: # Ideally, we NEVER default native functions. However, there are a number # of functions that call native:: directly and rely on the defaulting @@ -90,7 +98,7 @@ def argument( default = cpp.default_expr(a.default, a.type) return [ Binding( - nctype=argument_type(a, binds=a.name), + nctype=argument_type(a, binds=a.name, symint=symint), name=a.name, default=default, argument=a, @@ -98,7 +106,7 @@ def argument( ] elif isinstance(a, SelfArgument): # Erase SelfArgument from the distinction - return argument(a.argument, is_out=is_out) + return argument(a.argument, is_out=is_out, symint=symint) elif isinstance(a, TensorOptionsArguments): default = None if should_default: @@ -136,8 +144,10 @@ def argument( assert_never(a) -def arguments(func: FunctionSchema) -> List[Binding]: +def arguments(func: FunctionSchema, *, symint: bool) -> List[Binding]: args: List[Union[Argument, TensorOptionsArguments, SelfArgument]] = [] args.extend(func.arguments.non_out) args.extend(func.arguments.out) - return [r for arg in args for r in argument(arg, is_out=func.is_out_fn())] + return [ + r for arg in args for r in argument(arg, symint=symint, is_out=func.is_out_fn()) + ] diff --git a/torchgen/api/python.py b/torchgen/api/python.py index 0ff67f1365edca..aa00b2b7fca68e 100644 --- a/torchgen/api/python.py +++ b/torchgen/api/python.py @@ -216,8 +216,12 @@ class PythonArgument: # Compute argument formal for python argument parsing. # Needs to be consistent with torch/csrc/utils/python_arg_parser.h. - def argument_str(self, *, method: bool = False) -> str: - type_str = argument_type_str(self.type).replace("const ", "").replace(" &", "") + def argument_str(self, *, method: bool = False, symint: bool = True) -> str: + type_str = ( + argument_type_str(self.type, symint=symint) + .replace("const ", "") + .replace(" &", "") + ) name = self.name # s/self/input/ outside method bindings @@ -384,10 +388,10 @@ def output_idx(self) -> int: # # For a translation to mypy-valid type signatures, see # signature_str_pyi(). - def signature_str(self, *, skip_outputs: bool = False) -> str: + def signature_str(self, *, skip_outputs: bool = False, symint: bool = True) -> str: args = self.arguments(skip_outputs=skip_outputs) schema_formals: List[str] = list( - map(lambda a: a.argument_str(method=self.method), args) + map(lambda a: a.argument_str(method=self.method, symint=symint), args) ) positional_argc = len(self.input_args) if len(schema_formals) > positional_argc: @@ -426,7 +430,7 @@ def signature_str_pyi_vararg(self, *, skip_outputs: bool = False) -> Optional[st vararg_type = args[0].type if ( isinstance(vararg_type, ListType) - and str(vararg_type.elem) == "int" + and str(vararg_type.elem) in ["int", "SymInt"] and num_positionalargs == 1 ): have_vararg_version = True @@ -464,9 +468,11 @@ class PythonSignatureDeprecated(PythonSignature): def deprecated(self) -> bool: return True - def signature_str(self, *, skip_outputs: bool = False) -> str: + def signature_str(self, *, skip_outputs: bool = False, symint: bool = True) -> str: return ( - PythonSignature.signature_str(self, skip_outputs=skip_outputs) + PythonSignature.signature_str( + self, skip_outputs=skip_outputs, symint=symint + ) + "|deprecated" ) @@ -633,7 +639,9 @@ def has_tensor_options(f: NativeFunction) -> bool: # 'simple_type' was introduced by the old codegen, which is slightly # different from the python schema type, e.g.: doesn't have '?' suffix # for optional Tensor/TensorList; doesn't have '[size]' suffix for list type. -def argument_type_str(t: Type, *, simple_type: bool = False) -> str: +def argument_type_str( + t: Type, *, simple_type: bool = False, symint: bool = True +) -> str: if isinstance(t, BaseType): if t.name == BaseTy.Tensor: return "Tensor" @@ -665,7 +673,7 @@ def argument_type_str(t: Type, *, simple_type: bool = False) -> str: if str(t.elem) == "Tensor": # Is it desired to keep '?' for simple_type with new style dispatcher? return "Tensor?" - elem = argument_type_str(t.elem, simple_type=simple_type) + elem = argument_type_str(t.elem, simple_type=simple_type, symint=symint) return f"{elem}?" elif isinstance(t, ListType): size = t.size if not simple_type else None @@ -675,7 +683,12 @@ def argument_type_str(t: Type, *, simple_type: bool = False) -> str: elif str(t.elem) == "int": return f"IntArrayRef[{size}]" if size is not None else "IntArrayRef" elif str(t.elem) == "SymInt": - return f"SymIntArrayRef[{size}]" if size is not None else "SymIntArrayRef" + if symint: + return ( + f"SymIntArrayRef[{size}]" if size is not None else "SymIntArrayRef" + ) + else: + return f"IntArrayRef[{size}]" if size is not None else "IntArrayRef" elif str(t.elem) == "Tensor": return f"TensorList[{size}]" if size is not None else "TensorList" elif str(t.elem) == "Scalar": @@ -687,7 +700,7 @@ def argument_type_str(t: Type, *, simple_type: bool = False) -> str: return "const c10::List> &" elif str(t.elem) == "Dimname": return f"DimnameList[{size}]" if size is not None else "DimnameList" - elem = argument_type_str(t.elem, simple_type=simple_type) + elem = argument_type_str(t.elem, simple_type=simple_type, symint=symint) return f"ArrayRef<{elem}>" raise RuntimeError(f"unrecognized type {repr(t)}") @@ -898,7 +911,7 @@ def argument_type_str_pyi(t: Type) -> str: if t.name == BaseTy.int: ret = "_int" if t.name == BaseTy.SymInt: - ret = "SymInt" + ret = "Union[_int, SymInt]" elif t.name == BaseTy.float: ret = "_float" elif t.name == BaseTy.str: @@ -1040,7 +1053,7 @@ def returns_str_pyi(signature: PythonSignature) -> str: def dispatch_lambda_args( - ps: PythonSignature, f: NativeFunction + ps: PythonSignature, f: NativeFunction, symint: bool = True ) -> Tuple[DispatchLambdaArgument, ...]: if isinstance(ps, PythonSignatureDeprecated): schema = ps.deprecated_schema @@ -1051,6 +1064,7 @@ def dispatch_lambda_args( cpp_args = cpp.arguments( arguments=schema.arguments, faithful=False, + symint=symint, method=False, cpp_no_default_args=f.cpp_no_default_args, ) @@ -1133,14 +1147,15 @@ def dispatch_lambda_return_str(f: NativeFunction) -> str: returns_without_annotation = tuple( map(lambda r: Return(r.name, r.type, None), f.func.returns) ) - return_str = cpp.returns_type(returns_without_annotation).cpp_type() + return_str = cpp.returns_type(returns_without_annotation, symint=True).cpp_type() if return_str not in SUPPORTED_RETURN_TYPES: raise RuntimeError(f"{f.func.name} returns unsupported type {return_str}") return return_str def cpp_dispatch_target(f: NativeFunction) -> str: - name = cpp.name(f.func) + symint = f.func.has_symint() + name = cpp.name(f.func, symint_overload=symint) if Variant.method in f.variants: return f"self.{name}" if Variant.function in f.variants: @@ -1192,7 +1207,7 @@ def cpp_dispatch_exprs( # For certain cases it is intentionally more restrictive than necessary, # e.g.: it doesn't accepts doublelist with definite size. def arg_parser_unpack_method( - t: Type, default: Optional[str], default_init: Optional[str] + t: Type, default: Optional[str], default_init: Optional[str], *, symint: bool = True ) -> str: has_default_init = default_init is not None if has_default_init and str(t) not in ( @@ -1224,7 +1239,10 @@ def arg_parser_unpack_method( elif t.name == BaseTy.int: return "toInt64" elif t.name == BaseTy.SymInt: - return "toSymInt" + if symint: + return "toSymInt" + else: + return "toInt64" elif t.name == BaseTy.bool: return "toBoolWithDefault" if has_default_init else "toBool" elif t.name == BaseTy.float: @@ -1245,10 +1263,14 @@ def arg_parser_unpack_method( return "toDimnameListOptional" elif not has_default_init and default in (None, "None", "c10::nullopt"): # If default is None: append 'Optional' to elem's unpacking method - return arg_parser_unpack_method(t.elem, None, None) + "Optional" + return ( + arg_parser_unpack_method(t.elem, None, None, symint=symint) + "Optional" + ) else: # Otherwise, load as underlying type with default - return arg_parser_unpack_method(t.elem, default, default_init) + return arg_parser_unpack_method( + t.elem, default, default_init, symint=symint + ) elif isinstance(t, ListType): if str(t.elem) == "Tensor": @@ -1269,7 +1291,10 @@ def arg_parser_unpack_method( return "doublelist" elif str(t.elem) == "SymInt": # accept definite size - return "symintlist" + if symint: + return "symintlist" + else: + return "intlist" elif str(t) == "Scalar[]": return "scalarlist" raise RuntimeError(f"type '{t}' is not supported by PythonArgParser") @@ -1278,11 +1303,11 @@ def arg_parser_unpack_method( # Return RHS expression for python argument using PythonArgParser output. # e.g. for arg name 'foo', arg type 'bool', arg_index = 2, returns '_r.toBool(2)' def arg_parser_output_expr( - arg_index: int, a: PythonArgument + arg_index: int, a: PythonArgument, *, symint: bool = True ) -> PythonArgParserOutputExpr: has_default = a.default_init is not None unpack_method = arg_parser_unpack_method( - t=a.type, default=a.default, default_init=a.default_init + t=a.type, default=a.default, default_init=a.default_init, symint=symint ) default = f", {a.default_init}" if has_default else "" expr = f"_r.{unpack_method}({arg_index}{default})" @@ -1297,12 +1322,12 @@ def arg_parser_output_expr( # Returns a map with key = arg_name and value = PythonArgParserOutputExpr. def arg_parser_output_exprs( - ps: PythonSignature, f: NativeFunction + ps: PythonSignature, f: NativeFunction, *, symint: bool = True ) -> Dict[str, PythonArgParserOutputExpr]: return { e.name: e for i, a in enumerate(ps.arguments()) - for e in (arg_parser_output_expr(i, a),) + for e in (arg_parser_output_expr(i, a, symint=symint),) } @@ -1317,13 +1342,13 @@ def arg_parser_output_exprs( # bind arg parser outputs (python args) with dispatch lambda arguments (c++ args). def dispatch_lambda_exprs( - ps: PythonSignature, f: NativeFunction + ps: PythonSignature, f: NativeFunction, *, symint: bool = True ) -> DispatchLambdaArgumentExprs: # This method is to bind 'arg_parser_outputs' and 'lambda_args' by producing # 'inits' and 'lambda_args_exprs' for each lambda argument using arg parser # outputs. - arg_parser_outputs = arg_parser_output_exprs(ps, f) - lambda_args = dispatch_lambda_args(ps, f) + arg_parser_outputs = arg_parser_output_exprs(ps, f, symint=symint) + lambda_args = dispatch_lambda_args(ps, f, symint=symint) inits: List[str] = [] lambda_args_exprs: Dict[str, str] = {} diff --git a/torchgen/api/structured.py b/torchgen/api/structured.py index 4787adccae6b30..a5ab3f6e54320d 100644 --- a/torchgen/api/structured.py +++ b/torchgen/api/structured.py @@ -42,7 +42,12 @@ # some more nominal types def argumenttype_type(t: Type, *, mutable: bool, binds: ArgName) -> NamedCType: # If it's a value type, do the value type translation - r = cpp.valuetype_type(t, binds=binds) + # NB: structured kernels ALWAYS have symint off, since they involve actual + # kernels that require real ints. The one exception is the + # CompositeExplicitAutograd and the meta function (which could + # hypothetically be SymInt), but for simplicity we plan for these to just + # be handled in Python + r = cpp.valuetype_type(t, symint=False, binds=binds) if r is not None: return r diff --git a/torchgen/api/translate.py b/torchgen/api/translate.py index bee33b473dc945..de3e43a288964c 100644 --- a/torchgen/api/translate.py +++ b/torchgen/api/translate.py @@ -337,10 +337,16 @@ def direct_solve(goal: NamedCType) -> str: ) return f"c10::asIntArrayRefSlow({symIntArrayRef_type})" elif goal.type == BaseCType(symIntArrayRefT): - return direct_solve(NamedCType(goal.name, longSymVec_ctype)) + try: + r = direct_solve(NamedCType(goal.name, BaseCType(intArrayRefT))) + return f"c10::SymIntArrayRef::fromIntArrayRef({r})" + except UnsatError: + return direct_solve(NamedCType(goal.name, longSymVec_ctype)) + elif goal.type == BaseCType(SymIntT): + return direct_solve(NamedCType(goal.name, BaseCType(longT))) elif goal.type == BaseCType(longT): symInt_type = direct_solve(NamedCType(goal.name, BaseCType(SymIntT))) - return f"{symInt_type}.expectInt()" + return f"{symInt_type}.expect_int()" elif goal.type == BaseCType(optionalIntArrayRefT): return direct_solve(NamedCType(goal.name, optionalLongVec_ctype)) elif goal.type == BaseCType(optionalScalarRefT): diff --git a/torchgen/api/types.py b/torchgen/api/types.py index a3f5e44e461c63..9eacacf2fd9d32 100644 --- a/torchgen/api/types.py +++ b/torchgen/api/types.py @@ -1,6 +1,6 @@ from dataclasses import dataclass from enum import Enum -from typing import Dict, Iterator, List, Optional, Sequence, Set, TypeVar, Union +from typing import Dict, Iterator, List, Optional, Sequence, Set, Tuple, TypeVar, Union from torchgen.model import ( Argument, @@ -417,6 +417,11 @@ class CppSignature: # (i.e. with a potential TensorOptions argument and out arguments in the front) faithful: bool + # Is this a symint C++ signature. For BC reasons, functions that take + # SymInts still present as int64_t in C++, and the SymInt variant is + # offered at a different overload name + symint: bool + # The set of C++ arguments which should not have defaults applied to them cpp_no_default_args: Set[str] @@ -433,12 +438,17 @@ def arguments(self) -> Sequence[Binding]: return cpp.arguments( self.func.arguments, faithful=self.faithful, + symint=self.symint, method=self.method, cpp_no_default_args=self.cpp_no_default_args, ) def name(self) -> str: - n = cpp.name(self.func, faithful_name_for_out_overloads=self.faithful) + n = cpp.name( + self.func, + faithful_name_for_out_overloads=self.faithful, + symint_overload=self.symint, + ) if self.fallback_binding: n = f"__dispatch_{n}" return n @@ -451,7 +461,9 @@ def decl( prefix: str = "", is_redispatching_fn: bool = False, ) -> str: - returns_type = cpp.returns_type(self.func.returns).cpp_type() + returns_type = cpp.returns_type( + self.func.returns, symint=self.symint + ).cpp_type() cpp_args = [a.decl() for a in self.arguments()] if is_redispatching_fn: cpp_args = ["c10::DispatchKeySet dispatchKeySet"] + cpp_args @@ -469,7 +481,9 @@ def defn( prefix: str = "", is_redispatching_fn: bool = False, ) -> str: - returns_type = cpp.returns_type(self.func.returns).cpp_type() + returns_type = cpp.returns_type( + self.func.returns, symint=self.symint + ).cpp_type() cpp_args = [a.defn() for a in self.arguments()] if is_redispatching_fn: cpp_args = ["c10::DispatchKeySet dispatchKeySet"] + cpp_args @@ -480,12 +494,12 @@ def defn( def ptr_type(self) -> str: args_types_str = ", ".join(a.type for a in self.arguments()) - return f"{cpp.returns_type(self.func.returns).cpp_type()} (*)({args_types_str})" + return f"{cpp.returns_type(self.func.returns, symint=self.symint).cpp_type()} (*)({args_types_str})" # Return the C++ function type, e.g., something like int(bool) def type(self) -> str: args_types_str = ", ".join(a.type for a in self.arguments()) - return f"{cpp.returns_type(self.func.returns).cpp_type()} ({args_types_str})" + return f"{cpp.returns_type(self.func.returns, symint=self.symint).cpp_type()} ({args_types_str})" # Represents group of all CppSignatures associated with a @@ -497,6 +511,8 @@ class CppSignatureGroup: func: FunctionSchema signature: CppSignature faithful_signature: Optional[CppSignature] + symint_signature: Optional[CppSignature] + symint_faithful_signature: Optional[CppSignature] def most_faithful_signature(self) -> CppSignature: if self.faithful_signature: @@ -508,6 +524,10 @@ def signatures(self) -> Iterator[CppSignature]: yield self.signature if self.faithful_signature: yield self.faithful_signature + if self.symint_signature: + yield self.symint_signature + if self.symint_faithful_signature: + yield self.symint_faithful_signature @staticmethod def from_native_function( @@ -515,23 +535,35 @@ def from_native_function( ) -> "CppSignatureGroup": func = f.func - def make_sig(*, faithful: bool) -> CppSignature: + def make_sig(*, faithful: bool, symint: bool) -> CppSignature: return CppSignature( func=func, faithful=faithful, + symint=symint, method=method, fallback_binding=fallback_binding, cpp_no_default_args=f.cpp_no_default_args, ) - faithful_signature: Optional[CppSignature] = None - if func.arguments.tensor_options is not None or len(func.arguments.out) > 0: - faithful_signature = make_sig(faithful=True) - signature = make_sig(faithful=False) + def make_sigs(*, symint: bool) -> Tuple[CppSignature, Optional[CppSignature]]: + faithful_signature: Optional[CppSignature] = None + if func.arguments.tensor_options is not None or len(func.arguments.out) > 0: + faithful_signature = make_sig(faithful=True, symint=symint) + signature = make_sig(faithful=False, symint=symint) + return signature, faithful_signature + + signature, faithful_signature = make_sigs(symint=False) + symint_signature: Optional[CppSignature] = None + symint_faithful_signature: Optional[CppSignature] = None + if func.has_symint(): + symint_signature, symint_faithful_signature = make_sigs(symint=True) + return CppSignatureGroup( func=func, signature=signature, faithful_signature=faithful_signature, + symint_signature=symint_signature, + symint_faithful_signature=symint_faithful_signature, ) @@ -593,6 +625,8 @@ class NativeSignature: # The schema this signature is derived from func: FunctionSchema + symint: bool + prefix: str = "" def name(self) -> str: @@ -602,24 +636,24 @@ def decl(self, name: Optional[str] = None) -> str: args_str = ", ".join(a.decl() for a in self.arguments()) if name is None: name = self.name() - return f"{native.returns_type(self.func.returns).cpp_type()} {name}({args_str})" + return f"{native.returns_type(self.func.returns, symint=self.symint).cpp_type()} {name}({args_str})" def defn(self, name: Optional[str] = None) -> str: args_str = ", ".join(a.defn() for a in self.arguments()) if name is None: name = self.name() - return f"{native.returns_type(self.func.returns).cpp_type()} {name}({args_str})" + return f"{native.returns_type(self.func.returns, symint=self.symint).cpp_type()} {name}({args_str})" def ptr_type(self) -> str: # don't include defaults in type signature! args_str = ", ".join(a.defn() for a in self.arguments()) - return f"{native.returns_type(self.func.returns).cpp_type()} (*)({args_str})" + return f"{native.returns_type(self.func.returns, symint=self.symint).cpp_type()} (*)({args_str})" def arguments(self) -> List[Binding]: - return native.arguments(self.func) + return native.arguments(self.func, symint=self.symint) def returns_type(self) -> CType: - return native.returns_type(self.func.returns) + return native.returns_type(self.func.returns, symint=self.symint) def dispatcher_exprs(self) -> List[Expr]: return translate.translate( @@ -745,9 +779,14 @@ def kernel_signature( # With external backends, we'd like to enforce that they write their kernels with schemas # that match the Dispatcher API directly, if they can. if backend_index.external: + # Dispatcher signature faithfully does SymInt, which is good for XLA, + # not so good for more conventional backends but we don't have any of + # those. If we do, that's time to add a new Signature that is a cross + # between DispatcherSignature and NativeSignature + assert backend_index.symint return DispatcherSignature.from_schema(f.func, prefix=prefix) else: - return NativeSignature(f.func, prefix) + return NativeSignature(f.func, prefix=prefix, symint=backend_index.symint) # Functions only, no types diff --git a/torchgen/api/ufunc.py b/torchgen/api/ufunc.py index 34384ce340d535..7f044706068cf9 100644 --- a/torchgen/api/ufunc.py +++ b/torchgen/api/ufunc.py @@ -40,7 +40,8 @@ def kernel_name(g: NativeFunctionsGroup, dispatch_key: DispatchKey) -> str: # # NB: used for CPU only def dispatchstub_type(t: Type, *, binds: ArgName) -> Optional[NamedCType]: - r = cpp.valuetype_type(t, binds=binds) + # Dispatch stubs are always plain ints + r = cpp.valuetype_type(t, binds=binds, symint=False) if r is not None: return r @@ -64,7 +65,7 @@ def opmath_type(scalar_t: BaseCppType) -> BaseCppType: # # NB: CUDA only def ufunctor_ctor_type(t: Type, *, binds: ArgName, scalar_t: BaseCppType) -> NamedCType: - r = cpp.valuetype_type(t, binds=binds) + r = cpp.valuetype_type(t, binds=binds, symint=False) if r is not None: return r @@ -93,7 +94,7 @@ def ufunctor_apply_type( # is done in the computation type. compute_t is opmath_t in CUDA and scalar_t # in CPU def ufunc_type(t: Type, *, binds: ArgName, compute_t: CType) -> NamedCType: - r = cpp.valuetype_type(t, binds=binds) + r = cpp.valuetype_type(t, binds=binds, symint=False) if r is not None: return r diff --git a/torchgen/api/unboxing.py b/torchgen/api/unboxing.py index b5afdc099fa9d4..a1bca50538647e 100644 --- a/torchgen/api/unboxing.py +++ b/torchgen/api/unboxing.py @@ -136,7 +136,10 @@ def convert_arguments(f: NativeFunction) -> Tuple[List[Binding], List[str]]: def argumenttype_ivalue_convert( t: Type, arg_name: str, *, mutable: bool = False ) -> Tuple[str, CType, List[str], List[str]]: - ctype = cpp.argumenttype_type(t=t, mutable=mutable, binds=arg_name).type + # Unboxing is for mobile, which doesn't care about SymInts + ctype = cpp.argumenttype_type( + t=t, mutable=mutable, binds=arg_name, symint=False + ).type if isinstance(t, BaseType): out_name = f"{arg_name}_base" diff --git a/torchgen/dest/lazy_ir.py b/torchgen/dest/lazy_ir.py index c3c86822670448..7dfb166d8e9352 100644 --- a/torchgen/dest/lazy_ir.py +++ b/torchgen/dest/lazy_ir.py @@ -27,7 +27,10 @@ from torchgen.model import ( Argument, BackendIndex, + BaseTy, + BaseType, FunctionSchema, + ListType, NativeFunction, NativeFunctionsGroup, ) @@ -40,6 +43,7 @@ def node_ctor_arg_rvalue_string(arg: LazyArgument) -> str: a lazy Node constructor. """ + # TODO: Matching on CType seems wrong; should be matching on Type if isValueType(arg.lazy_type): if isinstance(arg.lazy_type, BaseCType): if arg.is_wrapped_scalar: @@ -48,7 +52,7 @@ def node_ctor_arg_rvalue_string(arg: LazyArgument) -> str: return f"lazy_{arg.name}_tensorlist" elif arg.is_symint_or_list: cpp_type = arg.lazy_type.cpp_type() - return f"{cpp_type}(dynamic_cast({arg.name}.toSymIntNodeImpl().get())->node_, 0)" + return f"GetSymIntValue({arg.name})" return f"lazy_{arg.name}->GetIrValue()" elif isinstance(arg.lazy_type, OptionalCType): if arg.is_wrapped_scalar: @@ -63,7 +67,15 @@ def node_ctor_arg_rvalue_string(arg: LazyArgument) -> str: f"TODO not sure if there are other valid types to handle here ({arg.lazy_type})" ) else: - if isinstance(arg.lazy_type, VectorCType) and isinstance( + # NB: this is here because right now we aren't treating SymInt[] as a + # value type; when we do this needs to move above + # NB: we cannot test arg.lazy_type as we've already specified it is an + # int64_t and so we cannot distinguish between SymInt and int64_t + if isinstance(arg.orig_type, ListType) and arg.orig_type.elem == BaseType( + BaseTy.SymInt + ): + return f"GetSymIntArrayRefValue({arg.name})" + elif isinstance(arg.lazy_type, VectorCType) and isinstance( arg.lazy_type.elem, BaseCType ): return f"std::vector<{arg.lazy_type.elem.type}>({arg.name}.begin(), {arg.name}.end())" @@ -500,9 +512,13 @@ def this_shape(i: int) -> str: dispatch_ns = "compositeexplicitautogradnonfunctional" else: dispatch_ns = "meta" + aten_name = schema.aten_name + # TODO: this is trolling + if func.func.has_symint(): + aten_name += "_symint" shape_str = f"""\ {meta_conversion_str} - auto out_meta = at::{dispatch_ns}::{schema.aten_name}({', '.join(meta_call_args)}); + auto out_meta = at::{dispatch_ns}::{aten_name}({', '.join(meta_call_args)}); {meta_out}""" else: shape_sig = ComputeShapeSignature(metadata.kernel, func) diff --git a/torchgen/dest/register_dispatch_key.py b/torchgen/dest/register_dispatch_key.py index f7a3ef7bb64482..b9083a9db5a9a2 100644 --- a/torchgen/dest/register_dispatch_key.py +++ b/torchgen/dest/register_dispatch_key.py @@ -287,8 +287,8 @@ def wrapper_kernel_sig( self, f: NativeFunction ) -> Union[NativeSignature, DispatcherSignature]: # The prefix is just to ensure uniqueness. The Dispatcher API doesn't guarantee unique kernel names. - return kernel_signature( - f, self.backend_index, prefix=f"wrapper_{f.func.name.overload_name}_" + return DispatcherSignature.from_schema( + f.func, prefix=f"wrapper_{f.func.name.overload_name}_" ) def gen_out_inplace_wrapper( @@ -407,10 +407,11 @@ def gen_unstructured( f, method=False, fallback_binding=False ) + # TODO: dedupe this with the structured codegen if self.target is Target.NAMESPACED_DECLARATION: - result = f"TORCH_API {cpp_sig_group.signature.decl()};\n" - if cpp_sig_group.faithful_signature is not None: - result += f"TORCH_API {cpp_sig_group.faithful_signature.decl()};\n" + result = "" + for cpp_sig in cpp_sig_group.signatures(): + result += f"TORCH_API {cpp_sig.decl()};\n" return result elif self.target is Target.NAMESPACED_DEFINITION: @@ -421,10 +422,11 @@ def generate_defn(cpp_sig: CppSignature) -> str: }} """ - result = generate_defn(cpp_sig_group.signature) - if cpp_sig_group.faithful_signature is not None: - result += generate_defn(cpp_sig_group.faithful_signature) + result = "" + for cpp_sig in cpp_sig_group.signatures(): + result += generate_defn(cpp_sig) return result + elif self.target is Target.ANONYMOUS_DEFINITION: # short circuit for inplace_meta if inplace_meta: @@ -451,7 +453,14 @@ def generate_defn(cpp_sig: CppSignature) -> str: else: impl_name = f"{metadata.cpp_namespace}::{self.class_method_name}::{metadata.kernel}" - args_exprs_str = ", ".join(a.name for a in args) + kernel_sig = kernel_signature(f, self.backend_index) + + args_exprs_str = ", ".join( + e.expr + for e in translate( + sig.arguments(), kernel_sig.arguments(), method=False + ) + ) device_check = " // No device check\n" # Backends that require device guards presumably also require device checks. @@ -741,12 +750,14 @@ def gen_one(self, f: NativeFunction) -> Optional[str]: ) # Signature of the wrapper function we'll register to the dispatcher - sig = NativeSignature(f.func, prefix="wrapper_") + sig = NativeSignature( + f.func, prefix="wrapper_", symint=self.backend_index.symint + ) if self.target is Target.NAMESPACED_DECLARATION: - result = f"TORCH_API {cpp_sig_group.signature.decl()};\n" - if cpp_sig_group.faithful_signature is not None: - result += f"TORCH_API {cpp_sig_group.faithful_signature.decl()};\n" + result = "" + for cpp_sig in cpp_sig_group.signatures(): + result += f"TORCH_API {cpp_sig.decl()};\n" return result elif self.target is Target.NAMESPACED_DEFINITION: @@ -758,9 +769,9 @@ def generate_defn(cpp_sig: CppSignature) -> str: }} """ - result = generate_defn(cpp_sig_group.signature) - if cpp_sig_group.faithful_signature is not None: - result += generate_defn(cpp_sig_group.faithful_signature) + result = "" + for cpp_sig in cpp_sig_group.signatures(): + result += generate_defn(cpp_sig) return result elif self.target is Target.ANONYMOUS_DEFINITION: diff --git a/torchgen/gen.py b/torchgen/gen.py index c10f62932aff4f..660e54dfdf34fd 100644 --- a/torchgen/gen.py +++ b/torchgen/gen.py @@ -37,7 +37,6 @@ gen_functionalization_definition, gen_functionalization_registration, gen_functionalization_view_inverse_declaration, - gen_symint_view_copy_kernel, ) from torchgen.gen_vmap_plumbing import gen_all_vmap_plumbing @@ -160,6 +159,9 @@ def parse_native_yaml_struct( use_out_as_primary=True, external=False, device_guard=False, + # I'm actually not sure about this; undefined could be hit on + # empty TensorList, hypothetically that could have sizes in it + symint=False, index={}, ) ) @@ -174,6 +176,16 @@ def parse_native_yaml_struct( # Only cuda-like devices in tree require device guards device_guard=is_cuda_dispatch_key(k), index=v, + # Which dispatch keys natively support symint + # Note: DispatchKey.CompositeExplicitAutograd has to match out + # composites; I think there's some factoring problem here + symint=k + in [ + DispatchKey.Meta, + DispatchKey.CompositeImplicitAutograd, + DispatchKey.CompositeExplicitAutograd, + DispatchKey.CompositeExplicitAutogradNonFunctional, + ], ) return ParsedYaml(rs, indices) @@ -862,7 +874,8 @@ def __call__(self, f: NativeFunction) -> Optional[str]: return None name = native.name(f.func) - native_sig = NativeSignature(f.func) + # BackendSelect can go to Meta, so it must preserve symints + native_sig = NativeSignature(f.func, symint=True) native_tensor_args = [ a @@ -966,7 +979,10 @@ def dynamic_type(t: Type) -> str: # also include Tensor[] if str(t) == "Tensor": return "at::Tensor" - return cpp.argumenttype_type(t, mutable=False, binds="__placeholder__").cpp_type() + # This is a legacy concept, so never report SymInt + return cpp.argumenttype_type( + t, mutable=False, binds="__placeholder__", symint=False + ).cpp_type() def compute_method_of_yaml(variants: Set[Variant]) -> List[str]: @@ -1031,7 +1047,8 @@ def compute_returns_yaml( ret = { "dynamic_type": dynamic_type(r.type), "name": name, - "type": cpp.return_type(r).cpp_type(), + # legacy, report ints + "type": cpp.return_type(r, symint=False).cpp_type(), } if r.name: @@ -1091,7 +1108,8 @@ def compute_argument_yaml( "dynamic_type": dynamic_type(a.type), "is_nullable": a.type.is_nullable(), "name": a.name, - "type": cpp.argument_type(a, binds="__placeholder__").cpp_type(), + # legacy, report ints + "type": cpp.argument_type(a, binds="__placeholder__", symint=False).cpp_type(), } if a.default is not None: arg["default"] = pythonify_default(cpp.default_expr(a.default, a.type)) @@ -1157,11 +1175,13 @@ def compute_declaration_yaml(f: NativeFunction) -> object: method=False, cpp_no_default_args=set(), faithful=False, + symint=False, has_tensor_options=False, ) ] - cpp_returns = cpp.returns_type(f.func.returns).cpp_type() + # legacy, report ints + cpp_returns = cpp.returns_type(f.func.returns, symint=False).cpp_type() schema_order_cpp_signature = f"{cpp_returns} ({', '.join(cpp_schema_order_types)})" is_factory_method = ( @@ -2390,29 +2410,6 @@ def gen_op_headers( ) }, ) - view_copy_with_symint_pairs: List[Tuple[NativeFunction, NativeFunction]] = [] - for g1 in view_groups: - for g2 in view_groups: - if g1.view_copy is None or g2.view_copy is None: - continue - # TODO: make this more first class in the data model - g1_base_name = str(g1.view_copy.func.name.name) - g2_base_name = str(g2.view_copy.func.name.name) - - same_base_op = ( - g1_base_name == g2_base_name - and g1.view_copy.func.arguments.symints_to_ints() - == g2.view_copy.func.arguments.symints_to_ints() - ) - op1_not_symint = "SymInt" not in str(g1.view_copy.func.name.overload_name) - op2_symint = "SymInt" in str(g2.view_copy.func.name.overload_name) - if same_base_op and op1_not_symint and op2_symint: - view_copy_with_symint_pairs.append( - ( - g1.view_copy, - g2.view_copy, - ) - ) # Note [view_copy NativeFunctions] # Every view operator in native_functions.yaml that is not CompositeImplicitAutograd @@ -2453,12 +2450,6 @@ def gen_op_headers( "CompositeViewCopyKernel_Definitions": list( mapMaybe(gen_composite_view_copy_kernel, view_groups) ), - "SymIntViewCopyKernel_Definitions": list( - mapMaybe( - lambda pair: gen_symint_view_copy_kernel(pair[0], pair[1]), - view_copy_with_symint_pairs, - ) - ), "GeneratedCompositeFunctional_Definitions": list( mapMaybe( gen_composite_functional_kernel, diff --git a/torchgen/gen_backend_stubs.py b/torchgen/gen_backend_stubs.py index 37b4048146c002..a8108a51411839 100644 --- a/torchgen/gen_backend_stubs.py +++ b/torchgen/gen_backend_stubs.py @@ -140,6 +140,7 @@ def create_backend_index( dispatch_key=dispatch_key, use_out_as_primary=use_out_as_primary, external=True, + symint=True, # TODO: make this configurable device_guard=use_device_guard, index=metadata, ) diff --git a/torchgen/gen_functionalization_type.py b/torchgen/gen_functionalization_type.py index d1e26a0f13b6d0..107d5737c3f78a 100644 --- a/torchgen/gen_functionalization_type.py +++ b/torchgen/gen_functionalization_type.py @@ -78,21 +78,18 @@ def gen_composite_view_copy_kernel(g: NativeFunctionsViewGroup) -> Optional[str] if g.view_copy is None: return None - # For view_copy.SymInt overloads, - # See gen_symint_view_copy_kernel. - if g.view_copy.func.name.overload_name == "SymInt": - return None - # We can make view_copy work in more cases by using reshape() # when a normal view call would ordinarily fail. # This also makes LTC more efficient, because they don't need to include # clone() calls in their graph (which is normally needed by reshape). if str(g.view_copy.func.name) == "view_copy": return """\ -at::Tensor view_copy(const at::Tensor & self, at::IntArrayRef size) { - DimVector shape = infer_size_dv(size, self.numel()); +at::Tensor view_copy(const at::Tensor & self, at::SymIntArrayRef size) { + // TODO: don't cast to int array ref + auto int_size = c10::asIntArrayRefSlow(size); + DimVector shape = infer_size_dv(int_size, self.numel()); if (!at::detail::computeStride(self.sizes(), self.strides(), shape).has_value()) { - return self.reshape(size); + return self.reshape(int_size); } else { auto output = at::_ops::view::call(self, size); return output.clone(); @@ -100,7 +97,8 @@ def gen_composite_view_copy_kernel(g: NativeFunctionsViewGroup) -> Optional[str] } """ # view_copy is a native signature, since we're generating an at::native:: kernel - view_copy_sig = NativeSignature(g.view_copy.func) + # Functionalization always operates on symints though + view_copy_sig = NativeSignature(g.view_copy.func, symint=True) # view is a dispatcher signature, since we're calling into the at::_ops API view_sig = DispatcherSignature(g.view.func) @@ -138,34 +136,6 @@ def gen_composite_view_copy_kernel(g: NativeFunctionsViewGroup) -> Optional[str] """ -# For symint view copy kernels, we want to generate them to call into -# their concrete view_copy counterparts. -@with_native_function_and -def gen_symint_view_copy_kernel( - view_copy: NativeFunction, view_copy_symint: NativeFunction -) -> str: - # view_copy.symint is a native signature, since we're generating an at::native:: kernel - view_copy_symint_sig = NativeSignature(view_copy_symint.func) - - # view_copy is a dispatcher signature, since we're calling into the at::_ops API - view_copy_sig = DispatcherSignature(view_copy.func) - - exprs = ", ".join( - [ - e.expr - for e in translate( - view_copy_symint_sig.arguments(), view_copy_sig.arguments() - ) - ] - ) - - return f""" -{view_copy_symint_sig.defn()} {{ - return at::_ops::{view_copy.func.name.unambiguous_name()}::call({exprs}); -}} -""" - - def return_str(rets: Tuple[Return, ...], names: List[str]) -> str: assert len(rets) == len(names) if len(rets) == 0: diff --git a/torchgen/model.py b/torchgen/model.py index 93bbad44e8f2cc..ee8f48afdaa1f5 100644 --- a/torchgen/model.py +++ b/torchgen/model.py @@ -816,9 +816,6 @@ def from_yaml( backend_metadata, ) - def symints_to_ints(self) -> "NativeFunction": - return dataclasses.replace(self, func=self.func.symints_to_ints()) - def validate_unstructured(self) -> None: # TODO: probably better to accumulate these errors and report them all # at once @@ -881,8 +878,6 @@ def __post_init__(self) -> None: "foreach kernels fall back to slow path when tensor are on different devices, " "device_check not allowed to be enabled" ) - named_symint = "SymInt" in self.func.name.overload_name - assert named_symint == self.func.has_symint() # NB: if your function accidentally has rand/dropout/... in its name # but is not actually random, feel free to amend this to special case @@ -1117,6 +1112,8 @@ class BackendIndex: external: bool # Other backend-specific information that is on a per-operator basis index: Dict["OperatorName", BackendMetadata] + # Whether or not this backend handles symbolic ints or not + symint: bool @staticmethod def grow_index( @@ -1235,9 +1232,6 @@ def schema_order_arguments(self) -> Iterator["Argument"]: decl_re = re.compile(r"(?P[^\(]+)\((?P.*)\) -> (?P.*)") - def symints_to_ints(self) -> "FunctionSchema": - return dataclasses.replace(self, arguments=self.arguments.symints_to_ints()) - @staticmethod def parse(func: str) -> "FunctionSchema": # We should probably get a proper parser here @@ -1360,10 +1354,6 @@ def __post_init__(self) -> None: def is_functional_fn(self) -> bool: return "functional" in self.name.overload_name - def is_symint_fn(self) -> bool: - # TODO: make this more robust - return "SymInt" in self.name.overload_name - def is_out_fn(self) -> bool: # Note [is_out_fn] # @@ -1704,9 +1694,6 @@ def is_nullable(self) -> bool: def is_list_like(self) -> Optional["ListType"]: raise NotImplementedError - def symint_to_int(self) -> "Type": - raise NotImplementedError - # Base types are simple, atomic types with no further structure BaseTy = Enum( @@ -1747,14 +1734,12 @@ def is_base_ty_like(self, base_ty: BaseTy) -> bool: def is_nullable(self) -> bool: return False - def symint_to_int(self) -> "BaseType": - if self.name == BaseTy.SymInt: - return BaseType(BaseTy.int) - return self - def is_list_like(self) -> Optional["ListType"]: return None + def is_symint_like(self) -> bool: + return self.name == BaseTy.SymInt + # Optional types may be specified, or may also be validly given None @dataclass(frozen=True) @@ -1767,12 +1752,12 @@ def __str__(self) -> str: def is_base_ty_like(self, base_ty: BaseTy) -> bool: return self.elem.is_base_ty_like(base_ty) + def is_symint_like(self) -> bool: + return self.elem.is_symint_like() + def is_nullable(self) -> bool: return True - def symint_to_int(self) -> "Type": - return dataclasses.replace(self, elem=self.elem.symint_to_int()) - def is_list_like(self) -> Optional["ListType"]: return self.elem.is_list_like() @@ -1791,15 +1776,15 @@ def __str__(self) -> str: def is_base_ty_like(self, base_ty: BaseTy) -> bool: return False + def is_symint_like(self) -> bool: + return False + def is_nullable(self) -> bool: """ Assume a custom class is not nullable. """ return False - def symint_to_int(self) -> "Type": - return self - def is_list_like(self) -> Optional["ListType"]: return None @@ -1823,12 +1808,12 @@ def __str__(self) -> str: def is_base_ty_like(self, base_ty: BaseTy) -> bool: return self.elem.is_base_ty_like(base_ty) + def is_symint_like(self) -> bool: + return self.elem.is_symint_like() + def is_nullable(self) -> bool: return self.elem.is_nullable() - def symint_to_int(self) -> "ListType": - return ListType(self.elem.symint_to_int(), self.size) - def is_list_like(self) -> Optional["ListType"]: return self @@ -1902,9 +1887,6 @@ def parse(arg: str) -> "Argument": def is_write(self) -> bool: return self.annotation is not None and self.annotation.is_write - def symint_to_int(self) -> "Argument": - return dataclasses.replace(self, type=self.type.symint_to_int()) - def __str__(self) -> str: type = f"{self.type}" if self.annotation: @@ -2094,37 +2076,6 @@ def mutable_arg_names(self) -> List[str]: if a.annotation is not None and a.annotation.is_write ] - def symints_to_ints(self) -> "Arguments": - arguments = self - - if arguments.self_arg: - arguments = dataclasses.replace( - arguments, - pre_self_positional=tuple( - x.symint_to_int() for x in arguments.pre_self_positional - ), - ) - - if self.tensor_options: - arguments = dataclasses.replace( - arguments, - post_tensor_options_kwarg_only=tuple( - x.symint_to_int() for x in arguments.post_tensor_options_kwarg_only - ), - ) - - arguments = dataclasses.replace( - arguments, - post_self_positional=tuple( - x.symint_to_int() for x in arguments.post_self_positional - ), - pre_tensor_options_kwarg_only=tuple( - x.symint_to_int() for x in arguments.pre_tensor_options_kwarg_only - ), - ) - - return arguments - def has_tensor_arg(self) -> bool: return any(a.type.is_tensor_like() for a in self.flat_non_out) diff --git a/torchgen/static_runtime/generator.py b/torchgen/static_runtime/generator.py index 22bf259f640bbd..1939a2dd9a3e1e 100644 --- a/torchgen/static_runtime/generator.py +++ b/torchgen/static_runtime/generator.py @@ -98,7 +98,9 @@ def is_supported(g: Union[NativeFunctionsGroup, NativeFunctionsViewGroup]) -> bo return False if isinstance(g, NativeFunctionsViewGroup): - if "at::Tensor" != cpp.returns_type(func.returns).cpp_type(): + # TODO: stop doing type tests by converting to C++ and then testing + # the string, just test the dang thing directly + if "at::Tensor" != cpp.returns_type(func.returns, symint=False).cpp_type(): # Returns a non-Tensor value. logger.info(f"NON-TENSOR RET TYPE: {str(func)}") return False @@ -122,7 +124,8 @@ def is_supported(g: Union[NativeFunctionsGroup, NativeFunctionsViewGroup]) -> bo or not str(func.name).endswith(".out") ): return False - if "at::Tensor &" != cpp.returns_type(func.returns).cpp_type(): + # TODO: stop type testing by converting to C++ + if "at::Tensor &" != cpp.returns_type(func.returns, symint=False).cpp_type(): logger.info(f"NON_TENSOR RET TYPE: {str(func)}") return False if has_alias(func.arguments.non_out):