Skip to content

Commit

Permalink
RavenDB-20103: ARM support for the different variants of VectorizedAnd
Browse files Browse the repository at this point in the history
  • Loading branch information
redknightlois authored and arekpalinski committed Apr 10, 2024
1 parent 4aa90aa commit bb386ab
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 20 deletions.
3 changes: 2 additions & 1 deletion src/Corax/Querying/IndexSearcher.cs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
using InvalidOperationException = System.InvalidOperationException;
using static Voron.Data.CompactTrees.CompactTree;
using Voron.Util;
using System.Runtime.Intrinsics;

namespace Corax.Querying;

Expand All @@ -47,7 +48,7 @@ public sealed unsafe partial class IndexSearcher : IDisposable
/// </summary>
public bool ForceNonAccelerated { get; set; }

public bool IsAccelerated => Avx2.IsSupported && !ForceNonAccelerated;
public bool IsAccelerated => Vector256.IsHardwareAccelerated && !ForceNonAccelerated;

public long NumberOfEntries => _numberOfEntries ??= _metadataTree?.ReadInt64(Constants.IndexWriter.NumberOfEntriesSlice) ?? 0;

Expand Down
29 changes: 15 additions & 14 deletions src/Corax/Querying/Matches/Meta/MergeHelper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,17 @@ public static int And(Span<long> dst, Span<long> left, Span<long> right)
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static int And(long* dst, int dstLength, long* left, int leftLength, long* right, int rightLength)
{
if (Avx2.IsSupported)
if (Vector256.IsHardwareAccelerated)
return AndVectorized(dst, dstLength, left, leftLength, right, rightLength);

return AndScalar(dst, dstLength, left, leftLength, right, rightLength);
}

