Skip to content

Commit

Permalink
Merge pull request apache#700 from piiswrong/master
Browse files Browse the repository at this point in the history
ndarray op interface fix
  • Loading branch information
piiswrong committed Nov 26, 2015
2 parents 3897d54 + d6ab7f4 commit 23311b8
Show file tree
Hide file tree
Showing 13 changed files with 264 additions and 20 deletions.
32 changes: 32 additions & 0 deletions example/python-howto/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# pylint: skip-file
""" data iterator for mnist """
import sys
import os
# code to automatically download dataset
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
sys.path.append(os.path.join(curr_path, "../../tests/python/common"))
import get_data
import mxnet as mx

def mnist_iterator(batch_size, input_shape):
"""return train and val iterators for mnist"""
# download data
get_data.GetMNIST_ubyte()
flat = False if len(input_shape) == 3 else True

train_dataiter = mx.io.MNISTIter(
image="data/train-images-idx3-ubyte",
label="data/train-labels-idx1-ubyte",
input_shape=input_shape,
batch_size=batch_size,
shuffle=True,
flat=flat)

val_dataiter = mx.io.MNISTIter(
image="data/t10k-images-idx3-ubyte",
label="data/t10k-labels-idx1-ubyte",
input_shape=input_shape,
batch_size=batch_size,
flat=flat)

return (train_dataiter, val_dataiter)
32 changes: 32 additions & 0 deletions example/python-howto/monitor_weights.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# pylint: skip-file
from data import mnist_iterator
import mxnet as mx
import numpy as np
import logging

data = mx.symbol.Variable('data')
fc1 = mx.symbol.FullyConnected(data = data, name='fc1', num_hidden=128)
act1 = mx.symbol.Activation(data = fc1, name='relu1', act_type="relu")
fc2 = mx.symbol.FullyConnected(data = act1, name = 'fc2', num_hidden = 64)
act2 = mx.symbol.Activation(data = fc2, name='relu2', act_type="relu")
fc3 = mx.symbol.FullyConnected(data = act2, name='fc3', num_hidden=10)
mlp = mx.symbol.SoftmaxOutput(data = fc3, name = 'softmax')

# data

train, val = mnist_iterator(batch_size=100, input_shape = (784,))

# train

logging.basicConfig(level=logging.DEBUG)

model = mx.model.FeedForward(
ctx = mx.cpu(), symbol = mlp, num_epoch = 20,
learning_rate = 0.1, momentum = 0.9, wd = 0.00001)

def norm_stat(d):
return mx.nd.norm(d)/np.sqrt(d.size)
mon = mx.mon.Monitor(100, norm_stat)
model.fit(X=train, eval_data=val, monitor=mon,
batch_end_callback = mx.callback.Speedometer(100, 100))

16 changes: 9 additions & 7 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,8 @@ typedef void *RecordIOHandle;
/*! \brief handle to MXRtc*/
typedef void *RtcHandle;

namespace mxnet {
class NDArray;
} // namespace mxnet
MXNET_EXTERN_C typedef void (*ExcecutorMonitorCallback)(const char*,
NDArrayHandle);

