Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
test ci
Browse files Browse the repository at this point in the history
  • Loading branch information
KexinFeng committed Jul 8, 2022
1 parent e36c9f0 commit 251a644
Show file tree
Hide file tree
Showing 4 changed files with 156 additions and 18 deletions.
6 changes: 5 additions & 1 deletion python/mxnet/_ctypes/cached_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def __call__(self, *args, **kwargs):
if not default_device:
default_device = kwargs.pop('default_ctx', None)
out = kwargs.pop('out', None)
nleaf_vars = [container.data() for container in kwargs.pop('_nleaf_vars', [])]
if kwargs:
raise TypeError(
"CachedOp.__call__ got unexpected keyword argument(s): " + \
Expand All @@ -93,7 +94,10 @@ def __call__(self, *args, **kwargs):
*args,
type_id,
device_id,
*out_arg
len(out_arg),
*out_arg,
len(nleaf_vars),
*nleaf_vars
)
if out is not None:
return out
Expand Down
94 changes: 91 additions & 3 deletions python/mxnet/gluon/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,14 @@
import json
import numpy as np

from ..base import mx_real_t, MXNetError, NDArrayHandle, SymbolHandle, py_str, check_call, _LIB
from ..base import mx_real_t, MXNetError, NDArrayHandle, SymbolHandle, py_str, check_call, _LIB, \
_as_list
from .. import symbol, ndarray, initializer, autograd, _deferred_compute as dc, name as _name, \
profiler as _profiler, device as _device
from ..symbol.numpy import _symbol as np_symbol
from ..symbol import Symbol, fromjson
from ..ndarray import NDArray, get_dtype_name
from .parameter import Parameter, DeferredInitializationError
from .parameter import Parameter, DeferredInitializationError, Intermediate
from .utils import _indent, _brief_print_list, HookHandle, shape_is_known
from .utils import _check_same_symbol_type, _check_all_np_ndarrays, _check_block_input_np_ndarrays
from .. import numpy_extension as _mx_npx
Expand Down Expand Up @@ -1091,6 +1092,7 @@ def __init__(self):
self._backend_opts = {}
self._partition_if_dynamic = True
self._first_forward = True
self._nleaf_vars = OrderedDict()

def __setattr__(self, name, value):
"""Registers parameters."""
Expand Down Expand Up @@ -1302,7 +1304,7 @@ def _call_cached_op(self, *args):
args_without_none = [ele for ele in args if ele is not None]
cargs = [args_without_none[i] if is_arg else i.data()
for is_arg, name, i in self._cached_op_args]
out = self._cached_op(*cargs)
out = self._cached_op(*cargs, _nleaf_vars=self._nleaf_vars.values())
if isinstance(out, NDArray):
out = [out]
return _regroup(out, self._out_format)
Expand Down Expand Up @@ -1678,6 +1680,92 @@ def reset_ctx(self, ctx):
self.reset_device(ctx)


def intermediate(self, names, var_arrays_inp, grad_req='write'):
"""Mark the intermediate variables.
Parameters
----------
name : str or tuple[str], name of the registered intermediate variable
var_arrays_inp : ndarray or tuple[ndarray], the output of the expression
grad_req : str, gradient request
"""
if not self._active:
var_arrays = _as_list(var_arrays_inp)
names = _as_list(names)
self._nleaf_vars.update(
{name : Intermediate(name, array, grad_req) for name, array in zip(names, var_arrays)})
else:
prev_val = dc.set_deferred_compute(False)
var_arrays = _as_list(var_arrays_inp)
names = _as_list(names)
# Prepare ctypes array types
import ctypes
var_handles_type = ctypes.c_void_p * len(var_arrays)
# Convert handles
var_handles = var_handles_type(*[arr.handle for arr in var_arrays])
check_call(_LIB.MXNDArrayMarkDCVariables(var_handles, len(var_arrays), len(self._nleaf_vars)))
self._nleaf_vars.update(
{name : Intermediate(name, array, grad_req) for name, array in zip(names, var_arrays)})
dc.set_deferred_compute(prev_val)
return var_arrays_inp

def attach_grad_intermediate(self):
"""Attach gradient to all the intermediate variables.
"""
for val in self._nleaf_vars.values():
val.data().attach_grad(grad_req=val.grad_req)

def get_intermediate(self, names):
"""Get the intermediate variables by names
"""
if isinstance(names, list):
return [self._nleaf_vars[n] for n in names]
else:
return self._nleaf_vars[names]

