forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpython_function.h
131 lines (109 loc) · 4.32 KB
/
python_function.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
#pragma once
#include <torch/csrc/python_headers.h>
#include <torch/csrc/Exceptions.h>
#include <torch/csrc/autograd/custom_function.h>
#include <torch/csrc/autograd/function.h>
#include <torch/csrc/autograd/saved_variable.h>
#include <torch/csrc/autograd/variable.h>
#include <torch/csrc/utils/object_ptr.h>
#include <c10/core/DeviceGuard.h>
#include <c10/util/Optional.h>
#include <memory>
#include <utility>
#include <vector>
namespace torch {
namespace jit {
struct Graph;
}
} // namespace torch
namespace torch {
namespace autograd {
// A Function which is implemented by a Python object (i.e., a THPFunction).
// Calls to 'apply' are forwarded to the Python method implementation.
struct PyNode : public Node {
PyNode(THPObjectPtr obj) : obj(obj.release()) {}
variable_list apply(variable_list&& inputs) override;
void release_variables() override;
std::string name() const override;
bool is_traceable() override;
// THPFunction this Function is wrapping. Owning!
PyObject* obj;
~PyNode() override {
// Can't use THPObjectPtr as a field in this class; destructor won't take
// out GIL! When I forgot to do this by hand
// TestAutograd.test_inplace_view_python called me out about it.
// If python is already dead, leak the wrapped python objects
if (Py_IsInitialized()) {
pybind11::gil_scoped_acquire gil;
Py_DECREF(obj);
}
}
};
/**
* Cast an object into a tuple, if it is not a tuple already. Returns true
* if the original object was not a tuple.
*/
inline bool ensure_tuple(THPObjectPtr& obj) {
if (PyTuple_Check(obj.get()))
return false;
PyObject* tuple = PyTuple_New(1);
if (!tuple)
throw python_error();
PyTuple_SET_ITEM(tuple, 0, obj.release());
obj = tuple;
return true;
}
} // namespace autograd
} // namespace torch
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
struct THPFunction {
PyObject_HEAD
PyObject* needs_input_grad;
// Python tuple of tensors whose variables we should save. Set
// by Python with 'save_for_backward'. If nullptr, no tensors were
// saved.
PyObject* to_save;
// Python tuple of tensors which are not differentiable. Set by
// Python with 'mark_non_differentiable'. If nullptr, no tensors were
// non-differentiable.
PyObject* non_differentiable;
// Python tuple of tensors which had inplace updates in the forward()
// pass. Set by Python with 'mark_dirty'. If nullptr, no tensors were
// modified inplace.
PyObject* dirty_tensors;
// boolean indicating whether to materialize undefined output grad tensors
// into tensors full of zeros. Set by Python with 'set_materialize_grads'.
// Default is true.
bool materialize_grads;
// boolean indicating whether to materialize output grad tensors
// corresponding to non-differentiable outputs. Normally, someone would
// already get this behavior by switching off materialize_grads,
// but there are certain use cases where that is not feasible:
// https://github.com/pytorch/pytorch/pull/98659#pullrequestreview-1376822560
bool materialize_non_diff_grads;
std::vector<torch::autograd::VariableInfo> output_info;
std::vector<torch::autograd::VariableInfo> input_info;
std::vector<torch::autograd::SavedVariable> saved_variables;
// For each input, true if the input is a THPVariable
std::vector<bool> is_variable_input;
char has_freed_buffers;
PyObject* saved_for_forward;
// The actual PyNode (in the autograd graph) that this data was
// saved for. This field may be NULL (because a user can construct
// a THPFunction directly from Python), but when this field is non-NULL,
// it is guaranteed that cdata.lock()->obj == this
//
// In most ordinary use, this field should always be non-NULL; e.g.,
// when we allocate a THPFunction because we are running Node.apply,
// after constructing a THPFunction, we immediately allocate a PyNode
// for it. We can't enforce this directly in the constructor of
// THPFunction though, because there's no way to keep it live long enough
// to save an owning reference to PyNode into the grad_fn of a Variable.
std::weak_ptr<torch::autograd::PyNode> cdata;
};
bool THPFunction_initModule(PyObject* module);
extern PyTypeObject THPFunctionType;
extern PyObject* THPFunctionClass;
inline bool THPFunction_Check(PyObject* obj) {
return PyObject_IsInstance(obj, (PyObject*)&THPFunctionType);
}