Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Add a attach_grad to hybradize block
Browse files Browse the repository at this point in the history
format

deduplicate MarkDCVariables

rebase error in block.py

format manual

format

error fix

Add attach_grad gluon
  • Loading branch information
KexinFeng committed Jul 8, 2022
1 parent e36c9f0 commit 4170b23
Show file tree
Hide file tree
Showing 15 changed files with 228 additions and 24 deletions.
7 changes: 7 additions & 0 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -1276,6 +1276,13 @@ MXNET_DLL int MXAutogradMarkVariables(uint32_t num_var,
NDArrayHandle* var_handles,
uint32_t* reqs_array,
NDArrayHandle* grad_handles);
/*!
* \brief mark nonleaf NDArrays as variables during deferredcomputation
* \param num_nleafs number of nonleaf NDArrays
* \param cnt_var count of existing marked nonleaf variables
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXNDArrayMarkDCVariables(NDArrayHandle* nleaf_handles, int num_nleafs, int cnt_var);
/*!
* \brief unmark nonleaf NDArrays to free the memory
* \param num_var number of variable NDArrays
Expand Down
2 changes: 2 additions & 0 deletions include/mxnet/imperative.h
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,8 @@ class Imperative {
void MarkVariables(const std::vector<NDArray*>& variables,
const std::vector<uint32_t>& grad_reqs,
const std::vector<NDArray*>& gradients);
/*! \brief mark nonleaf variables during DC for computing gradients. */
void MarkDCVariables(const std::vector<NDArray*>& nleafs, int cnt_vars);
/*! \brief unmark nonleaf variables to free the memory. */
void DropGrads(const std::vector<NDArray*>& variables);
/*! \brief compute the gradient of outputs w.r.t variables. */
Expand Down
2 changes: 2 additions & 0 deletions include/mxnet/ndarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,8 @@ class NDArray {
bool fresh_out_grad() const;
/*! \return updated grad state in autograd_entry_ */
void set_fresh_out_grad(bool state) const;
/*! \brief copy the autograd_entry_ from src NDArray */
void copy_autograd_entry_(const NDArray* src);
/*! \brief Returns true if a sparse ndarray's aux_data and storage are initialized
* Throws an exception if the indices array shape is inconsistent
* Returns false if the indices array is empty(nnz = 0) for csr/row_sparse
Expand Down
6 changes: 5 additions & 1 deletion python/mxnet/_ctypes/cached_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def __call__(self, *args, **kwargs):
if not default_device:
default_device = kwargs.pop('default_ctx', None)
out = kwargs.pop('out', None)
nleaf_vars = [container.data() for container in kwargs.pop('_nleaf_vars', [])]
if kwargs:
raise TypeError(
"CachedOp.__call__ got unexpected keyword argument(s): " + \
Expand All @@ -93,7 +94,10 @@ def __call__(self, *args, **kwargs):
*args,
type_id,
device_id,
*out_arg
len(out_arg),
*out_arg,
len(nleaf_vars),
*nleaf_vars
)
if out is not None:
return out
Expand Down
94 changes: 91 additions & 3 deletions python/mxnet/gluon/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,14 @@
import json
import numpy as np

from ..base import mx_real_t, MXNetError, NDArrayHandle, SymbolHandle, py_str, check_call, _LIB
from ..base import mx_real_t, MXNetError, NDArrayHandle, SymbolHandle, py_str, check_call, _LIB, \
_as_list
from .. import symbol, ndarray, initializer, autograd, _deferred_compute as dc, name as _name, \
profiler as _profiler, device as _device
from ..symbol.numpy import _symbol as np_symbol
from ..symbol import Symbol, fromjson
from ..ndarray import NDArray, get_dtype_name
from .parameter import Parameter, DeferredInitializationError
from .parameter import Parameter, DeferredInitializationError, Intermediate
from .utils import _indent, _brief_print_list, HookHandle, shape_is_known
from .utils import _check_same_symbol_type, _check_all_np_ndarrays, _check_block_input_np_ndarrays
from .. import numpy_extension as _mx_npx
Expand Down Expand Up @@ -1091,6 +1092,7 @@ def __init__(self):
self._backend_opts = {}
self._partition_if_dynamic = True
self._first_forward = True
self._nleaf_vars = OrderedDict()

def __setattr__(self, name, value):
"""Registers parameters."""
Expand Down Expand Up @@ -1302,7 +1304,7 @@ def _call_cached_op(self, *args):
args_without_none = [ele for ele in args if ele is not None]
cargs = [args_without_none[i] if is_arg else i.data()
for is_arg, name, i in self._cached_op_args]
out = self._cached_op(*cargs)
out = self._cached_op(*cargs, _nleaf_vars=self._nleaf_vars.values())
if isinstance(out, NDArray):
out = [out]
return _regroup(out, self._out_format)
Expand Down Expand Up @@ -1678,6 +1680,92 @@ def reset_ctx(self, ctx):
self.reset_device(ctx)


def intermediate(self, names, var_arrays_inp, grad_req='write'):
"""Mark the intermediate variables.
Parameters
----------
name : str or tuple[str], name of the registered intermediate variable
var_arrays_inp : ndarray or tuple[ndarray], the output of the expression
grad_req : str, gradient request
"""
if not self._active:
var_arrays = _as_list(var_arrays_inp)
names = _as_list(names)
self._nleaf_vars.update(
{name : Intermediate(name, array, grad_req) for name, array in zip(names, var_arrays)})
else:
prev_val = dc.set_deferred_compute(False)
var_arrays = _as_list(var_arrays_inp)
names = _as_list(names)
# Prepare ctypes array types
import ctypes
var_handles_type = ctypes.c_void_p * len(var_arrays)
# Convert handles
var_handles = var_handles_type(*[arr.handle for arr in var_arrays])
check_call(_LIB.MXNDArrayMarkDCVariables(var_handles, len(var_arrays), len(self._nleaf_vars)))
self._nleaf_vars.update(
{name : Intermediate(name, array, grad_req) for name, array in zip(names, var_arrays)})
dc.set_deferred_compute(prev_val)
return var_arrays_inp

def attach_grad_intermediate(self):
"""Attach gradient to all the intermediate variables.
"""
for val in self._nleaf_vars.values():
val.data().attach_grad(grad_req=val.grad_req)

def get_intermediate(self, names):
"""Get the intermediate variables by names
"""
if isinstance(names, list):
return [self._nleaf_vars[n] for n in names]
else:
return self._nleaf_vars[names]

def intermediate(self, names, var_arrays_inp, grad_req='write'):
"""Mark the intermediate variables.
Parameters
----------
name : str or tuple[str], name of the registered intermediate variable
var_arrays_inp : ndarray or tuple[ndarray], the output of the expression
grad_req : str, gradient request
"""
if not self._active:
var_arrays = _as_list(var_arrays_inp)
names = _as_list(names)
self._nleaf_vars.update(
{name : Intermediate(name, array, grad_req) for name, array in zip(names, var_arrays)})
else:
prev_val = dc.set_deferred_compute(False)
var_arrays = _as_list(var_arrays_inp)
names = _as_list(names)
# Prepare ctypes array types
import ctypes
var_handles_type = ctypes.c_void_p * len(var_arrays)
# Convert handles
var_handles = var_handles_type(*[arr.handle for arr in var_arrays])
check_call(_LIB.MXNDArrayMarkDCVariables(var_handles, len(var_arrays), len(self._nleaf_vars)))
self._nleaf_vars.update(
{name : Intermediate(name, array, grad_req) for name, array in zip(names, var_arrays)})
dc.set_deferred_compute(prev_val)
return var_arrays_inp

def attach_grad_intermediate(self):
"""Attach gradient to all the intermediate variables.
"""
for val in self._nleaf_vars.values():
val.data().attach_grad(grad_req=val.grad_req)

def get_intermediate(self, names):
"""Get the intermediate variables by names
"""
if isinstance(names, list):
return [self._nleaf_vars[n] for n in names]
else:
return self._nleaf_vars[names]

class SymbolBlock(HybridBlock):
"""Construct block from symbol. This is useful for using pre-trained models
as feature extractors. For example, you may want to extract the output
Expand Down
37 changes: 37 additions & 0 deletions python/mxnet/gluon/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -773,3 +773,40 @@ def grad_req(self, req):
warnings.warn('Constant parameter "{}" does not support '
'grad_req other than "null", and new value "{}" '
'is ignored.'.format(self.name, req))

class Intermediate:
"""A Container holding marked intermediate variables of Blocks.
Parameters
----------
name : str.
Name of this parameter. It be used to retrieve the marked variables.
grad_req : {'write', 'add', 'null'}, default 'write'
Specifies how to update gradient to grad arrays.
- ``'write'`` means everytime gradient is written to grad :py:class:`NDArray`.
- ``'add'`` means everytime gradient is added to the grad :py:class:`NDArray`. You need
to manually call ``zero_grad()`` to clear the gradient buffer before each
iteration when using this option.
- 'null' means gradient is not requested for this parameter. gradient arrays
will not be allocated.
"""
def __init__(self, name, data=None, grad_req='write'):
self._name = name
self._data = data
self._grad_req = grad_req

def __repr__(self):
s = 'Intermediate name={name}'
return s.format(name=self._name)

def data(self):
return self._data

@property
def name(self):
return self._name

@property
def grad_req(self):
return self._grad_req
15 changes: 12 additions & 3 deletions src/api/cached_op_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,16 +44,18 @@ MXNET_REGISTER_GLOBAL("cached_op.invoke")
ndinputs.push_back(static_cast<mxnet::NDArray*>(args[i]));
}

int num_outputs = args[num_inputs + 4];
int num_nleafs = args[num_inputs + num_outputs + 5];
std::vector<NDArray*> ndoutputs;
ndoutputs.reserve(op->num_outputs());
if (args[num_inputs + 4].type_code() == kNull) {
if (args[num_inputs + 5].type_code() == kNull) {
for (int i = 0; i < op->num_outputs(); ++i)
ndoutputs.push_back(new NDArray());
} else {
int array_size = args_size - num_inputs - 4;
int array_size = args_size - num_inputs - num_nleafs - 6;
CHECK_EQ(array_size, op->num_outputs()) << "CachedOp expects " << op->num_outputs()
<< " outputs, but " << array_size << " was given.";
for (int i = num_inputs + 4; i < array_size; ++i) {
for (int i = num_inputs + 5; i < num_inputs + num_outputs + 5; ++i) {
ndoutputs.push_back(args[i].operator mxnet::NDArray*());
}
}
Expand All @@ -69,6 +71,13 @@ MXNET_REGISTER_GLOBAL("cached_op.invoke")
default_dev_id = ctx.dev_id;
}

std::vector<NDArray*> nleafs;
nleafs.reserve(num_nleafs);
for (int i = 0; i < num_nleafs; ++i) {
nleafs.push_back(static_cast<mxnet::NDArray*>(args[i + num_inputs + num_outputs + 6]));
}
op->set_nleafs(nleafs);

// construct default context
Context ctx =
Context::Create(static_cast<Context::DeviceType>(default_dev_type), default_dev_id);
Expand Down
12 changes: 12 additions & 0 deletions src/c_api/c_api_ndarray.cc
Original file line number Diff line number Diff line change
Expand Up @@ -495,3 +495,15 @@ int MXNDArrayGetDeferredComputeSymbol(NDArrayHandle* output_handles,
*out = s;
API_END_HANDLE_ERROR(delete s;);
}

int MXNDArrayMarkDCVariables(NDArrayHandle* nleaf_handles, int num_nleafs, int cnt_var) {
API_BEGIN();
std::vector<NDArray*> nleafs;
nleafs.reserve(num_nleafs);
for (int i = 0; i < num_nleafs; ++i) {
NDArray* array = reinterpret_cast<NDArray*>(nleaf_handles[i]);
nleafs.emplace_back(array);
}
Imperative::Get()->MarkDCVariables(nleafs, cnt_var);
API_END();
}
5 changes: 4 additions & 1 deletion src/imperative/cached_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -801,7 +801,8 @@ OpStatePtr CachedOp::DynamicForward(const Context& default_ctx,
recording && inlining_,
nullptr,
monitor_callback_,
monitor_all_);
monitor_all_,
nleafs_);
} else {
mxnet::ShapeVector shapes = g.GetAttr<mxnet::ShapeVector>("shape");
NaiveRunGraph(false,
Expand Down Expand Up @@ -1063,6 +1064,7 @@ void CachedOp::StaticBackward(const bool retain_graph,
if (!idx.exist(entry.node.get()))
continue;
auto eid = idx.entry_id(entry);
state.array_reqs[eid] = reqs[iter->second];
// An input and an output may share the same array.
INIT_DETACHED(outputs[iter->second], arrays[eid]);
arrays[eid] = outputs[iter->second];
Expand All @@ -1073,6 +1075,7 @@ void CachedOp::StaticBackward(const bool retain_graph,
if (!idx.exist(entry.node.get()))
continue;
auto eid = idx.entry_id(entry);
state.array_reqs[eid] = reqs[i];
// An input and an output may share the same array.
INIT_DETACHED(outputs[i], arrays[eid]);
arrays[eid] = outputs[i];
Expand Down
4 changes: 4 additions & 0 deletions src/imperative/cached_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,9 @@ class CachedOp {
const std::unordered_set<uint32_t>& mutable_input_nodes() const {
return fwd_graph_.indexed_graph().mutable_input_nodes();
}
void set_nleafs(const std::vector<NDArray*>& nleafs) {
nleafs_ = nleafs;
}
virtual std::vector<nnvm::NodeEntry> Gradient(const nnvm::ObjectPtr& node,
const std::vector<nnvm::NodeEntry>& ograds) const;
virtual OpStatePtr Forward(const std::shared_ptr<CachedOp>& op_ptr,
Expand Down Expand Up @@ -649,6 +652,7 @@ class CachedOp {
std::vector<uint32_t> bwd_in_dep_, bwd_out_dep_, bwd_ograd_dep_;
std::vector<bool> save_inputs_, save_outputs_;
std::vector<OpReqType> bwd_output_reqs_;
std::vector<NDArray*> nleafs_;

std::function<void(const char*, const char*, NDArrayHandle)> monitor_callback_{nullptr};
bool monitor_all_{false};
Expand Down
12 changes: 12 additions & 0 deletions src/imperative/imperative.cc
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,18 @@ void Imperative::MarkVariables(const std::vector<NDArray*>& variables,
}
}

void Imperative::MarkDCVariables(const std::vector<NDArray*>& nleafs, int cnt_vars) {
for (NDArray* nleaf : nleafs) {
if (Imperative::DCInfo::IsNone(*nleaf)) {
LOG(WARNING) << "The marked node doesn't have deferred compute history.";
} else {
nnvm::ObjectPtr node = nleaf->deferredcompute_entry_.node;
node->attrs.dict["mark_id"] = std::to_string(cnt_vars);
}
cnt_vars++;
}
}

// Unmark the variables to free the memory.
void Imperative::DropGrads(const std::vector<NDArray*>& variables) {
for (auto variable : variables) {
Expand Down
12 changes: 11 additions & 1 deletion src/imperative/imperative_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,8 @@ void RunGraph(const bool retain_graph,
bool recording,
mxnet::ShapeVector* shapes,
const imperative::CachedOpMonCallback& callback,
const bool monitor_all) {
const bool monitor_all,
const std::vector<NDArray*>& nleafs) {
CHECK(shapes == nullptr);
for (size_t i = node_start; i < node_end; ++i) {
const nnvm::IndexedGraph::Node& node = idx[i];
Expand Down Expand Up @@ -166,6 +167,15 @@ void RunGraph(const bool retain_graph,
if (callback) {
mxnet::common::ExecuteMonOutputCallback(idx, arrays, i, callback);
}
// set the autograd_entry_ in marked nleafs
if (nleafs.size()) {
auto it = node.source->attrs.dict.find("mark_id");
if (it != node.source->attrs.dict.end()) {
int mark_id = std::stoi(it->second);
CHECK_LT(mark_id, nleafs.size()) << "Mark_id exceeds the nonleaf list size.";
nleafs[mark_id]->copy_autograd_entry_(ndoutputs[0]);
}
}
}
}

Expand Down
3 changes: 2 additions & 1 deletion src/imperative/imperative_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -1386,7 +1386,8 @@ void RunGraph(const bool retain_graph,
bool recording,
mxnet::ShapeVector* shapes = nullptr,
const CachedOpMonCallback& callback = nullptr,
const bool monitor_all_ = false);
const bool monitor_all_ = false,
const std::vector<NDArray*>& nleafs = std::vector<NDArray*>());

void NaiveRunGraph(const bool retain_graph,
const Context& default_ctx,
Expand Down
4 changes: 4 additions & 0 deletions src/ndarray/ndarray.cc
Original file line number Diff line number Diff line change
Expand Up @@ -513,6 +513,10 @@ void NDArray::set_fresh_out_grad(bool state) const {
info.fresh_out_grad = state;
}

void NDArray::copy_autograd_entry_(const NDArray* src) {
autograd_entry_ = nnvm::NodeEntry{src->autograd_entry_.node, 0, 0};
}

#if MXNET_USE_ONEDNN == 1

bool NDArray::Chunk::IsDNNL() const {
Expand Down
Loading

0 comments on commit 4170b23

Please sign in to comment.