Skip to content
This repository has been archived by the owner on Aug 11, 2020. It is now read-only.

Commit

Permalink
Merge pull request #170 from Lorrainexun/master
Browse files Browse the repository at this point in the history
Add Trinary op template in mshadow
  • Loading branch information
winstywang authored Oct 13, 2016
2 parents 7fd120a + b5fe218 commit 4d70c0e
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 2 deletions.
56 changes: 56 additions & 0 deletions mshadow/expr_engine-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,23 @@ class Plan<TypecastExp<DstDType, SrcDType, EType, etype>, DstDType> {
private:
Plan<EType, SrcDType> src_;
};

// ternary expression
template<typename OP, typename TA, typename TB, typename TC, int etype, typename DType>
class Plan<TernaryMapExp<OP, TA, TB, TC, DType, etype>, DType> {
public:
explicit Plan(const Plan<TA, DType> &item1, const Plan<TB, DType> &item2,
const Plan<TC, DType> &item3)
: item1_(item1), item2_(item2), item3_(item3) {}
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const {
return OP::Map(item1_.Eval(y, x), item2_.Eval(y, x), item3_.Eval(y, x));
}

private:
Plan<TA, DType> item1_;
Plan<TB, DType> item2_;
Plan<TC, DType> item3_;
};
// binary expression
template<typename OP, typename TA, typename TB, int etype, typename DType>
class Plan<BinaryMapExp<OP, TA, TB, DType, etype>, DType> {
Expand Down Expand Up @@ -161,6 +178,10 @@ template<typename OP, typename TA, typename TB, typename DType, int etype>
inline Plan<BinaryMapExp<OP, TA, TB, DType, etype>, DType>
MakePlan(const BinaryMapExp<OP, TA, TB, DType, etype> &e);

template<typename OP, typename TA, typename TB, typename TC, typename DType, int etype>
inline Plan<TernaryMapExp<OP, TA, TB, TC, DType, etype>, DType>
MakePlan(const TernaryMapExp<OP, TA, TB, TC, DType, etype> &e);

template<typename DType>
inline Plan<ScalarExp<DType>, DType> MakePlan(const ScalarExp<DType> &e) {
return Plan<ScalarExp<DType>, DType>(e.scalar_);
Expand Down Expand Up @@ -201,6 +222,14 @@ MakePlan(const BinaryMapExp<OP, TA, TB, DType, etype> &e) {
return Plan<BinaryMapExp<OP, TA, TB, DType, etype>,
DType>(MakePlan(e.lhs_), MakePlan(e.rhs_));
}

// Ternary
template<typename OP, typename TA, typename TB, typename TC, typename DType, int etype>
inline Plan<TernaryMapExp<OP, TA, TB, TC, DType, etype>, DType>
MakePlan(const TernaryMapExp<OP, TA, TB, TC, DType, etype> &e) {
return Plan<TernaryMapExp<OP, TA, TB, TC, DType, etype>,
DType>(MakePlan(e.item1_), MakePlan(e.item2_), MakePlan(e.item3_));
}
//----------------------------------------------------------------
// Static Type inference and Type Checking
//----------------------------------------------------------------
Expand Down Expand Up @@ -257,6 +286,15 @@ struct ExpInfo<BinaryMapExp<OP, TA, TB, DType, etype> > {
((kDimRhs == 0 || kDimLhs == kDimRhs) ? kDimLhs : -1)) : -1;
static const int kDevMask = ExpInfo<TA>::kDevMask & ExpInfo<TB>::kDevMask;
};
template<typename OP, typename TA, typename TB, typename TC, typename DType, int etype>
struct ExpInfo<TernaryMapExp<OP, TA, TB, TC, DType, etype> > {
static const int kDimItem1 = ExpInfo<TA>::kDim;
static const int kDimItem2 = ExpInfo<TB>::kDim;
static const int kDimItem3 = ExpInfo<TC>::kDim;
static const int kDim = kDimItem1;
static const int kDevMask = ExpInfo<TA>::kDevMask & ExpInfo<TB>::kDevMask & ExpInfo<TC>::kDevMask;
};

/*! \brief template to do type check */
template<typename Device, int dim, typename DType, typename E>
struct TypeCheck {
Expand Down Expand Up @@ -355,6 +393,7 @@ struct ShapeCheck<dim, UnaryMapExp<OP, TA, DType, etype> > {
return s;
}
};

template<int dim, typename OP, typename TA, typename TB,
typename DType, int etype>
struct ShapeCheck<dim, BinaryMapExp<OP, TA, TB, DType, etype> > {
Expand All @@ -369,7 +408,24 @@ struct ShapeCheck<dim, BinaryMapExp<OP, TA, TB, DType, etype> > {
return shape1;
}
};

template<int dim, typename OP, typename TA, typename TB, typename TC,
typename DType, int etype>
struct ShapeCheck<dim, TernaryMapExp<OP, TA, TB, TC, DType, etype> > {
inline static Shape<dim>
Check(const TernaryMapExp<OP, TA, TB, TC, DType, etype> &t) {
Shape<dim> shape1 = ShapeCheck<dim, TA>::Check(t.item1_);
Shape<dim> shape2 = ShapeCheck<dim, TB>::Check(t.item2_);
Shape<dim> shape3 = ShapeCheck<dim, TC>::Check(t.item3_);
bool same = (shape1 == shape2) && (shape2 == shape3);
CHECK(same) << "TernaryMapExp: Shapes of operands are not the same, " <<
"Shape1=" << shape1 << ", Shape2=" << shape2 << ", Shape3=" << shape3;

return shape1;
}
};
} // namespace expr

} // namespace mshadow
// include definition of dot engine
#include "./dot_engine-inl.h"
Expand Down
58 changes: 56 additions & 2 deletions mshadow/expression.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,7 @@ template<typename DstDType, typename SrcDType,
typename EType, int etype>
inline TypecastExp<DstDType, SrcDType, EType, (etype|type::kMapper)>
tcast(const Exp<EType, SrcDType, etype> &exp) {
return TypecastExp<DstDType, SrcDType, EType,
(etype|type::kMapper)>(exp.self());
return TypecastExp<DstDType, SrcDType, EType, (etype|type::kMapper)>(exp.self());
}
/*! \brief represent a transpose expression of a container */
template<typename EType, typename DType>
Expand Down Expand Up @@ -249,6 +248,61 @@ batch_dot(const RValueExp<TA, DType> &lhs, const RValueExp<TB, DType> &rhs) {
lhs.self(), rhs.self(), DType(1.0f));
}
//---------------
// TernaryMapExp
// --------------
/*!
* \brief ternary map expression
* \tparam OP operator
* \tparam TA type of item1
* \tparam TB type of item2
* \tparam etype expression type, sa namespace::type
*/
template<typename OP, typename TA, typename TB, typename TC, typename DType, int etype>
struct TernaryMapExp: public Exp<TernaryMapExp<OP, TA, TB, TC, DType, etype>,
DType, etype> {
/*! \brief first operand */
const TA &item1_;
/*! \brief second operand */
const TB &item2_;
/*! \brief third operand */
const TC &item3_;
/*! \brief constructor */
explicit TernaryMapExp(const TA &item1, const TB &item2, const TC &item3)
:item1_(item1), item2_(item2), item3_(item3) {}
};