MXNET_EXTERN_C {
struct NativeOpInfo {
Expand All @@ -71,8 +70,8 @@ struct NativeOpInfo {
};

struct NDArrayOpInfo {
bool (*forward)(int, mxnet::NDArray**, int*, void*);
bool (*backward)(int, mxnet::NDArray**, int*, void*);
bool (*forward)(int, void**, int*, void*);
bool (*backward)(int, void**, int*, void*);
bool (*infer_shape)(int, int*, unsigned**, void*);
bool (*list_outputs)(char***, void*);
bool (*list_arguments)(char***, void*);
Expand Down Expand Up @@ -688,7 +687,6 @@ MXNET_DLL int MXExecutorBind(SymbolHandle symbol_handle,
mx_uint aux_states_len,
NDArrayHandle *aux_states,
ExecutorHandle *out);

/*!
* \brief Generate Executor from symbol,
* This is advanced function, allow specify group2ctx map.
Expand Down Expand Up @@ -724,7 +722,11 @@ MXNET_DLL int MXExecutorBindX(SymbolHandle symbol_handle,
mx_uint aux_states_len,
NDArrayHandle *aux_states,
ExecutorHandle *out);

/*!
* \brief set a call back to notify the completion of operation
*/
MXNET_DLL int MXExecutorSetMonitorCallback(ExecutorHandle handle,
ExcecutorMonitorCallback callback);
//--------------------------------------------
// Part 5: IO Interface
//--------------------------------------------
Expand Down
5 changes: 5 additions & 0 deletions include/mxnet/symbolic.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include <string>
#include <utility>
#include "./base.h"
#include "./c_api.h"
#include "./ndarray.h"
#include "./operator.h"

Expand Down Expand Up @@ -310,6 +311,10 @@ class Executor {
const std::vector<NDArray> &arg_grad_store,
const std::vector<OpReqType> &grad_req_type,
const std::vector<NDArray> &aux_states);
/*!
* \brief Install a callback to notify the completion of operation.
*/
virtual void SetMonitorCallback(ExcecutorMonitorCallback callback) {}
}; // class operator
} // namespace mxnet
#endif // MXNET_SYMBOLIC_H_
2 changes: 2 additions & 0 deletions python/mxnet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,5 +38,7 @@
# Attribute scope to add attributes to symbolic graphs
from .attribute import AttrScope

from . import monitor
from . import monitor as mon

__version__ = base.__version__
15 changes: 15 additions & 0 deletions python/mxnet/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def __init__(self, handle, symbol):
self._arg_dict = None
self._grad_dict = None
self._aux_dict = None
self._monitor_callback = None

def __del__(self):
check_call(_LIB.MXExecutorFree(self.handle))
Expand Down Expand Up @@ -117,6 +118,20 @@ def backward(self, out_grads=None):
mx_uint(len(out_grads)),
ndarray))

def set_monitor_callback(self, callback):
"""Install callback.
Parameters
----------
callback : function
Takes a string and an NDArrayHandle.
"""
cb_type = ctypes.CFUNCTYPE(None, ctypes.c_char_p, NDArrayHandle)
self._monitor_callback = cb_type(callback)
check_call(_LIB.MXExecutorSetMonitorCallback(
self.handle,
self._monitor_callback))

