Skip to content
This repository has been archived by the owner on Aug 11, 2020. It is now read-only.

Commit

Permalink
Merge pull request #121 from sxjscience/add_keepdim_reduce_broadcast
Browse files Browse the repository at this point in the history
Add keepdim reduce broadcast
  • Loading branch information
tqchen committed May 27, 2016
2 parents f3dba81 + a44054e commit a90696e
Show file tree
Hide file tree
Showing 3 changed files with 211 additions and 71 deletions.
110 changes: 65 additions & 45 deletions mshadow/extension/broadcast_with_axis.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,77 +12,97 @@
namespace mshadow {
namespace expr {

/*! \brief Backward for tensor dot
* \tparam DataExp type of left expression
* \tparam TopExp type of right expression
* \tparam DType data type
*/
template<typename SrcExp, typename DType, int srcdim>
/*!
* \brief Broadcasting the tensor in the given axis. If keepdim is off, insert the broadcasting dim after axis. Otherwise broadcasting axis.
* \tparam SrcExp source expression
* \tparam DType data type
* \tparam dimsrc source dimension
* \tparam dimdst destination dimension
*/
template<typename SrcExp, typename DType, int dimsrc, int dimdst>
struct BroadcastWithAxisExp:
public MakeTensorExp<BroadcastWithAxisExp<SrcExp, DType, srcdim>,
SrcExp, srcdim+1, DType> {
public MakeTensorExp<BroadcastWithAxisExp<SrcExp, DType, dimsrc, dimdst>,
SrcExp, dimdst, DType> {
/*! \brief data oprand */
const SrcExp &src_;
/*! \brief size of middle dimension */
index_t leading_;
/*! \brief size of middle dimension */
/*! \brief size of the last dimension of dst */
index_t dst_last_;
/*! \brief product of the dimensions after the broadcasting axis */
index_t trailing_;
/*! \brief size of middle dimension */
/*! \brief new dimension of the broadcasting axis*/
index_t size_;
/*! \brief size of middle dimension */
/*! \brief size of the last dimension of src*/
index_t last_;
/*! constructor */
BroadcastWithAxisExp(const SrcExp &src, const int axis, const index_t size)
BroadcastWithAxisExp(const SrcExp &src, const int axis, const index_t size, int keepdim)
: src_(src), size_(size) {
CHECK(srcdim > axis) << "broadcast axis out of bound";
Shape<srcdim> src_shape = ShapeCheck<srcdim, SrcExp>::Check(src_);
this->leading_ = 1;
for (index_t i = 0; i <= axis; ++i) {
this->leading_ *= src_shape[i];
this->shape_[i] = src_shape[i];
}
this->shape_[axis+1] = size_;
Shape<dimsrc> src_shape = ShapeCheck<dimsrc, SrcExp>::Check(src_);
this->trailing_ = 1;
for (index_t i = axis+1; i < srcdim; ++i) {
this->trailing_ *= src_shape[i];
this->shape_[i+1] = src_shape[i];

if (!keepdim) {
CHECK(dimsrc > axis && axis >= -1) << "broadcast axis (no keepdim) out of bound, " <<
"axis must be between -1 and" << dimsrc - 1 << ", given=" << axis << ".";
for (int i = 0; i <= axis; ++i) {
this->shape_[i] = src_shape[i];
}
this->shape_[axis + 1] = size_;
for (int i = axis + 1; i < dimsrc; ++i) {
this->trailing_ *= src_shape[i];
this->shape_[i + 1] = src_shape[i];
}
} else {
CHECK(dimdst > axis && axis >= 0) << "broadcast axis (keepdim) out of bound, " <<
"axis must be between 0 and" << dimdst - 1 << ", given=" << axis << ".";
CHECK_EQ(src_shape[axis], 1) << "Size of the dimension of the broadcasting axis must be 1" <<
" when keepdim is on, src_shape[" << axis << "]=" << src_shape[axis] << ".";
for (int i = 0; i <= axis - 1; ++i) {
this->shape_[i] = src_shape[i];
}
this->shape_[axis] = size_;
for (int i = axis + 1; i < dimdst; ++i) {
this->trailing_ *= src_shape[i];
this->shape_[i] = src_shape[i];
}
}
this->last_ = src_shape[srcdim-1];

this->last_ = src_shape[dimsrc - 1];
this->dst_last_ = this->shape_[dimdst - 1];
}
}; // struct BroadcastWithAxisExp

/*!
* \brief pooling subregion results together
* \param data data oprand
* \param top top grad oprand
* \tparam DataExp left expression
* \tparam TopExp right expression
* \tparam DType the content data type
* \brief Broadcasting the tensor in the given axis. If keepdim is off, insert the broadcasting dim after axis. Otherwise broadcasting axis.
* \param keepdim whether to keepdim
* \param SrcExp source expression
* \tparam DType data type
* \tparam etype type of the expression
*/
template<typename SrcExp, typename DType, int etype>
inline BroadcastWithAxisExp<SrcExp, DType, ExpInfo<SrcExp>::kDim>
template<int keepdim, typename SrcExp, typename DType, int etype>
inline BroadcastWithAxisExp<SrcExp, DType, ExpInfo<SrcExp>::kDim,
ExpInfo<SrcExp>::kDim + 1 - keepdim>
broadcast_with_axis(const Exp<SrcExp, DType, etype> &src, const int axis, const index_t size) {
return BroadcastWithAxisExp<SrcExp, DType, ExpInfo<SrcExp>::kDim>(src.self(), axis, size);
return BroadcastWithAxisExp<SrcExp, DType, ExpInfo<SrcExp>::kDim,
ExpInfo<SrcExp>::kDim + 1 - keepdim>(src.self(), axis, size, keepdim);
}
//----------------------
// Execution plan
//----------------------
template<typename SrcExp, typename DType, int srcdim>
struct Plan<BroadcastWithAxisExp<SrcExp, DType, srcdim>, DType> {
template<typename SrcExp, typename DType, int dimsrc, int dimdst>
struct Plan<BroadcastWithAxisExp<SrcExp, DType, dimsrc, dimdst>, DType> {
public:
explicit Plan(const BroadcastWithAxisExp<SrcExp, DType, srcdim> &e)
: src_(MakePlan(e.src_)), leading_(e.leading_),
trailing_(e.trailing_), size_(e.size_), last_(e.last_) {}
explicit Plan(const BroadcastWithAxisExp<SrcExp, DType, dimsrc, dimdst> &e)
: src_(MakePlan(e.src_)), dst_last_(e.dst_last_),
trailing_(e.trailing_), size_(e.size_), last_(e.last_) {}
MSHADOW_XINLINE DType Eval(index_t i, index_t j) const {
index_t x = (i*last_+j)/trailing_/size_;
index_t y = (i*last_+j)%trailing_;
index_t z = x*trailing_ + y;
return src_.Eval(z/last_, z%last_);
index_t x = (i * dst_last_ + j) / trailing_ / size_;
index_t y = (i * dst_last_ + j) % trailing_;
index_t z = x * trailing_ + y;
return src_.Eval(z / last_, z % last_);
}

private:
Plan<SrcExp, DType> src_;
const index_t leading_, trailing_, size_, last_;
const index_t dst_last_, trailing_, size_, last_;
};
} // namespace expr
} // namespace mshadow
Expand Down
60 changes: 34 additions & 26 deletions mshadow/extension/reduce_with_axis.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@ namespace expr {
* \tparam SrcExp type of source expression
* \tparam DType data type
*/
template<typename Reducer, typename SrcExp, typename DType, int srcdim, bool mask>
template<typename Reducer, typename SrcExp, typename DType, int dimsrc, bool mask, int dimdst>
struct ReduceWithAxisExp:
public MakeTensorExp<ReduceWithAxisExp<Reducer, SrcExp, DType, srcdim, mask>,
SrcExp, srcdim-1, DType> {
public MakeTensorExp<ReduceWithAxisExp<Reducer, SrcExp, DType, dimsrc, mask, dimdst>,
SrcExp, dimdst, DType> {
/*! \brief source oprand */
const SrcExp &src_;
/*! \brief size of last destination dimension */
Expand All @@ -32,48 +32,56 @@ struct ReduceWithAxisExp:
/*! \brief size of last src dimension */
index_t last_;
/*! constructor */
explicit ReduceWithAxisExp(const SrcExp &src, int axis)
explicit ReduceWithAxisExp(const SrcExp &src, int axis, int keepdim)
: src_(src) {
CHECK(srcdim > axis) << "reduce axis out of bound";
Shape<srcdim> src_shape = ShapeCheck<srcdim, SrcExp>::Check(src_);
CHECK(dimsrc > axis) << "reduce axis out of bound";
Shape<dimsrc> src_shape = ShapeCheck<dimsrc, SrcExp>::Check(src_);
for (index_t i = 0; i < axis; ++i) {
this->shape_[i] = src_shape[i];
}
this->size_ = src_shape[axis];
this->trailing_ = 1;
for (index_t i = axis + 1; i < srcdim; ++i) {
this->trailing_ *= src_shape[i];
this->shape_[i-1] = src_shape[i];
}
this->last_ = src_shape[srcdim-1];
if (axis == srcdim -1) {
this->last_dst_dim_ = src_shape[srcdim-2];
if (!keepdim) {
for (index_t i = axis + 1; i < dimsrc; ++i) {
this->trailing_ *= src_shape[i];
this->shape_[i - 1] = src_shape[i];
}
} else {
this->last_dst_dim_ = src_shape[srcdim-1];
this->shape_[axis] = 1;
for (index_t i = axis + 1; i < dimsrc; ++i) {
this->trailing_ *= src_shape[i];
this->shape_[i] = src_shape[i];
}
}

this->last_ = src_shape[dimsrc - 1];
this->last_dst_dim_ = this->shape_[dimdst - 1];
}
}; // struct ReduceWithAxisExp

/*!
* \brief pooling subregion results together
* \param lhs left oprand
* \param rhs right oprand
* \tparam LhsExp left expression
* \tparam RhsExp right expression
* \tparam DType the content data type
* \brief reduce out the dimension of src labeled by axis.
* \param Reducer type of the reducing operation
* \param mask whether to output the unmask indices
* \param keepdim the keepdim flag
* \tparam SrcExp source expression
* \tparam DType data type
* \tparam etype type of the expression
*/
template<typename Reducer, bool mask, typename SrcExp, typename DType, int etype>
inline ReduceWithAxisExp<Reducer, SrcExp, DType, ExpInfo<SrcExp>::kDim, mask>
template<typename Reducer, bool mask, int keepdim, typename SrcExp, typename DType, int etype>
inline ReduceWithAxisExp<Reducer, SrcExp, DType, ExpInfo<SrcExp>::kDim, mask,
ExpInfo<SrcExp>::kDim + keepdim - 1>
reduce_with_axis(const Exp<SrcExp, DType, etype> &src, int axis) {
return ReduceWithAxisExp<Reducer, SrcExp, DType, ExpInfo<SrcExp>::kDim, mask>(src.self(), axis);
return ReduceWithAxisExp<Reducer, SrcExp, DType, ExpInfo<SrcExp>::kDim, mask,
ExpInfo<SrcExp>::kDim + keepdim - 1>(src.self(), axis, keepdim);
}
//----------------------
// Execution plan
//----------------------
template<typename Reducer, typename SrcExp, typename DType, int srcdim, bool mask>
struct Plan<ReduceWithAxisExp<Reducer, SrcExp, DType, srcdim, mask>, DType> {
template<typename Reducer, typename SrcExp, typename DType, int dimsrc, bool mask, int dimdst>
struct Plan<ReduceWithAxisExp<Reducer, SrcExp, DType, dimsrc, mask, dimdst>, DType> {
public:
explicit Plan(const ReduceWithAxisExp<Reducer, SrcExp, DType, srcdim, mask> &e)
explicit Plan(const ReduceWithAxisExp<Reducer, SrcExp, DType, dimsrc, mask, dimdst> &e)
: src_(MakePlan(e.src_)), last_dst_dim_(e.last_dst_dim_), trailing_(e.trailing_),
size_(e.size_), last_(e.last_) {}
MSHADOW_XINLINE DType Eval(index_t i, index_t j) const {
Expand Down
112 changes: 112 additions & 0 deletions test/test_tblob.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include <mshadow/tensor.h>
#include <vector>

// this namespace contains all data structures, functions
using namespace mshadow;
Expand Down Expand Up @@ -59,8 +60,119 @@ void test_tshape() {
}
}

void test_broadcast_with_axis() {
std::vector<mshadow::Shape<4> > test_shapes;
std::vector<mshadow::Shape<4> > keepdim_input_shapes;

test_shapes.push_back(mshadow::Shape4(5, 2, 3, 4));
test_shapes.push_back(mshadow::Shape4(2, 5, 3, 4));
test_shapes.push_back(mshadow::Shape4(2, 3, 5, 4));
test_shapes.push_back(mshadow::Shape4(2, 3, 4, 5));

keepdim_input_shapes.push_back(mshadow::Shape4(1, 2, 3, 4));
keepdim_input_shapes.push_back(mshadow::Shape4(2, 1, 3, 4));
keepdim_input_shapes.push_back(mshadow::Shape4(2, 3, 1, 4));
keepdim_input_shapes.push_back(mshadow::Shape4(2, 3, 4, 1));

for (int dim = -1; dim < 3; ++dim){
mshadow::Tensor<mshadow::cpu, 3> input_tensor(NULL, mshadow::Shape3(2, 3, 4));
mshadow::AllocSpace(&input_tensor);
input_tensor = 11;
mshadow::Tensor<mshadow::cpu, 4> n_tensor(NULL, test_shapes[dim + 1]);
mshadow::AllocSpace(&n_tensor);
n_tensor = broadcast_with_axis<0>(input_tensor, dim, 5);
printf("Test for keepdim = 0, dim = %d", dim);
for (index_t i = 0; i < n_tensor.shape_[0]; i++) {
for (index_t j = 0; j < n_tensor.shape_[1]; j++) {
for (index_t k = 0; k < n_tensor.shape_[2]; k++) {
for (index_t l = 0; l < n_tensor.shape_[3]; l++) {
CHECK_EQ(n_tensor[i][j][k][l], 11);
}
}
}
}
printf(" Pass!\n");
}

for (int dim = 0; dim < 4; ++dim){
mshadow::Tensor<mshadow::cpu, 4> input_tensor(NULL, keepdim_input_shapes[dim]);
mshadow::AllocSpace(&input_tensor);
input_tensor = 11;
mshadow::Tensor<mshadow::cpu, 4> n_tensor(NULL, test_shapes[dim]);
mshadow::AllocSpace(&n_tensor);
n_tensor = broadcast_with_axis<1>(input_tensor, dim, 5);
printf("Test for keepdim = 1, dim = %d", dim);
for (index_t i = 0; i < n_tensor.shape_[0]; i++) {
for (index_t j = 0; j < n_tensor.shape_[1]; j++) {
for (index_t k = 0; k < n_tensor.shape_[2]; k++) {
for (index_t l = 0; l < n_tensor.shape_[3]; l++) {
CHECK_EQ(n_tensor[i][j][k][l], 11);
}
}
}
}
printf(" Pass!\n");

}
}

void test_reduce_with_axis() {
std::vector<mshadow::Shape<4> > test_shapes;
std::vector<mshadow::Shape<4> > keepdim_output_shapes;

test_shapes.push_back(mshadow::Shape4(5, 2, 3, 4));
test_shapes.push_back(mshadow::Shape4(2, 5, 3, 4));
test_shapes.push_back(mshadow::Shape4(2, 3, 5, 4));
test_shapes.push_back(mshadow::Shape4(2, 3, 4, 5));

keepdim_output_shapes.push_back(mshadow::Shape4(1, 2, 3, 4));
keepdim_output_shapes.push_back(mshadow::Shape4(2, 1, 3, 4));
keepdim_output_shapes.push_back(mshadow::Shape4(2, 3, 1, 4));
keepdim_output_shapes.push_back(mshadow::Shape4(2, 3, 4, 1));

for (int dim = 0; dim < 4; ++dim){
mshadow::Tensor<mshadow::cpu, 4> input_tensor(NULL, test_shapes[dim]);
mshadow::AllocSpace(&input_tensor);
input_tensor = 1;
mshadow::Tensor<mshadow::cpu, 3> n_tensor(NULL, mshadow::Shape3(2, 3, 4));
mshadow::AllocSpace(&n_tensor);
n_tensor = reduce_with_axis<mshadow::red::sum, false, 0>(input_tensor, dim);
printf("Test for keepdim = 0, dim = %d", dim);
for (index_t i = 0; i < n_tensor.shape_[0]; i++) {
for (index_t j = 0; j < n_tensor.shape_[1]; j++) {
for (index_t k = 0; k < n_tensor.shape_[2]; k++) {
CHECK_EQ(n_tensor[i][j][k], 5);
}
}
}
printf(" Pass!\n");
}

for (int dim = 0; dim < 4; ++dim){
mshadow::Tensor<mshadow::cpu, 4> input_tensor(NULL, test_shapes[dim]);
mshadow::AllocSpace(&input_tensor);
input_tensor = 1;
mshadow::Tensor<mshadow::cpu, 4> n_tensor(NULL, keepdim_output_shapes[dim]);
mshadow::AllocSpace(&n_tensor);
n_tensor = reduce_with_axis<mshadow::red::sum, false, 1>(input_tensor, dim);
printf("Test for keepdim = 1, dim = %d", dim);
for (index_t i = 0; i < n_tensor.shape_[0]; i++) {
for (index_t j = 0; j < n_tensor.shape_[1]; j++) {
for (index_t k = 0; k < n_tensor.shape_[2]; k++) {
for (index_t l = 0; l < n_tensor.shape_[3]; l++) {
CHECK_EQ(n_tensor[i][j][k][l], 5);
}
}
}
}
printf(" Pass!\n");
}
}

int main(void) {
test_tshape();
test_broadcast_with_axis();
test_reduce_with_axis();
return 0;
}

Expand Down

0 comments on commit a90696e

Please sign in to comment.