/// <summary>
/// AVX2 implementation of vectorized AND.
/// Vector256 implementation of vectorized AND that works on both Intel/AMD and ARM.
/// </summary>
[MethodImpl(MethodImplOptions.AggressiveInlining)]
internal static unsafe int AndVectorized(long* dst, int dstLength, long* left, int leftLength, long* right, int rightLength)
internal static int AndVectorized(long* dst, int dstLength, long* left, int leftLength, long* right, int rightLength)
{
// This is effectively a constant.
uint N = (uint)Vector256<ulong>.Count;
Expand Down Expand Up @@ -64,8 +65,8 @@ internal static unsafe int AndVectorized(long* dst, int dstLength, long* left, i
{
while (true)
{
// TODO: In here we can do SIMD galloping with gather operations. Therefore we will be able to do
// multiple checks at once and find the right amount of skipping using a table.
// TODO: In here we can do SIMD galloping with gather operations. Therefore, we will be able to do
// multiple checks at once and find the right amount of skipping using a table.

// If the value to compare is bigger than the biggest element in the block, we advance the block.
if ((ulong)*smallerPtr > (ulong)*(largerPtr + N - 1))
Expand All @@ -91,10 +92,10 @@ internal static unsafe int AndVectorized(long* dst, int dstLength, long* left, i
break; //In case when block is smaller than N we've to use scalar version.

Vector256<ulong> value = Vector256.Create((ulong)*smallerPtr);
Vector256<ulong> blockValues = Avx.LoadVector256((ulong*)largerPtr);
Vector256<ulong> blockValues = Vector256.Load((ulong*)largerPtr);

// We are going to select which direction we are going to be moving forward.
if (!Avx2.CompareEqual(value, blockValues).Equals(Vector256<ulong>.Zero))
if (Vector256.EqualsAny(value, blockValues))
{
// We found the value, therefore we need to store this value in the destination.
*dstPtr = *smallerPtr;
Expand All @@ -107,7 +108,7 @@ internal static unsafe int AndVectorized(long* dst, int dstLength, long* left, i
}
}

// The scalar version. This shouldnt cost much either way.
// The scalar version. This shouldn't cost much either way.
while (smallerPtr < smallerEndPtr && largerPtr < largerEndPtr)
{
ulong leftValue = (ulong)*smallerPtr;
Expand Down Expand Up @@ -160,7 +161,7 @@ internal static int AndScalar(Span<long> dst, Span<long> left, Span<long> right)
/// is also used for testing purposes.
/// </summary>
[MethodImpl(MethodImplOptions.AggressiveInlining)]
internal static unsafe int AndScalar(long* dst, int dstLength, long* left, int leftLength, long* right, int rightLength)
internal static int AndScalar(long* dst, int dstLength, long* left, int leftLength, long* right, int rightLength)
{
long* dstPtr = dst;
long* leftPtr = left;
Expand Down Expand Up @@ -209,7 +210,7 @@ public static int Or(Span<long> dst, Span<long> left, Span<long> right)


[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static unsafe int Or(long* dst, int dstLength, long* left, int leftLength, long* right, int rightLength)
public static int Or(long* dst, int dstLength, long* left, int leftLength, long* right, int rightLength)
{
if (Sse2.IsSupported)
return OrNonTemporal(dst, dstLength, left, leftLength, right, rightLength);
Expand All @@ -220,7 +221,7 @@ public static unsafe int Or(long* dst, int dstLength, long* left, int leftLength
/// dst and left may *not* be the same buffer
/// </summary>
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static unsafe int OrNonTemporal(long* dst, int dstLength, long* left, int leftLength, long* right, int rightLength)
public static int OrNonTemporal(long* dst, int dstLength, long* left, int leftLength, long* right, int rightLength)
{
long* dstPtr = dst;
long* dstEndPtr = dst + dstLength;
Expand Down Expand Up @@ -280,7 +281,7 @@ public static unsafe int OrNonTemporal(long* dst, int dstLength, long* left, int
/// dst and left may *not* be the same buffer
/// </summary>
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static unsafe int OrScalar(long* dst, int dstLength, long* left, int leftLength, long* right, int rightLength)
public static int OrScalar(long* dst, int dstLength, long* left, int leftLength, long* right, int rightLength)
{
long* dstPtr = dst;
long* dstEndPtr = dst + dstLength;
Expand Down Expand Up @@ -347,7 +348,7 @@ public static int AndNot(Span<long> dst, Span<long> left, Span<long> right)
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static unsafe int AndNot(long* dst, int dstLength, long* left, int leftLength, long* right, int rightLength)
public static int AndNot(long* dst, int dstLength, long* left, int leftLength, long* right, int rightLength)
{
// PERF: This can be improved implementing support Sse2 implementation. This type of algorithms
// are very suitable for instruction level parallelism.
Expand All @@ -359,7 +360,7 @@ public static unsafe int AndNot(long* dst, int dstLength, long* left, int leftLe
/// is also used for testing purposes.
/// </summary>
[MethodImpl(MethodImplOptions.AggressiveInlining)]
internal static unsafe int AndNotScalar(long* dst, int dstLength, long* left, int leftLength, long* right, int rightLength)
internal static int AndNotScalar(long* dst, int dstLength, long* left, int leftLength, long* right, int rightLength)
{
long* dstPtr = dst;
long* leftPtr = left;
Expand Down
10 changes: 5 additions & 5 deletions src/Corax/Querying/Matches/TermMatch.cs
Original file line number Diff line number Diff line change
Expand Up @@ -477,12 +477,12 @@ static int AndWithVectorizedFunc<TBoostingMode>(ref TermMatch term, Span<long> b

if (largerEndPtr - largerPtr < N)
break; // boundary guardian for vector load.

Vector256<ulong> value = Vector256.Create((ulong)*smallerPtr);
Vector256<ulong> blockValues = Avx.LoadVector256((ulong*)largerPtr);
Vector256<ulong> blockValues = Vector256.Load((ulong*)largerPtr);

// We are going to select which direction we are going to be moving forward.
if (!Avx2.CompareEqual(value, blockValues).Equals(Vector256<ulong>.Zero))
if (Vector256.EqualsAny(value, blockValues))
{
// We found the value, therefore we need to store this value in the destination.
*dstPtr = *smallerPtr;
Expand Down Expand Up @@ -567,7 +567,7 @@ static void ScoreFunc(ref TermMatch term, Span<long> matches, Span<float> scores
term._bm25Relevance.Score(matches, scores, boostFactor);
}

if (Avx2.IsSupported == false)
if (Vector256.IsHardwareAccelerated == false)
useAccelerated = false;

var bm25Relevance = isBoosting
Expand All @@ -577,7 +577,7 @@ static void ScoreFunc(ref TermMatch term, Span<long> matches, Span<float> scores

var isStored = isBoosting && bm25Relevance.IsStored;

// We will select the AVX version if supported.
// We will select the Vector256 version if supported.
return new TermMatch(indexSearcher, ctx, postingList.State.NumberOfEntries,
(isBoosting, isStored) switch
{
Expand Down

0 comments on commit bb386ab

Please sign in to comment.