From 997d042e6da8ebf1c21f640a956d71eda9702fe3 Mon Sep 17 00:00:00 2001 From: sxjscience Date: Sat, 28 May 2016 16:51:01 +0800 Subject: [PATCH] Keep the original function name and add broadcast_keepdim & reduce_keepdim --- mshadow/extension/broadcast_with_axis.h | 27 +++++++++++++++++++------ mshadow/extension/reduce_with_axis.h | 27 ++++++++++++++++++++----- test/test_tblob.cc | 8 ++++---- 3 files changed, 47 insertions(+), 15 deletions(-) diff --git a/mshadow/extension/broadcast_with_axis.h b/mshadow/extension/broadcast_with_axis.h index a2dfb354..881f99d6 100644 --- a/mshadow/extension/broadcast_with_axis.h +++ b/mshadow/extension/broadcast_with_axis.h @@ -34,8 +34,9 @@ struct BroadcastWithAxisExp: /*! \brief size of the last dimension of src*/ index_t last_; /*! constructor */ - BroadcastWithAxisExp(const SrcExp &src, const int axis, const index_t size, int keepdim) + BroadcastWithAxisExp(const SrcExp &src, const int axis, const index_t size) : src_(src), size_(size) { + bool keepdim = (dimsrc == dimdst); Shape src_shape = ShapeCheck::Check(src_); this->trailing_ = 1; @@ -71,19 +72,33 @@ struct BroadcastWithAxisExp: }; // struct BroadcastWithAxisExp /*! - * \brief Broadcasting the tensor in the given axis. If keepdim is off, insert the broadcasting dim after axis. Otherwise broadcasting axis. - * \param keepdim whether to keepdim + * \brief Broadcasting the tensor after given axis. * \param SrcExp source expression * \tparam DType data type * \tparam etype type of the expression */ -template +template inline BroadcastWithAxisExp::kDim, - ExpInfo::kDim + 1 - keepdim> + ExpInfo::kDim + 1> broadcast_with_axis(const Exp &src, const int axis, const index_t size) { return BroadcastWithAxisExp::kDim, - ExpInfo::kDim + 1 - keepdim>(src.self(), axis, size, keepdim); + ExpInfo::kDim + 1>(src.self(), axis, size); } + +/*! +* \brief Broadcasting the tensor in the given axis (keepdim turned on) +* \param SrcExp source expression +* \tparam DType data type +* \tparam etype type of the expression +*/ +template +inline BroadcastWithAxisExp::kDim, + ExpInfo::kDim> + broadcast_keepdim(const Exp &src, const int axis, const index_t size) { + return BroadcastWithAxisExp::kDim, + ExpInfo::kDim>(src.self(), axis, size); +} + //---------------------- // Execution plan //---------------------- diff --git a/mshadow/extension/reduce_with_axis.h b/mshadow/extension/reduce_with_axis.h index 5378e465..b6ed5488 100644 --- a/mshadow/extension/reduce_with_axis.h +++ b/mshadow/extension/reduce_with_axis.h @@ -32,8 +32,9 @@ struct ReduceWithAxisExp: /*! \brief size of last src dimension */ index_t last_; /*! constructor */ - explicit ReduceWithAxisExp(const SrcExp &src, int axis, int keepdim) + explicit ReduceWithAxisExp(const SrcExp &src, int axis) : src_(src) { + bool keepdim = (dimsrc == dimdst); CHECK(dimsrc > axis) << "reduce axis out of bound"; Shape src_shape = ShapeCheck::Check(src_); for (index_t i = 0; i < axis; ++i) { @@ -63,18 +64,34 @@ struct ReduceWithAxisExp: * \brief reduce out the dimension of src labeled by axis. * \param Reducer type of the reducing operation * \param mask whether to output the unmask indices - * \param keepdim the keepdim flag * \tparam SrcExp source expression * \tparam DType data type * \tparam etype type of the expression */ -template +template inline ReduceWithAxisExp::kDim, mask, - ExpInfo::kDim + keepdim - 1> + ExpInfo::kDim - 1> reduce_with_axis(const Exp &src, int axis) { return ReduceWithAxisExp::kDim, mask, - ExpInfo::kDim + keepdim - 1>(src.self(), axis, keepdim); + ExpInfo::kDim- 1>(src.self(), axis); } + +/*! +* \brief reduce out the dimension of src labeled by axis, keepdim turned on. +* \param Reducer type of the reducing operation +* \param mask whether to output the unmask indices +* \tparam SrcExp source expression +* \tparam DType data type +* \tparam etype type of the expression +*/ +template +inline ReduceWithAxisExp::kDim, mask, + ExpInfo::kDim> + reduce_keepdim(const Exp &src, int axis) { + return ReduceWithAxisExp::kDim, mask, + ExpInfo::kDim>(src.self(), axis); +} + //---------------------- // Execution plan //---------------------- diff --git a/test/test_tblob.cc b/test/test_tblob.cc index 4c2c9c91..68f52fc9 100644 --- a/test/test_tblob.cc +++ b/test/test_tblob.cc @@ -80,7 +80,7 @@ void test_broadcast_with_axis() { input_tensor = 11; mshadow::Tensor n_tensor(NULL, test_shapes[dim + 1]); mshadow::AllocSpace(&n_tensor); - n_tensor = broadcast_with_axis<0>(input_tensor, dim, 5); + n_tensor = broadcast_with_axis(input_tensor, dim, 5); printf("Test for keepdim = 0, dim = %d", dim); for (index_t i = 0; i < n_tensor.shape_[0]; i++) { for (index_t j = 0; j < n_tensor.shape_[1]; j++) { @@ -100,7 +100,7 @@ void test_broadcast_with_axis() { input_tensor = 11; mshadow::Tensor n_tensor(NULL, test_shapes[dim]); mshadow::AllocSpace(&n_tensor); - n_tensor = broadcast_with_axis<1>(input_tensor, dim, 5); + n_tensor = broadcast_keepdim(input_tensor, dim, 5); printf("Test for keepdim = 1, dim = %d", dim); for (index_t i = 0; i < n_tensor.shape_[0]; i++) { for (index_t j = 0; j < n_tensor.shape_[1]; j++) { @@ -136,7 +136,7 @@ void test_reduce_with_axis() { input_tensor = 1; mshadow::Tensor n_tensor(NULL, mshadow::Shape3(2, 3, 4)); mshadow::AllocSpace(&n_tensor); - n_tensor = reduce_with_axis(input_tensor, dim); + n_tensor = reduce_with_axis(input_tensor, dim); printf("Test for keepdim = 0, dim = %d", dim); for (index_t i = 0; i < n_tensor.shape_[0]; i++) { for (index_t j = 0; j < n_tensor.shape_[1]; j++) { @@ -154,7 +154,7 @@ void test_reduce_with_axis() { input_tensor = 1; mshadow::Tensor n_tensor(NULL, keepdim_output_shapes[dim]); mshadow::AllocSpace(&n_tensor); - n_tensor = reduce_with_axis(input_tensor, dim); + n_tensor = reduce_keepdim(input_tensor, dim); printf("Test for keepdim = 1, dim = %d", dim); for (index_t i = 0; i < n_tensor.shape_[0]; i++) { for (index_t j = 0; j < n_tensor.shape_[1]; j++) {