def intermediate(self, names, var_arrays_inp, grad_req='write'):
"""Mark the intermediate variables.
Parameters
----------
name : str or tuple[str], name of the registered intermediate variable
var_arrays_inp : ndarray or tuple[ndarray], the output of the expression
grad_req : str, gradient request
"""
if not self._active:
var_arrays = _as_list(var_arrays_inp)
names = _as_list(names)
self._nleaf_vars.update(
{name : Intermediate(name, array, grad_req) for name, array in zip(names, var_arrays)})
else:
prev_val = dc.set_deferred_compute(False)
var_arrays = _as_list(var_arrays_inp)
names = _as_list(names)
# Prepare ctypes array types
import ctypes
var_handles_type = ctypes.c_void_p * len(var_arrays)
# Convert handles
var_handles = var_handles_type(*[arr.handle for arr in var_arrays])
check_call(_LIB.MXNDArrayMarkDCVariables(var_handles, len(var_arrays), len(self._nleaf_vars)))
self._nleaf_vars.update(
{name : Intermediate(name, array, grad_req) for name, array in zip(names, var_arrays)})
dc.set_deferred_compute(prev_val)
return var_arrays_inp

def attach_grad_intermediate(self):
"""Attach gradient to all the intermediate variables.
"""
for val in self._nleaf_vars.values():
val.data().attach_grad(grad_req=val.grad_req)

def get_intermediate(self, names):
"""Get the intermediate variables by names
"""
if isinstance(names, list):
return [self._nleaf_vars[n] for n in names]
else:
return self._nleaf_vars[names]

class SymbolBlock(HybridBlock):
"""Construct block from symbol. This is useful for using pre-trained models
as feature extractors. For example, you may want to extract the output
Expand Down
37 changes: 37 additions & 0 deletions python/mxnet/gluon/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -773,3 +773,40 @@ def grad_req(self, req):
warnings.warn('Constant parameter "{}" does not support '
'grad_req other than "null", and new value "{}" '
'is ignored.'.format(self.name, req))

class Intermediate:
"""A Container holding marked intermediate variables of Blocks.
Parameters
----------
name : str.
Name of this parameter. It be used to retrieve the marked variables.
grad_req : {'write', 'add', 'null'}, default 'write'
Specifies how to update gradient to grad arrays.
- ``'write'`` means everytime gradient is written to grad :py:class:`NDArray`.
- ``'add'`` means everytime gradient is added to the grad :py:class:`NDArray`. You need
to manually call ``zero_grad()`` to clear the gradient buffer before each
iteration when using this option.
- 'null' means gradient is not requested for this parameter. gradient arrays
will not be allocated.
"""
def __init__(self, name, data=None, grad_req='write'):
self._name = name
self._data = data
self._grad_req = grad_req

def __repr__(self):
s = 'Intermediate name={name}'
return s.format(name=self._name)

def data(self):
return self._data

@property
def name(self):
return self._name

@property
def grad_req(self):
return self._grad_req
37 changes: 23 additions & 14 deletions tests/python/unittest/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,7 +533,7 @@ def test_retain_grad_drop_grad():
z.attach_grad()
out_grad = nd.array([10, 10, 10, 10])
z.backward(out_grad, retain_graph=True)

assert (u.grad == out_grad * x).asnumpy().all()
assert (z.grad == out_grad).asnumpy().all()
assert (x.grad == out_grad * 2 * x * y).asnumpy().all()
Expand All @@ -548,39 +548,48 @@ def test_retain_grad_drop_grad():
assert u.grad is None and z.grad is None and y.grad is None
assert (x.grad == out_grad * 2 * x * y).asnumpy().all()

def test_retain_grad_drop_grad_gluon():
class CompBlock(mx.gluon.HybridBlock):
@pytest.fixture(scope="function", params=[True, False])
def test_retain_grad_drop_grad_gluon(request):
class CompBlock(mx.HybridBlock):
def __init__(self):
super().__init__()
self.marked_var = None
def forward(self, a, b):
out1 = a*b
out2 = out1 * a
self.marked_var = out1

def forward(self, a, b, c):
out1 = self.intermediate(('out1_0', 'out1_1'), ((a+b)*c, a*b), grad_req='write')
out2 = self.intermediate('out2', out1[1] * a)
return out2

x = mx.np.array([1,2,3,4])
y = mx.np.array([5,6,7,8])
w = mx.np.array([0.1, 0.1, 0.1, 0.1])
x.attach_grad()
y.attach_grad()
w.attach_grad()
block2 = CompBlock()
block2.initialize()
# block2.hybridize()
param = request.param
if param:
block2.hybridize()
with mx.autograd.record():
z = block2(x, y)
u = block2.marked_var
u.attach_grad()
z.attach_grad()
z = block2(x, y, w)

block2.attach_grad_intermediate()
u0 = block2.get_intermediate('out1_0').data()
u = block2.get_intermediate('out1_1').data()
z = block2.get_intermediate('out2').data()
z.backward(retain_graph=True)

assert (u.grad == x).all()
assert (u0.grad == mx.np.array([0, 0, 0, 0])).all()
assert (z.grad == mx.np.array([1,1,1,1])).all()
assert (x.grad == 2 * x * y).all()
assert (y.grad == x*x).all()

u.drop_grad()
u0.drop_grad()
z.drop_grad()
y.drop_grad()
z.backward()

assert u.grad is None and z.grad is None and y.grad is None
assert u.grad is None and u0.grad is None and y.grad is None and z.grad is None
assert (x.grad == 2 * x * y).all()

0 comments on commit 251a644

Please sign in to comment.