Skip to content

Commit

Permalink
Improve NEON 'shuff' in NEONv2 and add 'compress' NEONv1 support.
Browse files Browse the repository at this point in the history
  • Loading branch information
kouchy committed Apr 27, 2024
1 parent e739f3c commit 35d8fbf
Show file tree
Hide file tree
Showing 3 changed files with 166 additions and 117 deletions.
5 changes: 3 additions & 2 deletions TODO..md → TODO.md
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
# TODO

- [ ] Add `compress` for NEONv1 with emulation of `vqtbl1q` based on two `vtbl2`
- [ ] Improve NEONv2 `shuff` operations with `vqtbl1q` instruction
- [ ] Create a docker image with "Intel Software Development Emulator" to enable
AVX-512 instructions emulation on the runners that does not support
native AVX-512
- [ ] Find a workaround for the 16-bit SSE `compress` that requires BMI2
extension (remove `_pext_u32` dependency, available from Haswell)
- [ ] Compile the examples in the `CMakeFiles.txt
- [x] Add `compress` for NEONv1 with emulation of `vqtbl1q` based on two `vtbl2`
- [x] Improve NEONv2 `shuff` operations with `vqtbl1q` instruction
250 changes: 152 additions & 98 deletions include/mipp_impl_NEON.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -948,6 +948,37 @@
}

// ---------------------------------------------------------------------------------------------------------- shuff
#ifdef __aarch64__
template <>
inline reg shuff<double>(const reg v, const reg cm) {
return (reg)vqtbl1q_u8((uint8x16_t)v, (uint8x16_t)cm);
}

template <>
inline reg shuff<float>(const reg v, const reg cm) {
return (reg)vqtbl1q_u8((uint8x16_t)v, (uint8x16_t)cm);
}

template <>
inline reg shuff<int64_t>(const reg v, const reg cm) {
return (reg)vqtbl1q_u8((uint8x16_t)v, (uint8x16_t)cm);
}

template <>
inline reg shuff<int32_t>(const reg v, const reg cm) {
return (reg)vqtbl1q_u8((uint8x16_t)v, (uint8x16_t)cm);
}

template <>
inline reg shuff<int16_t>(const reg v, const reg cm) {
return (reg)vqtbl1q_u8((uint8x16_t)v, (uint8x16_t)cm);
}

template <>
inline reg shuff<int8_t>(const reg v, const reg cm) {
return (reg)vqtbl1q_u8((uint8x16_t)v, (uint8x16_t)cm);
}
#else
template <>
inline reg shuff<double>(const reg v, const reg cm) {
uint8x8x2_t v2 = {{vget_low_u8((uint8x16_t)v), vget_high_u8((uint8x16_t)v)}};
Expand Down Expand Up @@ -1001,6 +1032,7 @@

return (reg)vcombine_u8(low, high);
}
#endif

// --------------------------------------------------------------------------------------------------------- shuff2
template <>
Expand Down Expand Up @@ -1459,104 +1491,6 @@
return res;
}

// ------------------------------------------------------------------------------------------------------- compress
#ifdef MIPP_STATIC_LIB
#ifdef __aarch64__
template <>
inline reg compress<double>(const reg v, const msk m) {
// Convert mask to integer
alignas(16) constexpr uint64x2_t bits = {0x01, 0x02};
uint64x2_t m64 = vandq_u64((uint64x2_t)m, bits);
uint32_t idx = vaddvq_u64(m64);

// Get shuffle from LUT
int8x16_t shuff = vld1q_s8(vcompress_LUT64x2_NEON[idx]);
float64x2_t res = (float64x2_t)vqtbl1q_s8((int8x16_t)v, (uint8x16_t)shuff);

return (reg)res;
}

template <>
inline reg compress<float>(const reg v, const msk m) {
// Convert mask to integer
alignas(16) constexpr uint32x4_t bits = {0x01, 0x02, 0x04, 0x08};
uint32x4_t m32 = vandq_u32((uint32x4_t)m, bits);
uint32_t idx = vaddvq_u32(m32);

// Get shuffle from LUT
int8x16_t shuff = vld1q_s8(vcompress_LUT32x4_NEON[idx]);
float32x4_t res = (float32x4_t)vqtbl1q_s8((int8x16_t)v, (uint8x16_t)shuff);

return (reg) res;
}

template <>
inline reg compress<int64_t>(const reg v, const msk m) {
// Convert mask to integer
alignas(16) constexpr uint64x2_t bits = {0x01, 0x02};
uint64x2_t m64 = vandq_u64((uint64x2_t)m, bits);
uint32_t idx = vaddvq_u64(m64);

// Get shuffle from LUT
int8x16_t shuff = vld1q_s8(vcompress_LUT64x2_NEON[idx]);
int64x2_t res = (int64x2_t)vqtbl1q_s8((int8x16_t)v, (uint8x16_t)shuff);

return (reg) res;
}

template <>
inline reg compress<int32_t>(const reg v, const msk m) {
// Convert mask to integer
alignas(16) constexpr uint32x4_t bits = {0x01, 0x02, 0x04, 0x08};
uint32x4_t m32 = vandq_u32((uint32x4_t)m, bits);
uint32_t idx = vaddvq_u32(m32);

// Get shuffle from LUT
int8x16_t shuff = vld1q_s8(vcompress_LUT32x4_NEON[idx]);
int32x4_t res = (int32x4_t)vqtbl1q_s8((int8x16_t)v, (uint8x16_t)shuff);

return (reg) res;
}

template <>
inline reg compress<int16_t>(const reg v, const msk m) {
// Convert mask to integer
alignas(16) constexpr uint16x8_t bits = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80};
uint16x8_t m32 = vandq_u16((uint16x8_t)m, bits);
uint32_t idx = vaddvq_u16(m32);

// Get shuffle from LUT
int8x16_t shuff = vld1q_s8(vcompress_LUT16x8_NEON[idx]);
int16x8_t res = (int16x8_t)vqtbl1q_s8((int8x16_t)v, (uint8x16_t)shuff);

return (reg) res;
}

template <>
inline reg compress<int8_t>(const reg v, const msk m) {
// Convert mask to integer
alignas(16) constexpr uint8x16_t bits0 = {
0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00};
alignas(16) constexpr uint8x16_t bits1 = {
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80};

uint8x16_t m32a = vandq_u8((uint8x16_t)m, bits0);
uint8x16_t m32b = vandq_u8((uint8x16_t)m, bits1);
uint32_t idx0 = vaddvq_u8(m32a);
uint32_t idx1 = vaddvq_u8(m32b);
uint32_t idx = idx0 | (idx1 << 8);

// Get shuffle from LUT
int8x16_t shuff = vld1q_s8(vcompress_LUT8x16_NEON[idx]);
int8x16_t res = (int8x16_t)vqtbl1q_u8((uint8x16_t)v, (uint8x16_t)shuff);

return (reg) res;
}
#endif
#endif

// ----------------------------------------------------------------------------------------------------------- andb
#ifdef __aarch64__
template <>
Expand Down Expand Up @@ -3821,6 +3755,126 @@
}
#endif

// ------------------------------------------------------------------------------------------------------- compress
#ifdef MIPP_STATIC_LIB
#ifdef __aarch64__
template <>
inline reg compress<double>(const reg v, const msk m) {
// Convert mask to integer
alignas(16) constexpr uint64x2_t bits = {0x01, 0x02};
uint64x2_t m64 = vandq_u64((uint64x2_t)m, bits);
uint32_t idx = vaddvq_u64(m64);

// Get shuffle from LUT
int8x16_t shuff = vld1q_s8(vcompress_LUT64x2_NEON[idx]);
// float64x2_t res = (float64x2_t)vqtbl1q_s8((int8x16_t)v, (uint8x16_t)shuff);
float64x2_t res = (float64x2_t)mipp::shuff<double>((reg)v, (reg)shuff);

return (reg)res;
}
#endif

template <>
inline reg compress<float>(const reg v, const msk m) {
// Convert mask to integer
alignas(16) constexpr uint32x4_t bits = {0x01, 0x02, 0x04, 0x08};
uint32x4_t m32 = vandq_u32((uint32x4_t)m, bits);
#ifdef __aarch64__
uint32_t idx = vaddvq_u32(m32);
#else
uint32_t idx = mipp::getfirst<uint32_t>(mipp::_reduction<uint32_t, mipp::add<uint32_t>>::apply((reg)m32));
#endif
// Get shuffle from LUT
int8x16_t shuff = vld1q_s8(vcompress_LUT32x4_NEON[idx]);
// float32x4_t res = (float32x4_t)vqtbl1q_s8((int8x16_t)v, (uint8x16_t)shuff);
float32x4_t res = (float32x4_t)mipp::shuff<float>((reg)v, (reg)shuff);

return (reg) res;
}

#ifdef __aarch64__
template <>
inline reg compress<int64_t>(const reg v, const msk m) {
// Convert mask to integer
alignas(16) constexpr uint64x2_t bits = {0x01, 0x02};
uint64x2_t m64 = vandq_u64((uint64x2_t)m, bits);
uint32_t idx = vaddvq_u64(m64);

// Get shuffle from LUT
int8x16_t shuff = vld1q_s8(vcompress_LUT64x2_NEON[idx]);
// int64x2_t res = (int64x2_t)vqtbl1q_s8((int8x16_t)v, (uint8x16_t)shuff);
int64x2_t res = (int64x2_t)mipp::shuff<int64_t>((reg)v, (reg)shuff);

return (reg) res;
}
#endif

template <>
inline reg compress<int32_t>(const reg v, const msk m) {
// Convert mask to integer
alignas(16) constexpr uint32x4_t bits = {0x01, 0x02, 0x04, 0x08};
uint32x4_t m32 = vandq_u32((uint32x4_t)m, bits);
#ifdef __aarch64__
uint32_t idx = vaddvq_u32(m32);
#else
uint32_t idx = mipp::getfirst<uint32_t>(mipp::_reduction<uint32_t, mipp::add<uint32_t>>::apply((reg)m32));
#endif
// Get shuffle from LUT
int8x16_t shuff = vld1q_s8(vcompress_LUT32x4_NEON[idx]);
// int32x4_t res = (int32x4_t)vqtbl1q_s8((int8x16_t)v, (uint8x16_t)shuff);
int32x4_t res = (int32x4_t)mipp::shuff<int32_t>((reg)v, (reg)shuff);

return (reg) res;
}

template <>
inline reg compress<int16_t>(const reg v, const msk m) {
// Convert mask to integer
alignas(16) constexpr uint16x8_t bits = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80};
uint16x8_t m32 = vandq_u16((uint16x8_t)m, bits);
#ifdef __aarch64__
uint16_t idx = vaddvq_u16(m32);
#else
uint16_t idx = mipp::getfirst<uint16_t>(mipp::_reduction<uint16_t, mipp::add<uint16_t>>::apply((reg)m32));
#endif
// Get shuffle from LUT
int8x16_t shuff = vld1q_s8(vcompress_LUT16x8_NEON[idx]);
// int16x8_t res = (int16x8_t)vqtbl1q_s8((int8x16_t)v, (uint8x16_t)shuff);
int16x8_t res = (int16x8_t)mipp::shuff<int16_t>((reg)v, (reg)shuff);

return (reg) res;
}

template <>
inline reg compress<int8_t>(const reg v, const msk m) {
// Convert mask to integer
alignas(16) constexpr uint8x16_t bits0 = {
0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00};
alignas(16) constexpr uint8x16_t bits1 = {
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80};

uint8x16_t m32a = vandq_u8((uint8x16_t)m, bits0);
uint8x16_t m32b = vandq_u8((uint8x16_t)m, bits1);
#ifdef __aarch64__
uint8_t idx0 = vaddvq_u8(m32a);
uint8_t idx1 = vaddvq_u8(m32b);
#else
uint8_t idx0 = mipp::getfirst<uint8_t>(mipp::_reduction<uint8_t, mipp::add<uint8_t>>::apply((reg)m32a));
uint8_t idx1 = mipp::getfirst<uint8_t>(mipp::_reduction<uint8_t, mipp::add<uint8_t>>::apply((reg)m32b));
#endif
uint32_t idx = idx0 | (idx1 << 8);

// Get shuffle from LUT
int8x16_t shuff = vld1q_s8(vcompress_LUT8x16_NEON[idx]);
// int8x16_t res = (int8x16_t)vqtbl1q_u8((uint8x16_t)v, (uint8x16_t)shuff);
int8x16_t res = (int8x16_t)mipp::shuff<int8_t>((reg)v, (reg)shuff);

return (reg) res;
}
#endif

// ------------------------------------------------------------------------------------------------------ transpose
template <>
inline void transpose<int16_t>(reg tab[nElReg<int16_t>()]) {
Expand Down
28 changes: 11 additions & 17 deletions tests/src/memory_operations/compress.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,15 @@ template <typename T>
void test_reg_compress()
{
constexpr int N = mipp::N<T>();

T inputs1[N];
T expected[N];
bool mask1[N];

std::iota(inputs1, inputs1 + N, (T)1);
mipp::reg r1 = mipp::load<T>(inputs1);
mipp::reg r2 = mipp::set0<T>();

std::mt19937 g;
for (auto t = 0; t < 1000; t++)
{
Expand All @@ -30,26 +30,23 @@ void test_reg_compress()
{
bool bit = (g() & 1) ? false : true; // Generate random bit
mask1[i] = bit;
if (bit) {
expected[k] = i + (T)1;
k++;
}
if (bit)
expected[k++] = i + (T)1;
}

mipp::msk mask = mipp::set<N>(mask1);

r2 = mipp::compress<T>(r1, mask);

for (auto i = 0; i < N; i++) {
for (auto i = 0; i < N; i++)
REQUIRE(mipp::get<T>(r2, i) == expected[i]);
}
}
}

#if defined(MIPP_STATIC_LIB) && !defined(MIPP_NO)
TEST_CASE("Compress - mipp::reg", "[mipp::compress]")
{
#if (defined(MIPP_SSE) && MIPP_INSTR_VERSION >= 31) || defined(MIPP_AVX512) || defined(MIPP_NEONV2) || (defined(MIPP_AVX2) && defined(MIPP_BMI2))
#if (defined(MIPP_SSE) && MIPP_INSTR_VERSION >= 31) || defined(MIPP_AVX512) || defined(MIPP_NEON) || (defined(MIPP_AVX2) && defined(MIPP_BMI2))
#if defined(MIPP_64BIT)
SECTION("datatype = double") { test_reg_compress<double>(); }
#endif
Expand Down Expand Up @@ -93,26 +90,23 @@ void test_Reg_compress()
{
bool bit = (g() & 1) ? false : true; // Generate random bit
mask1[i] = bit;
if (bit) {
expected[k] = i + (T)1;
k++;
}
if (bit)
expected[k++] = i + (T)1;
}

mipp::Msk<mipp::N<T>()> mask = mask1;

r2 = mipp::compress(r1, mask);

for (auto i = 0; i < N; i++) {
for (auto i = 0; i < N; i++)
REQUIRE(r2[i] == expected[i]);
}
}
}

#if defined(MIPP_STATIC_LIB) && !defined(MIPP_NO)
TEST_CASE("Compress - mipp::Reg", "[mipp::compress]")
{
#if (defined(MIPP_SSE) && MIPP_INSTR_VERSION >= 31) || defined(MIPP_AVX512) || defined(MIPP_NEONV2) || (defined(MIPP_AVX2) && defined(MIPP_BMI2))
#if (defined(MIPP_SSE) && MIPP_INSTR_VERSION >= 31) || defined(MIPP_AVX512) || defined(MIPP_NEON) || (defined(MIPP_AVX2) && defined(MIPP_BMI2))
#if defined(MIPP_64BIT)
SECTION("datatype = double") { test_Reg_compress<double>(); }
#endif
Expand Down

0 comments on commit 35d8fbf

Please sign in to comment.