@property
def arg_dict(self):
"""Get dictionary representation of argument arrrays.
Expand Down
13 changes: 10 additions & 3 deletions python/mxnet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def _train_multi_device(symbol, ctx, arg_names, param_names, aux_names,
kvstore, update_on_kvstore,
train_data, eval_data=None, eval_metric=None,
epoch_end_callback=None, batch_end_callback=None,
logger=None, work_load_list=None):
logger=None, work_load_list=None, monitor=None):
"""Internal training function on multiple devices.
This function will also work for single device as well.
Parameters
Expand Down Expand Up @@ -231,6 +231,8 @@ def _train_multi_device(symbol, ctx, arg_names, param_names, aux_names,
data_shapes = {k: tuple([slices[i].stop-slices[i].start] + list(v[1:]))
for k, v in train_data.provide_data}
train_exec = symbol.simple_bind(ctx[i], 'write', **data_shapes)
if monitor:
monitor.install(train_exec)
train_execs.append(train_exec)

# data structure
Expand Down Expand Up @@ -287,6 +289,8 @@ def _train_multi_device(symbol, ctx, arg_names, param_names, aux_names,
_load_data(data_batch, data_arrays)
_load_label(data_batch, label_arrays)

if monitor is not None:
monitor.tic()
# forward backward pass
for texec, islice in zip(train_execs, slices):
texec.forward(is_train=True)
Expand Down Expand Up @@ -319,6 +323,9 @@ def _train_multi_device(symbol, ctx, arg_names, param_names, aux_names,
w, g = p
updater(index*num_device+k, g, w)

if monitor is not None:
monitor.toc_print()

nbatch += 1
# batch callback (for print purpose)
if batch_end_callback != None:
Expand Down Expand Up @@ -661,7 +668,7 @@ def predict(self, X):

def fit(self, X, y=None, eval_data=None, eval_metric='acc',
epoch_end_callback=None, batch_end_callback=None, kvstore='local', logger=None,
work_load_list=None):
work_load_list=None, monitor=None):
"""Fit the model.
Parameters
----------
Expand Down Expand Up @@ -736,7 +743,7 @@ def fit(self, X, y=None, eval_data=None, eval_metric='acc',
epoch_end_callback=epoch_end_callback,
batch_end_callback=batch_end_callback,
kvstore=kvstore, update_on_kvstore=update_on_kvstore,
logger=logger, work_load_list=work_load_list)
logger=logger, work_load_list=work_load_list, monitor=monitor)


def save(self, prefix, epoch=None):
Expand Down
104 changes: 104 additions & 0 deletions python/mxnet/monitor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
# coding: utf-8
# pylint: disable=protected-access, logging-format-interpolation, invalid-name, no-member
"""Monitor outputs, weights, and gradients for debugging."""
import ctypes
from .ndarray import NDArray
from .base import NDArrayHandle
from . import ndarray
import logging
from math import sqrt


class Monitor(object):
"""Monitor outputs, weights, and gradients for debugging.
Parameters
----------
interval : int
Number of batches between printing.
stat_func : function
a function that computes statistics of tensors.
Takes a NDArray and returns a NDArray. defaults to mean
absolute value |x|/size(x).
"""
def __init__(self, interval, stat_func=None):
if stat_func is None:
def asum_stat(x):
"""returns |x|/size(x), async execution."""
return ndarray.norm(x)/sqrt(x.size)
stat_func = asum_stat
self.stat_func = stat_func
self.interval = interval
self.activated = False
self.queue = []
self.step = 0
self.exes = []
def stat_helper(name, array):
"""wrapper for executor callback"""
if not self.activated:
return
array = ctypes.cast(array, NDArrayHandle)
array = NDArray(array, writable=False)
self.queue.append((self.step, name, self.stat_func(array)))
self.stat_helper = stat_helper

def install(self, exe):
"""install callback to executor.
Supports installing to multiple exes
Parameters
----------
exe : mx.executor.Executor
the Executor (returned by symbol.bind) to install to.
"""
exe.set_monitor_callback(self.stat_helper)
self.exes.append(exe)

def tic(self):
"""start collecting stats for current batch.
Call before forward"""
if self.step % self.interval == 0:
for exe in self.exes:
for array in exe.arg_arrays:
array.wait_to_read()
self.queue = []
self.activated = True
self.step += 1


def toc(self):
"""End collecting for current batch and return results.
Call after computation of current batch.
Returns
-------
res : list of """
if self.activated:
for exe in self.exes:
for array in exe.arg_arrays:
array.wait_to_read()
for exe in self.exes:
for name, array in zip(exe._symbol.list_arguments(), exe.arg_arrays):
self.queue.append((self.step, name, self.stat_func(array)))
else:
return []
self.activated = False
res = []
for n, k, v in self.queue:
assert isinstance(v, NDArray)
if v.shape == (1,):
res.append((n, k, v.asscalar()))
else:
res.append((n, k, v.asnumpy()))
self.queue = []
return res

def toc_print(self):
"""End collecting and print results"""
res = self.toc()
for n, k, v in res:
logging.info('Batch: {:7d} {:30s} {:f}'.format(n, k, v))




10 changes: 10 additions & 0 deletions python/mxnet/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,16 @@ def shape(self):
self.handle, ctypes.byref(ndim), ctypes.byref(pdata)))
return tuple(pdata[:ndim.value])

@property
def size(self):
"""Get size of current NDArray.
Returns
-------
an int representing size of current ndarray
"""
return np.prod(self.shape)

@property
def context(self):
"""Get context of current NDArray.
Expand Down
8 changes: 8 additions & 0 deletions src/c_api/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -822,6 +822,14 @@ int MXExecutorBindX(SymbolHandle symbol_handle,
API_END();
}

int MXExecutorSetMonitorCallback(ExecutorHandle handle,
ExcecutorMonitorCallback callback) {
API_BEGIN();
Executor *exec = static_cast<Executor*>(handle);
exec->SetMonitorCallback(callback);
API_END();
}

//--------------------------------------------
// Part 5: IO Interface
//--------------------------------------------
Expand Down
Loading

0 comments on commit 23311b8

Please sign in to comment.