forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathFunctionalizeFallbackKernel.cpp
67 lines (63 loc) · 2.82 KB
/
FunctionalizeFallbackKernel.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
#include <ATen/core/dispatch/Dispatcher.h>
#include <ATen/core/LegacyTypeDispatch.h>
#include <ATen/FunctionalTensorWrapper.h>
#include <torch/library.h>
#include <c10/util/irange.h>
namespace {
void functionalizeFallback(const c10::OperatorHandle& op, c10::DispatchKeySet dispatchKeySet, torch::jit::Stack* stack) {
const auto& schema = op.schema();
TORCH_INTERNAL_ASSERT(!schema.hasAnyAliasInfo(), "mutating and aliasing ops should all have codegen'd kernels");
const auto num_arguments = schema.arguments().size();
const auto arguments_begin = stack->size() - num_arguments;
auto arguments = torch::jit::last(stack, num_arguments);
auto any_functional_inputs = false;
auto any_tensor_inputs = false;
for (uint64_t idx = 0; idx < num_arguments; ++idx) {
const auto& ivalue = arguments[idx];
if (ivalue.isTensor()) {
any_tensor_inputs = true;
auto t = ivalue.toTensor();
if (at::functionalization::impl::isFunctionalTensor(t)) {
any_functional_inputs = true;
at::functionalization::impl::sync(t);
auto t_new = c10::IValue(at::functionalization::impl::from_functional_tensor(t));
(*stack)[arguments_begin + idx] = t_new;
}
} else if (ivalue.isTensorList()) {
any_tensor_inputs = true;
auto tensors = ivalue.toTensorList();
if (at::functionalization::impl::isFunctionalTensor(tensors)) {
any_functional_inputs = true;
at::functionalization::impl::sync(tensors);
auto t_new = c10::IValue(at::functionalization::impl::from_functional_tensor(tensors));
(*stack)[arguments_begin + idx] = t_new;
}
}
}
// we should wrap the output if any inputs were wrapped,
// OR if we're hitting a factory function (with no tensor inputs)
auto should_wrap_outputs = !any_tensor_inputs || any_functional_inputs;
{
at::AutoDispatchSkipFunctionalize guard;
op.callBoxed(stack);
}
const auto num_returns = schema.returns().size();
const auto returns_begin = stack->size() - num_returns;
auto returns = torch::jit::last(stack, num_returns);
for (const auto idx : c10::irange(num_returns)) {
const auto& ivalue = returns[idx];
if (ivalue.isTensor() && should_wrap_outputs) {
auto t = ivalue.toTensor();
auto t_new = c10::IValue(at::functionalization::impl::to_functional_tensor(t));
(*stack)[returns_begin + idx] = t_new;
} else if (ivalue.isTensorList() && should_wrap_outputs) {
auto tensors = ivalue.toTensorList();
auto t_new = c10::IValue(at::functionalization::impl::to_functional_tensor(tensors));
(*stack)[returns_begin + idx] = t_new;
}
}
}
}
TORCH_LIBRARY_IMPL(_, Functionalize, m) {
m.fallback(torch::CppFunction::makeFromBoxedFunction<&functionalizeFallback>());
}