Skip to content

Commit

Permalink
Back out "Revert D38984222: Don't introduce new overload for SymInt (p…
Browse files Browse the repository at this point in the history
…ytorch#83628)" (pytorch#84173)

Also Back out "Revert D39075159: [acc_tensor] Use SymIntArrayRef for overloaded empty.memory_format's signature"

Original commit changeset: dab4a9dba4fa
Original commit changeset: dcaf16c037a9

Original Phabricator Diff: D38984222
Original Phabricator Diff: D39075159

Also update Metal registrations for C++ registration changes.

Also update NNPI registration to account for tightened schema checking

Differential Revision: [D39084762](https://our.internmc.facebook.com/intern/diff/D39084762/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D39084762/)!
Pull Request resolved: pytorch#84173
Approved by: https://github.com/Krovatkin
  • Loading branch information
ezyang authored and pytorchmergebot committed Aug 29, 2022
1 parent cfd18e1 commit ad44670
Show file tree
Hide file tree
Showing 89 changed files with 864 additions and 749 deletions.
2 changes: 1 addition & 1 deletion .github/ci_commit_pins/xla.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
9b2f7929c2dae841888a836449c25b04c8cf4045
95eedc33fb48c2ba72f5efa45daa4941cb069864
16 changes: 8 additions & 8 deletions aten/src/ATen/BatchingRegistrations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,8 @@ Tensor expand_batching_rule(const Tensor& self, IntArrayRef size, bool implicit)
}

Tensor expand_symint_batching_rule(const Tensor& self, SymIntArrayRef psize, bool implicit) {
return self.expand(asIntArrayRefSlow(psize), implicit);
// TODO: properly support this
return expand_batching_rule(self, asIntArrayRefSlow(psize), implicit);
}

