Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix python API + small fixes #32

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 31 additions & 2 deletions benchmark/reference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,35 @@

#include "defines.h"

// Check if 'floatType' is complex
template <typename floatType>
class is_complex {
template <
template <typename> typename _floatType, typename T,
typename = typename std::enable_if<
std::is_same<_floatType<T>, std::complex<T>>::value, bool>::type>
bool static constexpr _is_complex(_floatType<T>) {
return true;
}
template <typename _floatType>
bool static constexpr _is_complex(_floatType) {
return false;
}

public:
static constexpr auto value =
_is_complex(typename std::remove_reference<
typename std::remove_cv<floatType>::type>::type());
};


// Re-define 'conj' to make sure it return a floating point if
// argument is a floating point
template <typename floatType, typename = typename std::enable_if<
!is_complex<floatType>::value, bool>::type>
floatType conj(floatType &&x) {
return x;
}

template<typename floatType>
void transpose_ref( uint32_t *size, uint32_t *perm, int dim,
Expand Down Expand Up @@ -57,13 +86,13 @@ void transpose_ref( uint32_t *size, uint32_t *perm, int dim,
if( beta == (floatType) 0 )
for(int i=0; i < sizeInner; ++i)
if( conjA )
B_[i] = alpha * std::conj(A_[i * strideAinner]);
B_[i] = alpha * conj(A_[i * strideAinner]);
else
B_[i] = alpha * A_[i * strideAinner];
else
for(int i=0; i < sizeInner; ++i)
if( conjA )
B_[i] = alpha * std::conj(A_[i * strideAinner]) + beta * B_[i];
B_[i] = alpha * conj(A_[i * strideAinner]) + beta * B_[i];
else
B_[i] = alpha * A_[i * strideAinner] + beta * B_[i];
}
Expand Down
16 changes: 11 additions & 5 deletions pythonAPI/hptt/hptt.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,10 +115,16 @@ def tensorTransposeAndUpdate(perm, alpha, A, beta, B, numThreads=-1):
raise ValueError("Unsupported dtype: {}.".format(A.dtype))

# tranpose!
tranpose_fn(permc, ctypes.c_int32(A.ndim),
scalar_fn(alpha), dataA, sizeA, outerSizeA,
scalar_fn(beta), dataB, outerSizeB,
ctypes.c_int32(numThreads), ctypes.c_int32(useRowMajor))
if 'float' in str(A.dtype):
tranpose_fn(permc, ctypes.c_int32(A.ndim),
scalar_fn(alpha), dataA, sizeA, outerSizeA,
scalar_fn(beta), dataB, outerSizeB,
ctypes.c_int32(numThreads), ctypes.c_int32(useRowMajor))
else:
tranpose_fn(permc, ctypes.c_int32(A.ndim),
scalar_fn(alpha), False, dataA, sizeA, outerSizeA,
scalar_fn(beta), dataB, outerSizeB,
ctypes.c_int32(numThreads), ctypes.c_int32(useRowMajor))


def tensorTranspose(perm, alpha, A, numThreads=-1):
Expand Down Expand Up @@ -168,7 +174,7 @@ def transpose(a, axes=None):
``a`` with its axes permuted.
"""
if axes is None:
axes = reversed(range(a.ndim))
axes = list(reversed(range(a.ndim)))

return tensorTranspose(axes, 1.0, a)

Expand Down
8 changes: 4 additions & 4 deletions src/hptt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -171,8 +171,8 @@ void dTensorTranspose( const int *perm, const int dim,
}

void cTensorTranspose( const int *perm, const int dim,
const float _Complex alpha, bool conjA, const float _Complex *A, const int *sizeA, const int *outerSizeA,
const float _Complex beta, float _Complex *B, const int *outerSizeB,
const float alpha, bool conjA, const float *A, const int *sizeA, const int *outerSizeA,
const float beta, float *B, const int *outerSizeB,
const int numThreads, const int useRowMajor)
{
auto plan(std::make_shared<hptt::Transpose<hptt::FloatComplex> >(sizeA, perm, outerSizeA, outerSizeB, dim,
Expand All @@ -182,8 +182,8 @@ void cTensorTranspose( const int *perm, const int dim,
}

void zTensorTranspose( const int *perm, const int dim,
const double _Complex alpha, bool conjA, const double _Complex *A, const int *sizeA, const int *outerSizeA,
const double _Complex beta, double _Complex *B, const int *outerSizeB,
const double alpha, bool conjA, const double *A, const int *sizeA, const int *outerSizeA,
const double beta, double *B, const int *outerSizeB,
const int numThreads, const int useRowMajor)
{
auto plan(std::make_shared<hptt::Transpose<hptt::DoubleComplex> >(sizeA, perm, outerSizeA, outerSizeB, dim,
Expand Down