From 37b986e09f838152afb8c9f94f531028c478034c Mon Sep 17 00:00:00 2001 From: Bing Xu Date: Wed, 2 Sep 2015 16:59:04 -0600 Subject: [PATCH] add narray scalar op --- dmlc-core | 2 +- include/mxnet/narray.h | 79 ++++++++++++++++- python/mxnet/narray.py | 14 +++ src/narray/narray.cc | 146 ++++++++++++++++++++++++++++--- src/narray/narray_function-inl.h | 31 +++++-- src/narray/narray_function.h | 3 + tests/python/test_narray.py | 11 +++ 7 files changed, 266 insertions(+), 20 deletions(-) mode change 100644 => 100755 include/mxnet/narray.h mode change 100644 => 100755 python/mxnet/narray.py mode change 100644 => 100755 src/narray/narray.cc mode change 100644 => 100755 src/narray/narray_function-inl.h diff --git a/dmlc-core b/dmlc-core index 7d3c78428819..75f1950d386d 160000 --- a/dmlc-core +++ b/dmlc-core @@ -1 +1 @@ -Subproject commit 7d3c78428819dc84c4da8ae1f302ba6c6a235a5d +Subproject commit 75f1950d386d033b0b64919017515d27e698962a diff --git a/include/mxnet/narray.h b/include/mxnet/narray.h old mode 100644 new mode 100755 index fc72b0f91e18..5d445261f4c1 --- a/include/mxnet/narray.h +++ b/include/mxnet/narray.h @@ -106,6 +106,13 @@ class NArray { * \return reference of self */ NArray &operator+=(const NArray &src); + /*! + * \brief elementwise add to current space + * this mutate the current NArray + * \param src the data to add + * \return reference of self + */ + NArray &operator+=(const real_t &src); /*! * \brief elementwise subtract from current narray * this mutate the current NArray @@ -113,6 +120,13 @@ class NArray { * \return reference of self */ NArray &operator-=(const NArray &src); + /*! + * \brief elementwise subtract from current narray + * this mutate the current NArray + * \param src the data to substract + * \return reference of self + */ + NArray &operator-=(const real_t &src); /*! * \brief elementwise multiplication to current narray * this mutate the current NArray @@ -120,6 +134,13 @@ class NArray { * \return reference of self */ NArray &operator*=(const NArray &src); + /*! + * \brief elementwise multiplication to current narray + * this mutate the current NArray + * \param src the data to substract + * \return reference of self + */ + NArray &operator*=(const real_t &src); /*! * \brief elementwise division from current narray * this mutate the current NArray @@ -127,6 +148,13 @@ class NArray { * \return reference of self */ NArray &operator/=(const NArray &src); + /*! + * \brief elementwise division from current narray + * this mutate the current NArray + * \param src the data to substract + * \return reference of self + */ + NArray &operator/=(const real_t &src); /*! * \brief return transpose of current NArray * \return a new transposed NArray @@ -241,6 +269,8 @@ class NArray { friend void BinaryOp(const NArray &lhs, const NArray &rhs, NArray *out); template friend void UnaryOp(const NArray &lhs, const NArray &rhs, NArray *out); + template + friend void ScalarOp(const NArray &lhs, const real_t &rhs, NArray *out); }; /*! @@ -262,6 +292,13 @@ void CopyFromTo(const NArray &from, NArray *to); * \return a new result narray */ NArray operator+(const NArray &lhs, const NArray &rhs); +/*! + * \brief elementwise add + * \param lhs left operand + * \param rhs right operand + * \return a new result narray + */ +NArray operator+(const NArray &lhs, const real_t &rhs); /*! * \brief elementwise substraction * \param lhs left operand @@ -269,13 +306,27 @@ NArray operator+(const NArray &lhs, const NArray &rhs); * \return a new result narray */ NArray operator-(const NArray &lhs, const NArray &rhs); +/*! + * \brief elementwise substraction + * \param lhs left operand + * \param rhs right operand + * \return a new result narray + */ +NArray operator-(const NArray &lhs, const real_t &rhs); /*! * \brief elementwise multiplication * \param lhs left operand * \param rhs right operand * \return a new result narray */ -NArray operator*(const NArray &lhs, const NArray &rhs); +NArray operator*(const NArray &lhs, const NArray &rhs);\ +/*! + * \brief elementwise multiplication + * \param lhs left operand + * \param rhs right operand + * \return a new result narray + */ +NArray operator*(const NArray &lhs, const real_t &rhs); /*! * \brief elementwise division * \param lhs left operand @@ -283,6 +334,13 @@ NArray operator*(const NArray &lhs, const NArray &rhs); * \return a new result narray */ NArray operator/(const NArray &lhs, const NArray &rhs); +/*! + * \brief elementwise division + * \param lhs left operand + * \param rhs right operand + * \return a new result narray + */ +NArray operator/(const NArray &lhs, const real_t &rhs); //-------------------------------------------------------------- // The following part are API Registration of NArray functions. @@ -346,6 +404,25 @@ struct NArrayFunctionReg this->add_argument("rhs", "NArray", "Right operand to the function."); return *this; } + /*! + * \brief set the function body to a binary NArray function + * this will also auto set the parameters correctly + * \param fscalar function body to set + * \return ref to the registered entry, used to set properties + */ + inline NArrayFunctionReg &set_function(void fscalar(const NArray &lhs, + const real_t &rhs, + NArray *out)) { + body = [fscalar] (NArray **used_vars, + real_t *s, NArray **mutate_vars) { + fscalar(*used_vars[0], s[0], mutate_vars[0]); + }; + num_use_vars = 1; num_mutate_vars = 1; num_scalars = 1; + type_mask = kNArrayArgBeforeScalar | kAcceptEmptyMutateTarget; + this->add_argument("lhs", "NArray", "Left operand to the function."); + this->add_argument("rhs", "real_t", "Right operand to the function."); + return *this; + } /*! * \brief set the function body to a unary NArray function * this will also auto set the parameters correctly diff --git a/python/mxnet/narray.py b/python/mxnet/narray.py old mode 100644 new mode 100755 index c4c267c87e11..227b06399328 --- a/python/mxnet/narray.py +++ b/python/mxnet/narray.py @@ -66,6 +66,8 @@ def __del__(self): def __add__(self, other): if isinstance(other, NArray): return NArray._plus(self, other) + elif isinstance(other, float) or isinstance(other, int): + return NArray._plus_scalar(self, float(other)) else: raise TypeError('type %s not supported' % str(type(other))) @@ -75,12 +77,16 @@ def __radd__(self, other): def __sub__(self, other): if isinstance(other, NArray): return NArray._minus(self, other) + elif isinstance(other, float) or isinstance(other, int): + return NArray._minus_scalar(self, float(other)) else: raise TypeError('type %s not supported' % str(type(other))) def __mul__(self, other): if isinstance(other, NArray): return NArray._mul(self, other) + elif isinstance(other, float) or isinstance(other, int): + return NArray._mul_scalar(self, float(other)) else: raise TypeError('type %s not supported' % str(type(other))) @@ -90,9 +96,17 @@ def __rmul__(self, other): def __div__(self, other): if isinstance(other, NArray): return NArray._div(self, other) + elif isinstance(other, float) or isinstance(other, int): + return NArray._div_scalar(self, float(other)) else: raise TypeError('type %s not supported' % str(type(other))) + def __idiv__(self, other): + return self.__div__(other) + + def __truediv__(self, other): + return self.__div__(other) + def __getstate__(self): this = self.__dict__.copy() handle = this['handle'] diff --git a/src/narray/narray.cc b/src/narray/narray.cc old mode 100644 new mode 100755 index 57b25f536978..a9f7ebde678d --- a/src/narray/narray.cc +++ b/src/narray/narray.cc @@ -38,21 +38,95 @@ inline void BinaryOp(const NArray &lhs, NArray ret = *out; // redirect everything to mshadow operations switch (lhs.ctx().dev_mask) { - case cpu::kDevMask: - DAGEngine::Get()->Push([lhs, rhs, ret](RunContext ctx) { - ret.ptr_->CheckAndAlloc(); - TBlob tmp = ret.data(); - narray::Eval(lhs.data(), rhs.data(), &tmp, ctx); - }, lhs.ctx(), {lhs.ptr_->var, rhs.ptr_->var}, {ret.ptr_->var}); + case cpu::kDevMask: { + auto func = [lhs, rhs, ret](RunContext ctx) { + ret.ptr_->CheckAndAlloc(); + TBlob tmp = ret.data(); + narray::Eval(lhs.data(), rhs.data(), &tmp, ctx); + }; + if (lhs.ptr_->var == ret.ptr_->var && rhs.ptr_->var == ret.ptr_->var) { + DAGEngine::Get()->Push(func, lhs.ctx(), {}, {ret.ptr_->var}); + } else if (lhs.ptr_->var == ret.ptr_->var) { + DAGEngine::Get()->Push(func, lhs.ctx(), {rhs.ptr_->var}, {ret.ptr_->var}); + } else if (rhs.ptr_->var == ret.ptr_->var) { + DAGEngine::Get()->Push(func, lhs.ctx(), {lhs.ptr_->var}, {ret.ptr_->var}); + } else { + DAGEngine::Get()->Push(func, lhs.ctx(), {lhs.ptr_->var, rhs.ptr_->var}, {ret.ptr_->var}); + } break; + } #if MXNET_USE_CUDA - case gpu::kDevMask: - DAGEngine::Get()->Push([lhs, rhs, ret](RunContext ctx) { - ret.ptr_->CheckAndAlloc(); - TBlob tmp = ret.data(); - narray::Eval(lhs.data(), rhs.data(), &tmp, ctx); - }, lhs.ctx(), {lhs.ptr_->var, rhs.ptr_->var}, {ret.ptr_->var}); + case gpu::kDevMask: { + auto func = [lhs, rhs, ret](RunContext ctx) { + ret.ptr_->CheckAndAlloc(); + TBlob tmp = ret.data(); + narray::Eval(lhs.data(), rhs.data(), &tmp, ctx); + }; + if (lhs.ptr_->var == ret.ptr_->var && rhs.ptr_->var == ret.ptr_->var) { + DAGEngine::Get()->Push(func, lhs.ctx(), {}, {ret.ptr_->var}); + } else if (lhs.ptr_->var == ret.ptr_->var) { + DAGEngine::Get()->Push(func, lhs.ctx(), {rhs.ptr_->var}, {ret.ptr_->var}); + } else if (rhs.ptr_->var == ret.ptr_->var) { + DAGEngine::Get()->Push(func, lhs.ctx(), {lhs.ptr_->var}, {ret.ptr_->var}); + } else { + DAGEngine::Get()->Push(func, lhs.ctx(), {lhs.ptr_->var, rhs.ptr_->var}, {ret.ptr_->var}); + } break; + } +#endif + default: LOG(FATAL) << "GPU is not enabled"; + } +} + +/*! + * \brief run a binary operation + * \param lhs left operand + * \param rhs right operand + * \param out the output narray + * \param binary_op the real + */ +template +inline void ScalarOp(const NArray &lhs, + const real_t &rhs, + NArray *out) { + if (out->is_none()) { + *out = NArray(OP::GetShape(lhs.shape(), lhs.shape()), lhs.ctx(), true); + } else { + CHECK(out->ctx() == lhs.ctx()) << "target context mismatch"; + CHECK(out->shape() == OP::GetShape(lhs.shape(), lhs.shape())) + << "target shape mismatch"; + } + // important: callback must always capture by value + NArray ret = *out; + // redirect everything to mshadow operations + switch (lhs.ctx().dev_mask) { + case cpu::kDevMask: { + auto func = [lhs, rhs, ret](RunContext ctx) { + ret.ptr_->CheckAndAlloc(); + TBlob tmp = ret.data(); + narray::Eval(lhs.data(), rhs, &tmp, ctx); + }; + if (lhs.ptr_->var == ret.ptr_->var) { + DAGEngine::Get()->Push(func, lhs.ctx(), {}, {ret.ptr_->var}); + } else { + DAGEngine::Get()->Push(func, lhs.ctx(), {lhs.ptr_->var}, {ret.ptr_->var}); + } + break; + } +#if MXNET_USE_CUDA + case gpu::kDevMask: { + auto func = [lhs, rhs, ret](RunContext ctx) { + ret.ptr_->CheckAndAlloc(); + TBlob tmp = ret.data(); + narray::Eval(lhs.data(), rhs, &tmp, ctx); + }; + if (lhs.ptr_->var == ret.ptr_->var) { + DAGEngine::Get()->Push(func, lhs.ctx(), {}, {ret.ptr_->var}); + } else { + DAGEngine::Get()->Push(func, lhs.ctx(), {lhs.ptr_->var}, {ret.ptr_->var}); + } + break; + } #endif default: LOG(FATAL) << "GPU is not enabled"; } @@ -120,6 +194,14 @@ inline NArray BinaryOpRet(const NArray &lhs, return ret; } +template +inline NArray ScalarOpRet(const NArray &lhs, + const real_t &rhs) { + NArray ret; + ScalarOp(lhs, rhs, &ret); + return ret; +} + template inline NArray &BinaryOpApply(NArray *dst, const NArray &src) { @@ -127,6 +209,13 @@ inline NArray &BinaryOpApply(NArray *dst, return *dst; } +template +inline NArray &ScalarOpApply(NArray *dst, + const real_t &src) { + ScalarOp(*dst, src, dst); + return *dst; +} +// Binary NArray operator+(const NArray &lhs, const NArray &rhs) { return BinaryOpRet(lhs, rhs); } @@ -139,7 +228,20 @@ NArray operator*(const NArray &lhs, const NArray &rhs) { NArray operator/(const NArray &lhs, const NArray &rhs) { return BinaryOpRet(lhs, rhs); } - +// Scalar +NArray operator+(const NArray &lhs, const real_t &rhs) { + return ScalarOpRet(lhs, rhs); +} +NArray operator-(const NArray &lhs, const real_t &rhs) { + return ScalarOpRet(lhs, rhs); +} +NArray operator*(const NArray &lhs, const real_t &rhs) { + return ScalarOpRet(lhs, rhs); +} +NArray operator/(const NArray &lhs, const real_t &rhs) { + return ScalarOpRet(lhs, rhs); +} +// Binary NArray &NArray::operator+=(const NArray &src) { return BinaryOpApply(this, src); } @@ -152,6 +254,19 @@ NArray &NArray::operator*=(const NArray &src) { NArray &NArray::operator/=(const NArray &src) { return BinaryOpApply(this, src); } +// Scalar +NArray &NArray::operator+=(const real_t &src) { + return ScalarOpApply(this, src); +} +NArray &NArray::operator-=(const real_t &src) { + return ScalarOpApply(this, src); +} +NArray &NArray::operator*=(const real_t &src) { + return ScalarOpApply(this, src); +} +NArray &NArray::operator/=(const real_t &src) { + return ScalarOpApply(this, src); +} void NArray::Save(dmlc::Stream *strm) const { // save shape @@ -223,6 +338,11 @@ MXNET_REGISTER_NARRAY_FUN(_minus).set_function(BinaryOp); MXNET_REGISTER_NARRAY_FUN(_mul).set_function(BinaryOp); MXNET_REGISTER_NARRAY_FUN(_div).set_function(BinaryOp); +/////// +MXNET_REGISTER_NARRAY_FUN(_plus_scalar).set_function(ScalarOp); +MXNET_REGISTER_NARRAY_FUN(_minus_scalar).set_function(ScalarOp); +MXNET_REGISTER_NARRAY_FUN(_mul_scalar).set_function(ScalarOp); +MXNET_REGISTER_NARRAY_FUN(_div_scalar).set_function(ScalarOp); // copy function is special // that we need to remove kAcceptEmptyMutateTarget from it MXNET_REGISTER_NARRAY_FUN(_copyto) diff --git a/src/narray/narray_function-inl.h b/src/narray/narray_function-inl.h old mode 100644 new mode 100755 index 6488652ffe80..6c79d74e3f5a --- a/src/narray/narray_function-inl.h +++ b/src/narray/narray_function-inl.h @@ -16,6 +16,14 @@ } #endif +#ifndef DECL_SCALAR +#define DECL_SCALAR(XPU, OP, FUN) \ + template<> \ + void Eval(const TBlob &lhs, const real_t &rhs, TBlob *ret, RunContext ctx) { \ + FUN(lhs, rhs, ret, ctx); \ + } +#endif + #if defined(__CUDACC__) #define DEVICE gpu #else @@ -26,7 +34,7 @@ namespace mxnet { namespace narray { // true implementation template -inline void Eval_(const TBlob &lhs, const TBlob &rhs, +inline void EvalBinary_(const TBlob &lhs, const TBlob &rhs, TBlob *ret, RunContext ctx) { using namespace mshadow::expr; mshadow::Stream *s = static_cast*>(ctx.stream); @@ -34,11 +42,24 @@ inline void Eval_(const TBlob &lhs, const TBlob &rhs, = F(lhs.FlatTo2D(s), rhs.FlatTo2D(s)); } + +template +inline void EvalScalar_(const TBlob &lhs, const real_t &rhs, + TBlob *ret, RunContext ctx) { + using namespace mshadow::expr; + mshadow::Stream *s = static_cast*>(ctx.stream); + ret->FlatTo2D(s) + = F(lhs.FlatTo2D(s), rhs); +} // declarations -DECL_BINARY(DEVICE, Plus, Eval_) -DECL_BINARY(DEVICE, Minus, Eval_) -DECL_BINARY(DEVICE, Mul, Eval_) -DECL_BINARY(DEVICE, Div, Eval_) +DECL_BINARY(DEVICE, Plus, EvalBinary_) +DECL_BINARY(DEVICE, Minus, EvalBinary_) +DECL_BINARY(DEVICE, Mul, EvalBinary_) +DECL_BINARY(DEVICE, Div, EvalBinary_) +DECL_SCALAR(DEVICE, Plus, EvalScalar_) +DECL_SCALAR(DEVICE, Minus, EvalScalar_) +DECL_SCALAR(DEVICE, Mul, EvalScalar_) +DECL_SCALAR(DEVICE, Div, EvalScalar_) } // namespace narray } // namespace mxnet diff --git a/src/narray/narray_function.h b/src/narray/narray_function.h index 50e86aeed9ed..4ea556883ede 100644 --- a/src/narray/narray_function.h +++ b/src/narray/narray_function.h @@ -36,6 +36,9 @@ struct Div : public BinaryBase { template void Eval(const TBlob &lhs, const TBlob &rhs, TBlob *ret, RunContext ctx); +template +void Eval(const TBlob &lhs, const real_t &rhs, TBlob *ret, RunContext ctx); + // copy function when only cpu is involved template void Copy(const TBlob &from, TBlob *to, diff --git a/tests/python/test_narray.py b/tests/python/test_narray.py index e4e1698c3799..6e3466fbe473 100644 --- a/tests/python/test_narray.py +++ b/tests/python/test_narray.py @@ -53,6 +53,16 @@ def test_narray_copy(): assert np.sum(np.abs(c.numpy != d.numpy)) == 0.0 +def test_narray_scalar(): + c = mx.narray.create((10,10)) + d = mx.narray.create((10,10)) + c.numpy[:] = 0.5 + d.numpy[:] = 1.0 + d -= c * 2 / 3 * 6.0 + c += 0.5 + assert(np.sum(c.numpy) == 100) + assert(np.sum(d.numpy) == -100) + def test_narray_pickle(): np.random.seed(0) maxdim = 5 @@ -97,3 +107,4 @@ def test_narray_saveload(): test_narray_saveload() test_narray_copy() test_narray_elementwise() + test_narray_scalar()