std::vector<Tensor> chunk_batching_rule(const Tensor& self, int64_t chunks, int64_t dim) {
Expand Down Expand Up @@ -469,7 +470,8 @@ Tensor view_batching_rule(const Tensor& self, IntArrayRef size) {
}

Tensor view_symint_batching_rule(const Tensor& self, c10::SymIntArrayRef size) {
return self.view(asIntArrayRefSlow(size));
// TODO: properly support this
return view_batching_rule(self, asIntArrayRefSlow(size));
}

Tensor view_as_complex_batching_rule(const Tensor& self) {
Expand Down Expand Up @@ -1009,6 +1011,7 @@ Tensor new_empty_symint_batching_rule(
c10::optional<Layout> layout,
c10::optional<Device> device,
c10::optional<bool> pin_memory) {
// TODO: properly support this
return new_empty_batching_rule(self, asIntArrayRefSlow(size), dtype, layout, device, pin_memory);
}

Expand Down Expand Up @@ -1109,8 +1112,7 @@ TORCH_LIBRARY_IMPL(aten, Batched, m) {
m.impl("tensor_split.sections", tensor_split_sections_batching_rule);
m.impl("tensor_split.indices", tensor_split_indices_batching_rule);
m.impl("diagonal", diagonal_batching_rule);
m.impl("expand", expand_batching_rule);
m.impl("expand.SymInt", expand_symint_batching_rule);
m.impl("expand", expand_symint_batching_rule);
m.impl("expand_as", native::expand_as); // composite wrt autograd
m.impl("movedim.intlist", movedim_batching_rule);
m.impl("movedim.int", static_cast<Tensor(*)(const Tensor&,int64_t,int64_t)>(native::movedim)); // composite wrt autograd
Expand Down Expand Up @@ -1138,8 +1140,7 @@ TORCH_LIBRARY_IMPL(aten, Batched, m) {
m.impl("unbind.int", unbind_batching_rule);
m.impl("unfold", unfold_batching_rule);
m.impl("unsqueeze", unsqueeze_batching_rule);
m.impl("view", view_batching_rule);
m.impl("view.SymInt", view_symint_batching_rule);
m.impl("view", view_symint_batching_rule);
m.impl("view_as", native::view_as); // composite wrt autograd

// clamp operations
Expand Down Expand Up @@ -1277,8 +1278,7 @@ TORCH_LIBRARY_IMPL(aten, Batched, m) {
m.impl("diagonal_backward", diagonal_backward_batching_rule);

// Tensor.new_* operators
m.impl("new_empty", new_empty_batching_rule);
m.impl("new_empty.SymInt", new_empty_symint_batching_rule);
m.impl("new_empty", new_empty_symint_batching_rule);
m.impl("new_empty_strided", new_empty_strided_batching_rule);
m.impl("new_zeros", new_zeros_batching_rule);

Expand Down
19 changes: 4 additions & 15 deletions aten/src/ATen/FunctionalInverses.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,12 +137,8 @@ Tensor FunctionalInverses::diagonal_copy_inverse(const Tensor& base, const Tenso
return base.diagonal_scatter(mutated_view, offset, dim1, dim2);
}

Tensor FunctionalInverses::expand_copy_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views, at::IntArrayRef size, bool implicit) {
return at::sum_to(mutated_view, base.sizes(),/*always_return_non_view=*/!reapply_views);
}

Tensor FunctionalInverses::expand_copy_SymInt_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views, c10::SymIntArrayRef size, bool implicit) {
return at::sum_to(mutated_view, c10::asIntArrayRefSlow(base.sym_sizes()),/*always_return_non_view=*/!reapply_views);
Tensor FunctionalInverses::expand_copy_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views, at::SymIntArrayRef size, bool implicit) {
return at::sum_to(mutated_view, base.sym_sizes(),/*always_return_non_view=*/!reapply_views);
}

Tensor FunctionalInverses::permute_copy_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views, at::IntArrayRef dims) {
Expand Down Expand Up @@ -291,22 +287,15 @@ Tensor FunctionalInverses::unbind_copy_int_inverse(const Tensor& base, const Ten
return base.select_scatter(mutated_view, dim, mutated_view_idx);
}

Tensor FunctionalInverses::view_copy_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views, at::IntArrayRef size) {
if (reapply_views) {
return mutated_view.view(base.sizes());
} else {
return at::view_copy(mutated_view, base.sizes());
}
}

Tensor FunctionalInverses::view_copy_SymInt_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views, c10::SymIntArrayRef size) {
Tensor FunctionalInverses::view_copy_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views, at::SymIntArrayRef size) {
if (reapply_views) {
return mutated_view.view_symint(base.sym_sizes());
} else {
return at::view_copy_symint(mutated_view, base.sym_sizes());
}
}


Tensor FunctionalInverses::view_copy_dtype_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views, at::ScalarType dtype) {
if (reapply_views) {
return mutated_view.view(base.scalar_type());
Expand Down
1 change: 0 additions & 1 deletion aten/src/ATen/core/NamedRegistrations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,6 @@ TORCH_LIBRARY_IMPL(aten, Named, m) {
m.impl("exp.out", CppFunction::makeFallthrough());
m.impl("exp_", CppFunction::makeFallthrough());
m.impl("expand", CppFunction::makeFallthrough());
m.impl("expand.SymInt", CppFunction::makeFallthrough());
m.impl("expm1", CppFunction::makeFallthrough());
m.impl("expm1.out", CppFunction::makeFallthrough());
m.impl("expm1_", CppFunction::makeFallthrough());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,14 @@ namespace impl {
template<bool AllowDeprecatedTypes>
struct ivalue_to_arg<c10::SymIntArrayRef, AllowDeprecatedTypes> final {
static std::vector<c10::SymInt> call(IValue& v) {
return ivalue_to_arg<std::vector<c10::SymInt>, AllowDeprecatedTypes>::call(v);
if (v.isIntList()) {
std::vector<c10::SymInt> r;
auto src = v.toIntList();
std::transform(src.begin(), src.end(), std::back_inserter(r), [](int64_t i) { return c10::SymInt(i); });
return r;
} else {
return ivalue_to_arg<std::vector<c10::SymInt>, AllowDeprecatedTypes>::call(v);
}
}
};
template<class T, bool AllowDeprecatedTypes>
Expand Down
1 change: 1 addition & 0 deletions aten/src/ATen/core/custom_class.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ c10::FunctionSchema class_base::withNewArguments(
new_args.emplace_back(
default_arg.name_,
old_arg.type(),
old_arg.real_type(),
old_arg.N(),
default_arg.value_);
}
Expand Down
6 changes: 5 additions & 1 deletion aten/src/ATen/core/dispatch/OperatorEntry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,11 @@ OperatorEntry::OperatorEntry(OperatorName&& operator_name)

namespace {
void checkSchema(const OperatorName& name, const FunctionSchema& from_def, const std::string& from_def_debug, const FunctionSchema& inferred, const std::string& inferred_debug) {
c10::optional<std::string> schema_difference = findSchemaDifferences(from_def, inferred);
// TODO: figure out if we can just directly save real schema at def time
c10::optional<std::string> schema_difference = findSchemaDifferences(
from_def.cloneWithRealTypes(),
inferred.cloneWithRealTypes()
);
if (schema_difference.has_value()) {
TORCH_CHECK(false,
"Inferred operator schema for a C++ kernel function doesn't match the expected function schema.\n"
Expand Down
4 changes: 0 additions & 4 deletions aten/src/ATen/core/dynamic_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -231,8 +231,6 @@ TypePtr DynamicType::fallback() const {
return BoolType::get();
case Tag::Int:
return IntType::get();
case Tag::SymInt:
return SymIntType::get();
case Tag::Float:
return FloatType::get();
case Tag::Complex:
Expand Down Expand Up @@ -326,8 +324,6 @@ DynamicType::Ptr IValue::TagType<c10::DynamicType>::get(const c10::IValue& v) {
return DynamicTypeTrait<ComplexType>::getBaseType();
case Tag::Int:
return DynamicTypeTrait<IntType>::getBaseType();
case Tag::SymInt:
return DynamicTypeTrait<SymIntType>::getBaseType();
case Tag::Bool:
return DynamicTypeTrait<BoolType>::getBaseType();
case Tag::String:
Expand Down
3 changes: 1 addition & 2 deletions aten/src/ATen/core/dynamic_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ constexpr DynamicTypeBits kDynamicAnyTypeBit = DYNAMIC_TYPE_BIT(30);

constexpr DynamicTypeBits kDynamicNoneTypeBit = DYNAMIC_TYPE_BIT(1);
constexpr DynamicTypeBits kDynamicIntTypeBit = DYNAMIC_TYPE_BIT(3);
constexpr DynamicTypeBits kDynamicSymIntTypeBit = DYNAMIC_TYPE_BIT(23);
constexpr DynamicTypeBits kDynamicFloatTypeBit = DYNAMIC_TYPE_BIT(4);
constexpr DynamicTypeBits kDynamicComplexTypeBit = DYNAMIC_TYPE_BIT(5);
constexpr DynamicTypeBits kDynamicListTypeBit = DYNAMIC_TYPE_BIT(7);
Expand All @@ -29,7 +28,6 @@ constexpr DynamicTypeBits kDynamicClassTypeBit = DYNAMIC_TYPE_BIT(10);
_(Bool, DYNAMIC_TYPE_BIT(2), 1) \
_(Int, kDynamicIntTypeBit, 1) \
_(Float, kDynamicFloatTypeBit, 1) \
_(SymInt, kDynamicSymIntTypeBit, 1) \
_(Complex, kDynamicComplexTypeBit, 1) \
_(Number, \
(kDynamicIntTypeBit | kDynamicFloatTypeBit | kDynamicComplexTypeBit), \
Expand Down Expand Up @@ -63,6 +61,7 @@ constexpr DynamicTypeBits kDynamicClassTypeBit = DYNAMIC_TYPE_BIT(10);
#define FORALL_DYNAMIC_TYPES_FAKE(_) \
_(ScalarType, kDynamicIntTypeBit, 1) \
_(Layout, kDynamicIntTypeBit, 1) \
_(SymInt, kDynamicIntTypeBit, 1) \
_(MemoryFormat, kDynamicIntTypeBit, 1)

#define FORWARD_DECL_TYPE(NAME, _, __) struct NAME ## Type;
Expand Down
16 changes: 16 additions & 0 deletions aten/src/ATen/core/function_schema.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,22 @@ const std::vector<Argument>& FunctionSchema::getCorrectList(SchemaArgType type)
}
}

FunctionSchema FunctionSchema::cloneWithRealTypes() const {
auto cloneWithRealTypes = [](const Argument& a) {
return a.cloneWithType(a.real_type());
};
std::vector<Argument> new_arguments, new_returns;
std::transform(arguments().begin(), arguments().end(), std::back_inserter(new_arguments), cloneWithRealTypes);
std::transform(returns().begin(), returns().end(), std::back_inserter(new_returns), cloneWithRealTypes);
return FunctionSchema(
name(),
overload_name(),
std::move(new_arguments),
std::move(new_returns),
is_vararg(),
is_varret());
}

bool FunctionSchema::canAliasTypeSetsAlias(const c10::optional<AliasTypeSet> &lhs, const c10::optional<AliasTypeSet> &rhs) const {
if (!lhs || !rhs) {
return false;
Expand Down
6 changes: 5 additions & 1 deletion aten/src/ATen/core/function_schema.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ struct Argument {
c10::optional<AliasInfo> alias_info = c10::nullopt)
: name_(std::move(name)),
type_(fake_type ? std::move(fake_type) : TensorType::get()),
real_type_(real_type ? std::move(real_type) : TensorType::get()),
real_type_(real_type ? std::move(real_type) : type_),
N_(std::move(N)),
default_value_(std::move(default_value)),
alias_info_(alias_info ? std::make_unique<AliasInfo>(std::move(*alias_info)) : nullptr),
Expand Down Expand Up @@ -88,6 +88,8 @@ struct Argument {
const TypePtr& type() const {
return type_;
}
// if type() is non-null, this is guaranteed to be non-null (if no real
// type was provided, this takes on type()'s value)
const TypePtr& real_type() const {
return real_type_;
}
Expand Down Expand Up @@ -472,6 +474,8 @@ struct TORCH_API FunctionSchema {
FunctionSchema cloneWithRemappedTypes(
const std::function<TypePtr(TypePtr)> type_map) const;

FunctionSchema cloneWithRealTypes() const;

// Check that inputs have the correct types and appends any missing default
// values.
template <typename T = c10::PlatformType>
Expand Down
Loading

0 comments on commit ad44670

Please sign in to comment.