forked from apache/mxnet
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request apache#700 from piiswrong/master
ndarray op interface fix
- Loading branch information
Showing
13 changed files
with
264 additions
and
20 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
# pylint: skip-file | ||
""" data iterator for mnist """ | ||
import sys | ||
import os | ||
# code to automatically download dataset | ||
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) | ||
sys.path.append(os.path.join(curr_path, "../../tests/python/common")) | ||
import get_data | ||
import mxnet as mx | ||
|
||
def mnist_iterator(batch_size, input_shape): | ||
"""return train and val iterators for mnist""" | ||
# download data | ||
get_data.GetMNIST_ubyte() | ||
flat = False if len(input_shape) == 3 else True | ||
|
||
train_dataiter = mx.io.MNISTIter( | ||
image="data/train-images-idx3-ubyte", | ||
label="data/train-labels-idx1-ubyte", | ||
input_shape=input_shape, | ||
batch_size=batch_size, | ||
shuffle=True, | ||
flat=flat) | ||
|
||
val_dataiter = mx.io.MNISTIter( | ||
image="data/t10k-images-idx3-ubyte", | ||
label="data/t10k-labels-idx1-ubyte", | ||
input_shape=input_shape, | ||
batch_size=batch_size, | ||
flat=flat) | ||
|
||
return (train_dataiter, val_dataiter) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
# pylint: skip-file | ||
from data import mnist_iterator | ||
import mxnet as mx | ||
import numpy as np | ||
import logging | ||
|
||
data = mx.symbol.Variable('data') | ||
fc1 = mx.symbol.FullyConnected(data = data, name='fc1', num_hidden=128) | ||
act1 = mx.symbol.Activation(data = fc1, name='relu1', act_type="relu") | ||
fc2 = mx.symbol.FullyConnected(data = act1, name = 'fc2', num_hidden = 64) | ||
act2 = mx.symbol.Activation(data = fc2, name='relu2', act_type="relu") | ||
fc3 = mx.symbol.FullyConnected(data = act2, name='fc3', num_hidden=10) | ||
mlp = mx.symbol.SoftmaxOutput(data = fc3, name = 'softmax') | ||
|
||
# data | ||
|
||
train, val = mnist_iterator(batch_size=100, input_shape = (784,)) | ||
|
||
# train | ||
|
||
logging.basicConfig(level=logging.DEBUG) | ||
|
||
model = mx.model.FeedForward( | ||
ctx = mx.cpu(), symbol = mlp, num_epoch = 20, | ||
learning_rate = 0.1, momentum = 0.9, wd = 0.00001) | ||
|
||
def norm_stat(d): | ||
return mx.nd.norm(d)/np.sqrt(d.size) | ||
mon = mx.mon.Monitor(100, norm_stat) | ||
model.fit(X=train, eval_data=val, monitor=mon, | ||
batch_end_callback = mx.callback.Speedometer(100, 100)) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
# coding: utf-8 | ||
# pylint: disable=protected-access, logging-format-interpolation, invalid-name, no-member | ||
"""Monitor outputs, weights, and gradients for debugging.""" | ||
import ctypes | ||
from .ndarray import NDArray | ||
from .base import NDArrayHandle | ||
from . import ndarray | ||
import logging | ||
from math import sqrt | ||
|
||
|
||
class Monitor(object): | ||
"""Monitor outputs, weights, and gradients for debugging. | ||
Parameters | ||
---------- | ||
interval : int | ||
Number of batches between printing. | ||
stat_func : function | ||
a function that computes statistics of tensors. | ||
Takes a NDArray and returns a NDArray. defaults to mean | ||
absolute value |x|/size(x). | ||
""" | ||
def __init__(self, interval, stat_func=None): | ||
if stat_func is None: | ||
def asum_stat(x): | ||
"""returns |x|/size(x), async execution.""" | ||
return ndarray.norm(x)/sqrt(x.size) | ||
stat_func = asum_stat | ||
self.stat_func = stat_func | ||
self.interval = interval | ||
self.activated = False | ||
self.queue = [] | ||
self.step = 0 | ||
self.exes = [] | ||
def stat_helper(name, array): | ||
"""wrapper for executor callback""" | ||
if not self.activated: | ||
return | ||
array = ctypes.cast(array, NDArrayHandle) | ||
array = NDArray(array, writable=False) | ||
self.queue.append((self.step, name, self.stat_func(array))) | ||
self.stat_helper = stat_helper | ||
|
||
def install(self, exe): | ||
"""install callback to executor. | ||
Supports installing to multiple exes | ||
Parameters | ||
---------- | ||
exe : mx.executor.Executor | ||
the Executor (returned by symbol.bind) to install to. | ||
""" | ||
exe.set_monitor_callback(self.stat_helper) | ||
self.exes.append(exe) | ||
|
||
def tic(self): | ||
"""start collecting stats for current batch. | ||
Call before forward""" | ||
if self.step % self.interval == 0: | ||
for exe in self.exes: | ||
for array in exe.arg_arrays: | ||
array.wait_to_read() | ||
self.queue = [] | ||
self.activated = True | ||
self.step += 1 | ||
|
||
|
||
def toc(self): | ||
"""End collecting for current batch and return results. | ||
Call after computation of current batch. | ||
Returns | ||
------- | ||
res : list of """ | ||
if self.activated: | ||
for exe in self.exes: | ||
for array in exe.arg_arrays: | ||
array.wait_to_read() | ||
for exe in self.exes: | ||
for name, array in zip(exe._symbol.list_arguments(), exe.arg_arrays): | ||
self.queue.append((self.step, name, self.stat_func(array))) | ||
else: | ||
return [] | ||
self.activated = False | ||
res = [] | ||
for n, k, v in self.queue: | ||
assert isinstance(v, NDArray) | ||
if v.shape == (1,): | ||
res.append((n, k, v.asscalar())) | ||
else: | ||
res.append((n, k, v.asnumpy())) | ||
self.queue = [] | ||
return res | ||
|
||
def toc_print(self): | ||
"""End collecting and print results""" | ||
res = self.toc() | ||
for n, k, v in res: | ||
logging.info('Batch: {:7d} {:30s} {:f}'.format(n, k, v)) | ||
|
||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.