diff --git a/faiss/impl/ScalarQuantizer.cpp b/faiss/impl/ScalarQuantizer.cpp index 91e00ee032..8d440926a7 100644 --- a/faiss/impl/ScalarQuantizer.cpp +++ b/faiss/impl/ScalarQuantizer.cpp @@ -671,7 +671,7 @@ struct QuantizerBF16<8> : QuantizerBF16<1> { FAISS_ALWAYS_INLINE simd8float32 reconstruct_8_components(const uint8_t* code, int i) const { -#ifdef __AVX2__ + // #ifdef __AVX2__ // reference impl: decode_bf16(((uint16_t*)code)[i]); // decode_bf16(v) -> (uint32_t(v) << 16) // read 128-bits (16 uint8_t) -> (uint16_t*)code)[i] @@ -683,18 +683,18 @@ struct QuantizerBF16<8> : QuantizerBF16<1> { simd8uint32 shifted_16 = code_256i << 16; return as_float32(shifted_16); -#endif + // #endif -#ifdef __aarch64__ + // #ifdef __aarch64__ - uint16x4x2_t codei = vld1_u16_x2((const uint16_t*)(code + 2 * i)); - return simd8float32( - {vreinterpretq_f32_u32( - vshlq_n_u32(vmovl_u16(codei.val[0]), 16)), - vreinterpretq_f32_u32( - vshlq_n_u32(vmovl_u16(codei.val[1]), 16))}); + // uint16x4x2_t codei = vld1_u16_x2((const uint16_t*)(code + 2 * + // i)); return simd8float32( + // {vreinterpretq_f32_u32( + // vshlq_n_u32(vmovl_u16(codei.val[0]), 16)), + // vreinterpretq_f32_u32( + // vshlq_n_u32(vmovl_u16(codei.val[1]), 16))}); -#endif + // #endif } }; diff --git a/faiss/utils/simdlib_emulated.h b/faiss/utils/simdlib_emulated.h index 2dc61ae34a..2d9a5f66c8 100644 --- a/faiss/utils/simdlib_emulated.h +++ b/faiss/utils/simdlib_emulated.h @@ -703,6 +703,13 @@ struct simd8uint32 : simd256bit { u32[0], u32[2], u32[4], u32[6], u32[1], u32[3], u32[5], u32[7]}; return simd8uint32{ret}; } + inline simd8uint32 load8_16bits_as_uint32(const uint8_t* code, int i) { + simd8uint32 res; + for (int j = 0; j < 16; j = j + 2) { + res.u32[j / 2] = *(code + i + j); + } + return res; + } }; // Vectorized version of the following code: @@ -833,221 +840,226 @@ struct simd8float32 : simd256bit { ptr[-1] = 0; return std::string(res); } -}; -// hadd does not cross lanes -inline simd8float32 hadd(const simd8float32& a, const simd8float32& b) { - simd8float32 c; - c.f32[0] = a.f32[0] + a.f32[1]; - c.f32[1] = a.f32[2] + a.f32[3]; - c.f32[2] = b.f32[0] + b.f32[1]; - c.f32[3] = b.f32[2] + b.f32[3]; - - c.f32[4] = a.f32[4] + a.f32[5]; - c.f32[5] = a.f32[6] + a.f32[7]; - c.f32[6] = b.f32[4] + b.f32[5]; - c.f32[7] = b.f32[6] + b.f32[7]; + float accumulate() const { + return f32[0] + f32[1] + f32[2] + f32[3] + f32[4] + f32[5] + f32[6] + + f32[7]; + }; - return c; -} + // hadd does not cross lanes + inline simd8float32 hadd(const simd8float32& a, const simd8float32& b) { + simd8float32 c; + c.f32[0] = a.f32[0] + a.f32[1]; + c.f32[1] = a.f32[2] + a.f32[3]; + c.f32[2] = b.f32[0] + b.f32[1]; + c.f32[3] = b.f32[2] + b.f32[3]; -inline simd8float32 unpacklo(const simd8float32& a, const simd8float32& b) { - simd8float32 c; - c.f32[0] = a.f32[0]; - c.f32[1] = b.f32[0]; - c.f32[2] = a.f32[1]; - c.f32[3] = b.f32[1]; + c.f32[4] = a.f32[4] + a.f32[5]; + c.f32[5] = a.f32[6] + a.f32[7]; + c.f32[6] = b.f32[4] + b.f32[5]; + c.f32[7] = b.f32[6] + b.f32[7]; - c.f32[4] = a.f32[4]; - c.f32[5] = b.f32[4]; - c.f32[6] = a.f32[5]; - c.f32[7] = b.f32[5]; + return c; + } - return c; -} + inline simd8float32 unpacklo(const simd8float32& a, const simd8float32& b) { + simd8float32 c; + c.f32[0] = a.f32[0]; + c.f32[1] = b.f32[0]; + c.f32[2] = a.f32[1]; + c.f32[3] = b.f32[1]; -inline simd8float32 unpackhi(const simd8float32& a, const simd8float32& b) { - simd8float32 c; - c.f32[0] = a.f32[2]; - c.f32[1] = b.f32[2]; - c.f32[2] = a.f32[3]; - c.f32[3] = b.f32[3]; + c.f32[4] = a.f32[4]; + c.f32[5] = b.f32[4]; + c.f32[6] = a.f32[5]; + c.f32[7] = b.f32[5]; - c.f32[4] = a.f32[6]; - c.f32[5] = b.f32[6]; - c.f32[6] = a.f32[7]; - c.f32[7] = b.f32[7]; + return c; + } - return c; -} + inline simd8float32 unpackhi(const simd8float32& a, const simd8float32& b) { + simd8float32 c; + c.f32[0] = a.f32[2]; + c.f32[1] = b.f32[2]; + c.f32[2] = a.f32[3]; + c.f32[3] = b.f32[3]; -// compute a * b + c -inline simd8float32 fmadd( - const simd8float32& a, - const simd8float32& b, - const simd8float32& c) { - simd8float32 res; - for (int i = 0; i < 8; i++) { - res.f32[i] = a.f32[i] * b.f32[i] + c.f32[i]; + c.f32[4] = a.f32[6]; + c.f32[5] = b.f32[6]; + c.f32[6] = a.f32[7]; + c.f32[7] = b.f32[7]; + + return c; } - return res; -} -inline simd8float32 load8(const uint8_t* code, int i) { - simd8float32 res; - for (int j = 0; j < 8; j++) { - res.f32[i] = *(code + i + j); + // compute a * b + c + inline simd8float32 fmadd( + const simd8float32& a, + const simd8float32& b, + const simd8float32& c) { + simd8float32 res; + for (int i = 0; i < 8; i++) { + res.f32[i] = a.f32[i] * b.f32[i] + c.f32[i]; + } + return res; } - return res; -} -namespace { + inline simd8float32 load8(const uint8_t* code, int i) { + simd8float32 res; + for (int j = 0; j < 8; j++) { + res.f32[j] = *(code + i + j); + } + return res; + } -// get even float32's of a and b, interleaved -simd8float32 geteven(const simd8float32& a, const simd8float32& b) { - simd8float32 c; + namespace { + // get even float32's of a and b, interleaved + simd8float32 geteven(const simd8float32& a, const simd8float32& b) { + simd8float32 c; - c.f32[0] = a.f32[0]; - c.f32[1] = a.f32[2]; - c.f32[2] = b.f32[0]; - c.f32[3] = b.f32[2]; + c.f32[0] = a.f32[0]; + c.f32[1] = a.f32[2]; + c.f32[2] = b.f32[0]; + c.f32[3] = b.f32[2]; - c.f32[4] = a.f32[4]; - c.f32[5] = a.f32[6]; - c.f32[6] = b.f32[4]; - c.f32[7] = b.f32[6]; + c.f32[4] = a.f32[4]; + c.f32[5] = a.f32[6]; + c.f32[6] = b.f32[4]; + c.f32[7] = b.f32[6]; - return c; -} + return c; + } -// get odd float32's of a and b, interleaved -simd8float32 getodd(const simd8float32& a, const simd8float32& b) { - simd8float32 c; + // get odd float32's of a and b, interleaved + simd8float32 getodd(const simd8float32& a, const simd8float32& b) { + simd8float32 c; - c.f32[0] = a.f32[1]; - c.f32[1] = a.f32[3]; - c.f32[2] = b.f32[1]; - c.f32[3] = b.f32[3]; + c.f32[0] = a.f32[1]; + c.f32[1] = a.f32[3]; + c.f32[2] = b.f32[1]; + c.f32[3] = b.f32[3]; - c.f32[4] = a.f32[5]; - c.f32[5] = a.f32[7]; - c.f32[6] = b.f32[5]; - c.f32[7] = b.f32[7]; + c.f32[4] = a.f32[5]; + c.f32[5] = a.f32[7]; + c.f32[6] = b.f32[5]; + c.f32[7] = b.f32[7]; - return c; -} + return c; + } -// 3 cycles -// if the lanes are a = [a0 a1] and b = [b0 b1], return [a0 b0] -simd8float32 getlow128(const simd8float32& a, const simd8float32& b) { - simd8float32 c; + // 3 cycles + // if the lanes are a = [a0 a1] and b = [b0 b1], return [a0 b0] + simd8float32 getlow128(const simd8float32& a, const simd8float32& b) { + simd8float32 c; - c.f32[0] = a.f32[0]; - c.f32[1] = a.f32[1]; - c.f32[2] = a.f32[2]; - c.f32[3] = a.f32[3]; + c.f32[0] = a.f32[0]; + c.f32[1] = a.f32[1]; + c.f32[2] = a.f32[2]; + c.f32[3] = a.f32[3]; - c.f32[4] = b.f32[0]; - c.f32[5] = b.f32[1]; - c.f32[6] = b.f32[2]; - c.f32[7] = b.f32[3]; + c.f32[4] = b.f32[0]; + c.f32[5] = b.f32[1]; + c.f32[6] = b.f32[2]; + c.f32[7] = b.f32[3]; - return c; -} + return c; + } -simd8float32 gethigh128(const simd8float32& a, const simd8float32& b) { - simd8float32 c; + simd8float32 gethigh128(const simd8float32& a, const simd8float32& b) { + simd8float32 c; - c.f32[0] = a.f32[4]; - c.f32[1] = a.f32[5]; - c.f32[2] = a.f32[6]; - c.f32[3] = a.f32[7]; + c.f32[0] = a.f32[4]; + c.f32[1] = a.f32[5]; + c.f32[2] = a.f32[6]; + c.f32[3] = a.f32[7]; - c.f32[4] = b.f32[4]; - c.f32[5] = b.f32[5]; - c.f32[6] = b.f32[6]; - c.f32[7] = b.f32[7]; + c.f32[4] = b.f32[4]; + c.f32[5] = b.f32[5]; + c.f32[6] = b.f32[6]; + c.f32[7] = b.f32[7]; - return c; -} + return c; + } -// The following primitive is a vectorized version of the following code -// snippet: -// float lowestValue = HUGE_VAL; -// uint lowestIndex = 0; -// for (size_t i = 0; i < n; i++) { -// if (values[i] < lowestValue) { -// lowestValue = values[i]; -// lowestIndex = i; -// } -// } -// Vectorized version can be implemented via two operations: cmp and blend -// with something like this: -// lowestValues = [HUGE_VAL; 8]; -// lowestIndices = {0, 1, 2, 3, 4, 5, 6, 7}; -// for (size_t i = 0; i < n; i += 8) { -// auto comparison = cmp(values + i, lowestValues); -// lowestValues = blend( -// comparison, -// values + i, -// lowestValues); -// lowestIndices = blend( -// comparison, -// i + {0, 1, 2, 3, 4, 5, 6, 7}, -// lowestIndices); -// lowestIndices += {8, 8, 8, 8, 8, 8, 8, 8}; -// } -// The problem is that blend primitive needs very different instruction -// order for AVX and ARM. -// So, let's introduce a combination of these two in order to avoid -// confusion for ppl who write in low-level SIMD instructions. Additionally, -// these two ops (cmp and blend) are very often used together. -inline void cmplt_and_blend_inplace( - const simd8float32 candidateValues, - const simd8uint32 candidateIndices, - simd8float32& lowestValues, - simd8uint32& lowestIndices) { - for (size_t j = 0; j < 8; j++) { - bool comparison = (candidateValues.f32[j] < lowestValues.f32[j]); - if (comparison) { - lowestValues.f32[j] = candidateValues.f32[j]; - lowestIndices.u32[j] = candidateIndices.u32[j]; + // The following primitive is a vectorized version of the following code + // snippet: + // float lowestValue = HUGE_VAL; + // uint lowestIndex = 0; + // for (size_t i = 0; i < n; i++) { + // if (values[i] < lowestValue) { + // lowestValue = values[i]; + // lowestIndex = i; + // } + // } + // Vectorized version can be implemented via two operations: cmp and + // blend with something like this: + // lowestValues = [HUGE_VAL; 8]; + // lowestIndices = {0, 1, 2, 3, 4, 5, 6, 7}; + // for (size_t i = 0; i < n; i += 8) { + // auto comparison = cmp(values + i, lowestValues); + // lowestValues = blend( + // comparison, + // values + i, + // lowestValues); + // lowestIndices = blend( + // comparison, + // i + {0, 1, 2, 3, 4, 5, 6, 7}, + // lowestIndices); + // lowestIndices += {8, 8, 8, 8, 8, 8, 8, 8}; + // } + // The problem is that blend primitive needs very different instruction + // order for AVX and ARM. + // So, let's introduce a combination of these two in order to avoid + // confusion for ppl who write in low-level SIMD instructions. + // Additionally, these two ops (cmp and blend) are very often used + // together. + inline void cmplt_and_blend_inplace( + const simd8float32 candidateValues, + const simd8uint32 candidateIndices, + simd8float32& lowestValues, + simd8uint32& lowestIndices) { + for (size_t j = 0; j < 8; j++) { + bool comparison = (candidateValues.f32[j] < lowestValues.f32[j]); + if (comparison) { + lowestValues.f32[j] = candidateValues.f32[j]; + lowestIndices.u32[j] = candidateIndices.u32[j]; + } } } -} -// Vectorized version of the following code: -// for (size_t i = 0; i < n; i++) { -// bool flag = (candidateValues[i] < currentValues[i]); -// minValues[i] = flag ? candidateValues[i] : currentValues[i]; -// minIndices[i] = flag ? candidateIndices[i] : currentIndices[i]; -// maxValues[i] = !flag ? candidateValues[i] : currentValues[i]; -// maxIndices[i] = !flag ? candidateIndices[i] : currentIndices[i]; -// } -// Max indices evaluation is inaccurate in case of equal values (the index of -// the last equal value is saved instead of the first one), but this behavior -// saves instructions. -inline void cmplt_min_max_fast( - const simd8float32 candidateValues, - const simd8uint32 candidateIndices, - const simd8float32 currentValues, - const simd8uint32 currentIndices, - simd8float32& minValues, - simd8uint32& minIndices, - simd8float32& maxValues, - simd8uint32& maxIndices) { - for (size_t i = 0; i < 8; i++) { - bool flag = (candidateValues.f32[i] < currentValues.f32[i]); - minValues.f32[i] = flag ? candidateValues.f32[i] : currentValues.f32[i]; - minIndices.u32[i] = - flag ? candidateIndices.u32[i] : currentIndices.u32[i]; - maxValues.f32[i] = - !flag ? candidateValues.f32[i] : currentValues.f32[i]; - maxIndices.u32[i] = - !flag ? candidateIndices.u32[i] : currentIndices.u32[i]; + // Vectorized version of the following code: + // for (size_t i = 0; i < n; i++) { + // bool flag = (candidateValues[i] < currentValues[i]); + // minValues[i] = flag ? candidateValues[i] : currentValues[i]; + // minIndices[i] = flag ? candidateIndices[i] : currentIndices[i]; + // maxValues[i] = !flag ? candidateValues[i] : currentValues[i]; + // maxIndices[i] = !flag ? candidateIndices[i] : currentIndices[i]; + // } + // Max indices evaluation is inaccurate in case of equal values (the + // index of the last equal value is saved instead of the first one), but + // this behavior saves instructions. + inline void cmplt_min_max_fast( + const simd8float32 candidateValues, + const simd8uint32 candidateIndices, + const simd8float32 currentValues, + const simd8uint32 currentIndices, + simd8float32& minValues, + simd8uint32& minIndices, + simd8float32& maxValues, + simd8uint32& maxIndices) { + for (size_t i = 0; i < 8; i++) { + bool flag = (candidateValues.f32[i] < currentValues.f32[i]); + minValues.f32[i] = + flag ? candidateValues.f32[i] : currentValues.f32[i]; + minIndices.u32[i] = + flag ? candidateIndices.u32[i] : currentIndices.u32[i]; + maxValues.f32[i] = + !flag ? candidateValues.f32[i] : currentValues.f32[i]; + maxIndices.u32[i] = + !flag ? candidateIndices.u32[i] : currentIndices.u32[i]; + } } -} -} // namespace + } // namespace } // namespace faiss diff --git a/faiss/utils/simdlib_neon.h b/faiss/utils/simdlib_neon.h index 21bda18898..456a35551e 100644 --- a/faiss/utils/simdlib_neon.h +++ b/faiss/utils/simdlib_neon.h @@ -254,6 +254,11 @@ static inline uint32_t cmp_xe32( return d0_mask | static_cast(d1_mask) << 16; } +template +static inline uint32x4_t vshlq(uint32x4_t vec) { + return vshlq_n_u32(vec, Shift); +} + template static inline uint16x8_t vshlq(uint16x8_t vec) { return vshlq_n_u16(vec, Shift); @@ -972,6 +977,63 @@ struct simd8uint32 { return ~(*this == other); } + // shift must be known at compile time + simd8uint32 operator<<(const int shift) const { + switch (shift) { + case 0: + return *this; + case 1: + return simd8uint32{detail::simdlib::unary_func(data) + .call>()}; + case 2: + return simd8uint32{detail::simdlib::unary_func(data) + .call>()}; + case 3: + return simd8uint32{detail::simdlib::unary_func(data) + .call>()}; + case 4: + return simd8uint32{detail::simdlib::unary_func(data) + .call>()}; + case 5: + return simd8uint32{detail::simdlib::unary_func(data) + .call>()}; + case 6: + return simd8uint32{detail::simdlib::unary_func(data) + .call>()}; + case 7: + return simd8uint32{detail::simdlib::unary_func(data) + .call>()}; + case 8: + return simd8uint32{detail::simdlib::unary_func(data) + .call>()}; + case 9: + return simd8uint32{detail::simdlib::unary_func(data) + .call>()}; + case 10: + return simd8uint32{detail::simdlib::unary_func(data) + .call>()}; + case 11: + return simd8uint32{detail::simdlib::unary_func(data) + .call>()}; + case 12: + return simd8uint32{detail::simdlib::unary_func(data) + .call>()}; + case 13: + return simd8uint32{detail::simdlib::unary_func(data) + .call>()}; + case 14: + return simd8uint32{detail::simdlib::unary_func(data) + .call>()}; + case 15: + return simd8uint32{detail::simdlib::unary_func(data) + .call>()}; + case 16: + return simd8uint32{detail::simdlib::unary_func(data) + .call>()}; + default: + FAISS_THROW_FMT("Invalid shift %d", shift); + } + } // Checks whether the other holds exactly the same bytes. template bool is_same_as(T other) const { @@ -1240,6 +1302,13 @@ inline simd8float32 load8(const uint8_t* code, int i) { {vcvtq_f32_u32(vmovl_u16(y8_0)), vcvtq_f32_u32(vmovl_u16(y8_1))}); } +inline simd8uint32 load8_16bits_as_uint32(const uint8_t* code, int i) { + uint16x4x2_t codei = vld1_u16_x2((const uint16_t*)(code + 2 * i)); + return simd8uint32({vmovl_u16(codei.val[0]), vmovl_u16(codei.val[1])}); +} +inline simd8float32 as_float32(simd8uint32 x) { + return simd8float32(detail::simdlib::reinterpret_f32(x.data)); +} // The following primitive is a vectorized version of the following code // snippet: // float lowestValue = HUGE_VAL;