Skip to content

Commit

Permalink
RavenDB-22076 Added IndexingStrategy
Browse files Browse the repository at this point in the history
  • Loading branch information
Lwiel committed Oct 17, 2024
1 parent 3ae659b commit 2935992
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 63 deletions.
1 change: 1 addition & 0 deletions src/Raven.Client/Constants.cs
Original file line number Diff line number Diff line change
Expand Up @@ -595,6 +595,7 @@ public static class VectorSearch

public const float MinimumSimilarity = 0.8F;
public const EmbeddingType DefaultEmbeddingType = EmbeddingType.Float32;
public const VectorIndexingStrategy DefaultIndexingStrategy = VectorIndexingStrategy.Exact;
}
}
}
52 changes: 28 additions & 24 deletions src/Raven.Client/Documents/Queries/VectorFieldFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,16 @@ namespace Raven.Client.Documents.Queries;

public interface IVectorFieldFactory<T>
{
public IVectorEmbeddingTextField WithText(string fieldName);
public IVectorEmbeddingTextField WithText(string fieldName, VectorIndexingStrategy vectorIndexingStrategy = Constants.VectorSearch.DefaultIndexingStrategy);

public IVectorEmbeddingTextField WithText(Expression<Func<T, object>> propertySelector);
public IVectorEmbeddingTextField WithText(Expression<Func<T, object>> propertySelector, VectorIndexingStrategy vectorIndexingStrategy = Constants.VectorSearch.DefaultIndexingStrategy);

public IVectorEmbeddingField WithEmbedding(string fieldName, EmbeddingType storedEmbeddingQuantization = Constants.VectorSearch.DefaultEmbeddingType);
public IVectorEmbeddingField WithEmbedding(string fieldName, EmbeddingType storedEmbeddingQuantization = Constants.VectorSearch.DefaultEmbeddingType, VectorIndexingStrategy vectorIndexingStrategy = Constants.VectorSearch.DefaultIndexingStrategy);

public IVectorEmbeddingField WithEmbedding(Expression<Func<T, object>> propertySelector, EmbeddingType storedEmbeddingQuantization = Constants.VectorSearch.DefaultEmbeddingType);
public IVectorEmbeddingField WithBase64(string fieldName, EmbeddingType storedEmbeddingQuantization = Constants.VectorSearch.DefaultEmbeddingType);
public IVectorEmbeddingField WithEmbedding(Expression<Func<T, object>> propertySelector, EmbeddingType storedEmbeddingQuantization = Constants.VectorSearch.DefaultEmbeddingType, VectorIndexingStrategy vectorIndexingStrategy = Constants.VectorSearch.DefaultIndexingStrategy);
public IVectorEmbeddingField WithBase64(string fieldName, EmbeddingType storedEmbeddingQuantization = Constants.VectorSearch.DefaultEmbeddingType, VectorIndexingStrategy vectorIndexingStrategy = Constants.VectorSearch.DefaultIndexingStrategy);

public IVectorEmbeddingField WithBase64(Expression<Func<T, object>> propertySelector, EmbeddingType storedEmbeddingQuantization = Constants.VectorSearch.DefaultEmbeddingType);
public IVectorEmbeddingField WithBase64(Expression<Func<T, object>> propertySelector, EmbeddingType storedEmbeddingQuantization = Constants.VectorSearch.DefaultEmbeddingType, VectorIndexingStrategy vectorIndexingStrategy = Constants.VectorSearch.DefaultIndexingStrategy);

public IVectorField WithField(string fieldName);

Expand Down Expand Up @@ -48,67 +48,74 @@ internal sealed class VectorEmbeddingFieldFactory<T> : IVectorFieldFactory<T>, I
internal EmbeddingType SourceQuantizationType { get; set; } = Constants.VectorSearch.DefaultEmbeddingType;
internal EmbeddingType DestinationQuantizationType { get; set; } = Constants.VectorSearch.DefaultEmbeddingType;
internal bool IsBase64Encoded { get; set; }
internal VectorIndexingStrategy VectorIndexingStrategy { get; set; } = Constants.VectorSearch.DefaultIndexingStrategy;

IVectorEmbeddingTextField IVectorFieldFactory<T>.WithText(Expression<Func<T, object>> propertySelector)
IVectorEmbeddingTextField IVectorFieldFactory<T>.WithText(Expression<Func<T, object>> propertySelector, VectorIndexingStrategy vectorIndexingStrategy)
{
FieldName = propertySelector.ToPropertyPath(DocumentConventions.Default);
SourceQuantizationType = EmbeddingType.Text;
DestinationQuantizationType = EmbeddingType.Float32;
DestinationQuantizationType = Constants.VectorSearch.DefaultEmbeddingType;
VectorIndexingStrategy = vectorIndexingStrategy;

return this;
}

IVectorEmbeddingTextField IVectorFieldFactory<T>.WithText(string fieldName)
IVectorEmbeddingTextField IVectorFieldFactory<T>.WithText(string fieldName, VectorIndexingStrategy vectorIndexingStrategy)
{
FieldName = fieldName;
SourceQuantizationType = EmbeddingType.Text;
DestinationQuantizationType = EmbeddingType.Float32;
DestinationQuantizationType = Constants.VectorSearch.DefaultEmbeddingType;
VectorIndexingStrategy = vectorIndexingStrategy;

return this;
}

IVectorEmbeddingField IVectorFieldFactory<T>.WithEmbedding(string fieldName, EmbeddingType storedEmbeddingQuantization)
IVectorEmbeddingField IVectorFieldFactory<T>.WithEmbedding(string fieldName, EmbeddingType storedEmbeddingQuantization, VectorIndexingStrategy vectorIndexingStrategy)
{
FieldName = fieldName;
SourceQuantizationType = storedEmbeddingQuantization;
DestinationQuantizationType = SourceQuantizationType;
VectorIndexingStrategy = vectorIndexingStrategy;

return this;
}

IVectorEmbeddingField IVectorFieldFactory<T>.WithEmbedding(Expression<Func<T, object>> propertySelector, EmbeddingType storedEmbeddingQuantization)
IVectorEmbeddingField IVectorFieldFactory<T>.WithEmbedding(Expression<Func<T, object>> propertySelector, EmbeddingType storedEmbeddingQuantization, VectorIndexingStrategy vectorIndexingStrategy)
{
FieldName = propertySelector.ToPropertyPath(DocumentConventions.Default);
SourceQuantizationType = storedEmbeddingQuantization;
DestinationQuantizationType = SourceQuantizationType;
VectorIndexingStrategy = vectorIndexingStrategy;

return this;
}

IVectorEmbeddingField IVectorFieldFactory<T>.WithBase64(string fieldName, EmbeddingType storedEmbeddingQuantization)
IVectorEmbeddingField IVectorFieldFactory<T>.WithBase64(string fieldName, EmbeddingType storedEmbeddingQuantization, VectorIndexingStrategy vectorIndexingStrategy)
{
FieldName = fieldName;
SourceQuantizationType = storedEmbeddingQuantization;
DestinationQuantizationType = SourceQuantizationType;
IsBase64Encoded = true;
VectorIndexingStrategy = vectorIndexingStrategy;

return this;
}

IVectorEmbeddingField IVectorFieldFactory<T>.WithBase64(Expression<Func<T, object>> propertySelector, EmbeddingType storedEmbeddingQuantization)
IVectorEmbeddingField IVectorFieldFactory<T>.WithBase64(Expression<Func<T, object>> propertySelector, EmbeddingType storedEmbeddingQuantization, VectorIndexingStrategy vectorIndexingStrategy)
{
FieldName = propertySelector.ToPropertyPath(DocumentConventions.Default);
SourceQuantizationType = storedEmbeddingQuantization;
DestinationQuantizationType = SourceQuantizationType;
IsBase64Encoded = true;
VectorIndexingStrategy = vectorIndexingStrategy;

return this;
}

IVectorField IVectorFieldFactory<T>.WithField(string fieldName)
{
FieldName = fieldName;
SourceQuantizationType = EmbeddingQuantizationType.Any;
SourceQuantizationType = Constants.VectorSearch.DefaultEmbeddingType;
DestinationQuantizationType = SourceQuantizationType;

return this;
Expand All @@ -117,7 +124,7 @@ IVectorField IVectorFieldFactory<T>.WithField(string fieldName)
IVectorField IVectorFieldFactory<T>.WithField(Expression<Func<T, object>> propertySelector)
{
FieldName = propertySelector.ToPropertyPath(DocumentConventions.Default);
SourceQuantizationType = EmbeddingQuantizationType.Any;
SourceQuantizationType = Constants.VectorSearch.DefaultEmbeddingType;
DestinationQuantizationType = SourceQuantizationType;

return this;
Expand All @@ -127,7 +134,7 @@ IVectorEmbeddingField IVectorEmbeddingField.TargetQuantization(EmbeddingType tar
{
DestinationQuantizationType = targetEmbeddingQuantization;

if (DestinationQuantizationType < SourceQuantizationType)
if (DestinationQuantizationType > SourceQuantizationType)
throw new Exception($"Cannot quantize vector with {SourceQuantizationType} quantization into {DestinationQuantizationType}");

if (SourceQuantizationType == EmbeddingType.Int8 && DestinationQuantizationType == EmbeddingType.Binary)
Expand All @@ -151,13 +158,13 @@ public interface IVectorEmbeddingTextFieldValueFactory

public interface IVectorEmbeddingFieldValueFactory
{
public void ByEmbedding<T>(IEnumerable<T> embedding, EmbeddingType queriedEmbeddingQuantization = Constants.VectorSearch.DefaultEmbeddingType) where T : unmanaged
public void ByEmbedding<T>(IEnumerable<T> embedding) where T : unmanaged
#if NET7_0_OR_GREATER
, INumber<T>
#endif
;

public void ByBase64(string base64Embedding, EmbeddingType queriedEmbeddingQuantization = Constants.VectorSearch.DefaultEmbeddingType);
public void ByBase64(string base64Embedding);
}

public interface IVectorFieldValueFactory : IVectorEmbeddingTextFieldValueFactory, IVectorEmbeddingFieldValueFactory
Expand All @@ -170,9 +177,8 @@ internal class VectorFieldValueFactory : IVectorFieldValueFactory
public object Embedding { get; set; }
public string Text { get; set; }
public string Base64Embedding { get; set; }
public EmbeddingType EmbeddingType { get; set; }

void IVectorEmbeddingFieldValueFactory.ByEmbedding<T>(IEnumerable<T> embedding, EmbeddingType queriedEmbeddingQuantization)
void IVectorEmbeddingFieldValueFactory.ByEmbedding<T>(IEnumerable<T> embedding)
{
#if NET7_0_OR_GREATER == FALSE
// For >=NET7, INumber<T> is the guardian.
Expand All @@ -183,13 +189,11 @@ void IVectorEmbeddingFieldValueFactory.ByEmbedding<T>(IEnumerable<T> embedding,
#endif

Embedding = embedding;
EmbeddingType = queriedEmbeddingQuantization;
}

void IVectorEmbeddingFieldValueFactory.ByBase64(string base64Embedding, EmbeddingType queriedEmbeddingQuantization)
void IVectorEmbeddingFieldValueFactory.ByBase64(string base64Embedding)
{
Base64Embedding = base64Embedding;
EmbeddingType = queriedEmbeddingQuantization;
}

void IVectorEmbeddingTextFieldValueFactory.ByText(string text)
Expand Down
5 changes: 2 additions & 3 deletions src/Raven.Client/Documents/Session/AbstractDocumentQuery.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1520,8 +1520,7 @@ internal void VectorSearch(VectorEmbeddingFieldFactory<T> embeddingFieldFactory,
var sourceQuantizationType = embeddingFieldFactory.SourceQuantizationType;
var targetQuantizationType = embeddingFieldFactory.DestinationQuantizationType;
var isSourceBase64Encoded = embeddingFieldFactory.IsBase64Encoded;

var queriedVectorQuantizationType = queriedFactory.EmbeddingType;
var indexingStrategy = embeddingFieldFactory.VectorIndexingStrategy;

string queryParameterName;
var isVectorBase64Encoded = false;
Expand All @@ -1541,7 +1540,7 @@ internal void VectorSearch(VectorEmbeddingFieldFactory<T> embeddingFieldFactory,
isVectorBase64Encoded = true;
}

var vectorSearchToken = new VectorSearchToken(fieldName, queryParameterName, sourceQuantizationType, targetQuantizationType, queriedVectorQuantizationType, isSourceBase64Encoded, isVectorBase64Encoded, minimumSimilarity);
var vectorSearchToken = new VectorSearchToken(fieldName, queryParameterName, sourceQuantizationType, targetQuantizationType, isSourceBase64Encoded, isVectorBase64Encoded, minimumSimilarity, indexingStrategy);

WhereTokens.AddLast(vectorSearchToken);
}
Expand Down
15 changes: 11 additions & 4 deletions src/Raven.Client/Documents/Session/Tokens/VectorSearchToken.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,28 +12,33 @@ public sealed class VectorSearchToken : WhereToken
private float SimilarityThreshold { get; set; }
private EmbeddingType SourceQuantizationType { get; set; }
private EmbeddingType TargetQuantizationType { get; set; }
private EmbeddingType QueriedVectorQuantizationType { get; set; }
private bool IsSourceBase64Encoded { get; set; }
private bool IsVectorBase64Encoded { get; set; }
private VectorIndexingStrategy IndexingStrategy { get; set; }

public VectorSearchToken(string fieldName, string parameterName, EmbeddingType sourceQuantizationType, EmbeddingType targetQuantizationType, EmbeddingType queriedQueriedVectorQuantizationType, bool isSourceBase64Encoded, bool isVectorBase64Encoded, float similarityThreshold)
public VectorSearchToken(string fieldName, string parameterName, EmbeddingType sourceQuantizationType, EmbeddingType targetQuantizationType, bool isSourceBase64Encoded, bool isVectorBase64Encoded, float similarityThreshold, VectorIndexingStrategy indexingStrategy)
{
FieldName = fieldName;
ParameterName = parameterName;

SourceQuantizationType = sourceQuantizationType;
TargetQuantizationType = targetQuantizationType;
QueriedVectorQuantizationType = queriedQueriedVectorQuantizationType;

IsSourceBase64Encoded = isSourceBase64Encoded;
IsVectorBase64Encoded = isVectorBase64Encoded;

SimilarityThreshold = similarityThreshold;

IndexingStrategy = indexingStrategy;
}

public override void WriteTo(StringBuilder writer)
{
writer.Append("vector.search(");

if (IndexingStrategy != Constants.VectorSearch.DefaultIndexingStrategy)
writer.Append($"{IndexingStrategy}(");

if (SourceQuantizationType is EmbeddingType.Float32 && TargetQuantizationType is EmbeddingType.Float32)
writer.Append(FieldName);
else
Expand All @@ -46,14 +51,16 @@ public override void WriteTo(StringBuilder writer)
(EmbeddingType.Text, EmbeddingType.Int8) => Constants.VectorSearch.EmbeddingTextInt8,
(EmbeddingType.Text, EmbeddingType.Binary) => Constants.VectorSearch.EmbeddingTextInt1,
(EmbeddingType.Int8, EmbeddingType.Int8) => Constants.VectorSearch.EmbeddingInt8,
(EmbeddingType.Binary, EmbeddingType.Binary) => Constants.VectorSearch.EmbeddingInt8,
(EmbeddingType.Binary, EmbeddingType.Binary) => Constants.VectorSearch.EmbeddingInt1,
_ => throw new InvalidOperationException(
$"Cannot create vector field with SourceQuantizationType {SourceQuantizationType} and TargetQuantizationType {TargetQuantizationType}")
};

writer.Append($"{methodName}({FieldName})");
}

if (IndexingStrategy != Constants.VectorSearch.DefaultIndexingStrategy)
writer.Append(')');

writer.Append($", ${ParameterName}");

Expand Down
Loading

0 comments on commit 2935992

Please sign in to comment.