/*! \brief make expression */
template<typename OP, typename TA, typename TB, typename TC, typename DType, int ta, int tb, int tc>
inline TernaryMapExp<OP, TA, TB, TC, DType, (ta|tb|tc|type::kMapper)>
MakeExp(const Exp<TA, DType, ta> &item1, const Exp<TB, DType, tb> &item2,
const Exp<TC, DType, tc> &item3) {
return TernaryMapExp<OP, TA, TB, TC, DType,
(ta|tb|tc|type::kMapper)>(item1.self(), item2.self(), item3.self());
}
/*!
* \brief short hand for MakeExp, usage F<op>(item1,item2,item3). create a ternary operation expression
* \param item1 first operand
* \param item2 second operand
* \param item3 third operand
* \return the result expression
* \tparam ternary operator
* \tparam TA item1 expression
* \tparam ta item1 expression type
* \tparam TB item2 expression
* \tparam tb item2 expression type
* \tparam TC item3 expression
* \tparam tc item3 expression type
* \sa mshadow::op
*/

// Ternary
template<typename OP, typename TA, typename TB, typename TC, typename DType, int ta, int tb, int tc>
inline TernaryMapExp<OP, TA, TB, TC, DType, (ta|tb|tc|type::kMapper)>
F(const Exp<TA, DType, ta> &item1, const Exp<TB, DType, tb> &item2,
const Exp<TC, DType, tc> &item3) {
return MakeExp<OP>(item1, item2, item3);
}
//---------------
// BinaryMapExp
// --------------
/*!
Expand Down

0 comments on commit 4d70c0e

Please sign in to comment.