Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
RavenDB-22076 Vector search client API
Browse files Browse the repository at this point in the history
Lwiel committed Oct 14, 2024
1 parent 8a7f70c commit 3808e6d
Showing 14 changed files with 566 additions and 55 deletions.
16 changes: 16 additions & 0 deletions src/Raven.Client/Documents/Indexes/AbstractIndexCreationTask.cs
Original file line number Diff line number Diff line change
@@ -177,6 +177,22 @@ protected object CreateField(string name, object value)
throw new NotSupportedException("This can only be run on the server side");
}

/// <summary>
/// Generates a vector field in the index, generating embeddings from the provided value
/// </summary>
/// <param name="value">Source of text to generate tokens</param>
public object VectorSearch(string value)
{
throw new NotSupportedException("This method is provided solely to allow query translation on the server");
}

/// <inheritdoc cref="VectorSearch(string)"/>
/// <param name="values">Enumerable of text to generate tokens</param>
public object VectorSearch(IEnumerable<string> values)
{
throw new NotSupportedException("This method is provided solely to allow query translation on the server");
}

/// <summary>
/// Generates a spatial field in the index, generating a Point from the provided lat/lng coordinates
/// </summary>
139 changes: 139 additions & 0 deletions src/Raven.Client/Documents/Queries/VectorFieldFactory.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
using System;
using System.Linq.Expressions;
using Raven.Client.Documents.Conventions;
using Raven.Client.Extensions;

namespace Raven.Client.Documents.Queries;

public interface IVectorFieldFactory<T>
{
public IVectorEmbeddingTextField WithText(string fieldName);

public IVectorEmbeddingTextField WithText(Expression<Func<T, string>> propertySelector);

public IVectorEmbeddingField WithEmbedding(string fieldName, EmbeddingQuantizationType storedEmbeddingQuantization = EmbeddingQuantizationType.None);

public IVectorEmbeddingField WithEmbedding(Expression<Func<T, string>> propertySelector, EmbeddingQuantizationType storedEmbeddingQuantization = EmbeddingQuantizationType.None);
public IVectorEmbeddingField WithBase64(string fieldName, EmbeddingQuantizationType storedEmbeddingQuantization = EmbeddingQuantizationType.None);

public IVectorEmbeddingField WithBase64(Expression<Func<T, string>> propertySelector, EmbeddingQuantizationType storedEmbeddingQuantization = EmbeddingQuantizationType.None);
}

public interface IVectorEmbeddingTextField
{

}

public interface IVectorEmbeddingField
{
public IVectorEmbeddingField TargetQuantization(EmbeddingQuantizationType targetEmbeddingQuantization);
}

internal sealed class VectorEmbeddingFieldFactory<T> : IVectorFieldFactory<T>, IVectorEmbeddingTextField, IVectorEmbeddingField
{
internal string FieldName { get; set; }
internal EmbeddingQuantizationType SourceQuantizationType { get; set; }
internal EmbeddingQuantizationType DestinationQuantizationType { get; set; }
internal bool IsBase64Encoded { get; set; }

IVectorEmbeddingTextField IVectorFieldFactory<T>.WithText(Expression<Func<T, string>> propertySelector)
{
FieldName = propertySelector.ToPropertyPath(DocumentConventions.Default);

return this;
}

IVectorEmbeddingTextField IVectorFieldFactory<T>.WithText(string fieldName)
{
FieldName = fieldName;

return this;
}

IVectorEmbeddingField IVectorFieldFactory<T>.WithEmbedding(string fieldName, EmbeddingQuantizationType storedEmbeddingQuantization)
{
FieldName = fieldName;
SourceQuantizationType = storedEmbeddingQuantization;

return this;
}

IVectorEmbeddingField IVectorFieldFactory<T>.WithEmbedding(Expression<Func<T, string>> propertySelector, EmbeddingQuantizationType storedEmbeddingQuantization)
{
FieldName = propertySelector.ToPropertyPath(DocumentConventions.Default);
SourceQuantizationType = storedEmbeddingQuantization;

return this;
}

IVectorEmbeddingField IVectorFieldFactory<T>.WithBase64(string fieldName, EmbeddingQuantizationType storedEmbeddingQuantization)
{
FieldName = fieldName;
SourceQuantizationType = storedEmbeddingQuantization;
IsBase64Encoded = true;

return this;
}

IVectorEmbeddingField IVectorFieldFactory<T>.WithBase64(Expression<Func<T, string>> propertySelector, EmbeddingQuantizationType storedEmbeddingQuantization)
{
FieldName = propertySelector.ToPropertyPath(DocumentConventions.Default);
SourceQuantizationType = storedEmbeddingQuantization;
IsBase64Encoded = true;

return this;
}

IVectorEmbeddingField IVectorEmbeddingField.TargetQuantization(EmbeddingQuantizationType targetEmbeddingQuantization)
{
DestinationQuantizationType = targetEmbeddingQuantization;

return this;
}
}

//////////////////////////

public interface IVectorEmbeddingTextFieldValueFactory
{
public void ByText(string text);
}

public interface IVectorEmbeddingFieldValueFactory
{
public void ByEmbedding(float[] embedding, EmbeddingQuantizationType queriedEmbeddingQuantization = EmbeddingQuantizationType.None);
public void ByBase64(string base64Embedding, EmbeddingQuantizationType queriedEmbeddingQuantization = EmbeddingQuantizationType.None);
}

internal class VectorEmbeddingValueFactory : IVectorEmbeddingFieldValueFactory, IVectorEmbeddingTextFieldValueFactory
{
public float[] Embedding { get; set; }
public string Text { get; set; }
public string Base64Embedding { get; set; }
public EmbeddingQuantizationType EmbeddingQuantizationType { get; set; }

void IVectorEmbeddingFieldValueFactory.ByEmbedding(float[] embedding, EmbeddingQuantizationType queriedEmbeddingQuantization)
{
Embedding = embedding;
EmbeddingQuantizationType = queriedEmbeddingQuantization;
}

void IVectorEmbeddingFieldValueFactory.ByBase64(string base64Embedding, EmbeddingQuantizationType queriedEmbeddingQuantization)
{
Base64Embedding = base64Embedding;
EmbeddingQuantizationType = queriedEmbeddingQuantization;
}

void IVectorEmbeddingTextFieldValueFactory.ByText(string text)
{
Text = text;
}
}

public enum EmbeddingQuantizationType
{
None = 0,
F32 = None,
I8 = 1,
I1 = 2
}
44 changes: 44 additions & 0 deletions src/Raven.Client/Documents/Session/AbstractDocumentQuery.cs
Original file line number Diff line number Diff line change
@@ -1513,6 +1513,50 @@ public string GetMemberQueryPath(Expression expression)
return propertyName;
}

internal void VectorSearch(VectorEmbeddingFieldFactory<T> embeddingFieldFactory, VectorEmbeddingValueFactory queriedEmbeddingFactory,
float minimumSimilarity)
{
var fieldName = embeddingFieldFactory.FieldName;
var sourceQuantizationType = embeddingFieldFactory.SourceQuantizationType;
var targetQuantizationType = embeddingFieldFactory.DestinationQuantizationType;

if (targetQuantizationType == EmbeddingQuantizationType.None)
targetQuantizationType = sourceQuantizationType;

// TODO
if (targetQuantizationType < sourceQuantizationType)
throw new Exception($"Cannot quantize vector with {sourceQuantizationType} quantization into {targetQuantizationType}");

if (sourceQuantizationType == EmbeddingQuantizationType.I8 && targetQuantizationType == EmbeddingQuantizationType.I1)
throw new Exception("Cannot quantize already quantized embeddings");

var isSourceBase64Encoded = embeddingFieldFactory.IsBase64Encoded;

string queryParameterName;
var isVectorBase64Encoded = false;
var vectorQuantizationType = queriedEmbeddingFactory.EmbeddingQuantizationType;

if (queriedEmbeddingFactory.Text != null)
{
queryParameterName = AddQueryParameter(queriedEmbeddingFactory.Text);
}

else if (queriedEmbeddingFactory.Embedding != null)
{
queryParameterName = AddQueryParameter(queriedEmbeddingFactory.Embedding);
}

else
{
queryParameterName = AddQueryParameter(queriedEmbeddingFactory.Base64Embedding);
isVectorBase64Encoded = true;
}

var vectorSearchToken = new VectorSearchToken(fieldName, queryParameterName, sourceQuantizationType, targetQuantizationType, vectorQuantizationType, isSourceBase64Encoded, isVectorBase64Encoded, minimumSimilarity);

WhereTokens.AddLast(vectorSearchToken);
}

public void Distinct()
{
if (IsDistinct)
28 changes: 28 additions & 0 deletions src/Raven.Client/Documents/Session/AsyncDocumentQuery.cs
Original file line number Diff line number Diff line change
@@ -336,6 +336,34 @@ IAsyncDocumentQuery<T> IFilterDocumentQueryBase<T, IAsyncDocumentQuery<T>>.Where
return this;
}

/// <inheritdoc />
IAsyncDocumentQuery<T> IFilterDocumentQueryBase<T, IAsyncDocumentQuery<T>>.VectorSearch(Func<IVectorFieldFactory<T>, IVectorEmbeddingTextField> textFieldFactory, Action<IVectorEmbeddingTextFieldValueFactory> queriedTextFactory, float minimumSimilarity)
{
var fieldBuilder = new VectorEmbeddingFieldFactory<T>();
var valueBuilder = new VectorEmbeddingValueFactory();

textFieldFactory.Invoke(fieldBuilder);
queriedTextFactory.Invoke(valueBuilder);

VectorSearch(fieldBuilder, valueBuilder, minimumSimilarity);

return this;
}

/// <inheritdoc />
IAsyncDocumentQuery<T> IFilterDocumentQueryBase<T, IAsyncDocumentQuery<T>>.VectorSearch(Func<IVectorFieldFactory<T>, IVectorEmbeddingField> embeddingFieldFactory, Action<IVectorEmbeddingFieldValueFactory> queriedEmbeddingFactory, float minimumSimilarity)
{
var fieldBuilder = new VectorEmbeddingFieldFactory<T>();
var valueBuilder = new VectorEmbeddingValueFactory();

embeddingFieldFactory.Invoke(fieldBuilder);
queriedEmbeddingFactory.Invoke(valueBuilder);

VectorSearch(fieldBuilder, valueBuilder, minimumSimilarity);

return this;
}

/// <inheritdoc />
IAsyncDocumentQuery<T> IFilterDocumentQueryBase<T, IAsyncDocumentQuery<T>>.AndAlso()
{
28 changes: 28 additions & 0 deletions src/Raven.Client/Documents/Session/DocumentQuery.cs
Original file line number Diff line number Diff line change
@@ -662,6 +662,34 @@ IDocumentQuery<T> IFilterDocumentQueryBase<T, IDocumentQuery<T>>.WhereRegex(stri
return this;
}

/// <inheritdoc />
IDocumentQuery<T> IFilterDocumentQueryBase<T, IDocumentQuery<T>>.VectorSearch(Func<IVectorFieldFactory<T>, IVectorEmbeddingTextField> textFieldFactory, Action<IVectorEmbeddingTextFieldValueFactory> queriedTextFactory, float minimumSimilarity)
{
var fieldBuilder = new VectorEmbeddingFieldFactory<T>();
var valueBuilder = new VectorEmbeddingValueFactory();

textFieldFactory.Invoke(fieldBuilder);
queriedTextFactory.Invoke(valueBuilder);

VectorSearch(fieldBuilder, valueBuilder, minimumSimilarity);

return this;
}

/// <inheritdoc />
IDocumentQuery<T> IFilterDocumentQueryBase<T, IDocumentQuery<T>>.VectorSearch(Func<IVectorFieldFactory<T>, IVectorEmbeddingField> embeddingFieldFactory, Action<IVectorEmbeddingFieldValueFactory> queriedEmbeddingFactory, float minimumSimilarity)
{
var fieldBuilder = new VectorEmbeddingFieldFactory<T>();
var valueBuilder = new VectorEmbeddingValueFactory();

embeddingFieldFactory.Invoke(fieldBuilder);
queriedEmbeddingFactory.Invoke(valueBuilder);

VectorSearch(fieldBuilder, valueBuilder, minimumSimilarity);

return this;
}

/// <inheritdoc />
IDocumentQuery<T> IFilterDocumentQueryBase<T, IDocumentQuery<T>>.AndAlso()
{
4 changes: 4 additions & 0 deletions src/Raven.Client/Documents/Session/IDocumentQueryBase.cs
Original file line number Diff line number Diff line change
@@ -440,6 +440,10 @@ public interface IFilterDocumentQueryBase<T, TSelf> where TSelf : IDocumentQuery
/// <inheritdoc cref="MoreLikeThisBase"/>
/// <param name="moreLikeThis">Specified MoreLikeThisQuery.</param>
TSelf MoreLikeThis(MoreLikeThisBase moreLikeThis);

TSelf VectorSearch(Func<IVectorFieldFactory<T>, IVectorEmbeddingTextField> textFieldFactory, Action<IVectorEmbeddingTextFieldValueFactory> textValueFactory, float minimumSimilarity = 0.8f);

TSelf VectorSearch(Func<IVectorFieldFactory<T>, IVectorEmbeddingField> embeddingFieldFactory, Action<IVectorEmbeddingFieldValueFactory> embeddingValueFactory, float minimumSimilarity = 0.8f);
}

public interface IGroupByDocumentQueryBase<T, TSelf> where TSelf : IDocumentQueryBase<T, TSelf>
90 changes: 90 additions & 0 deletions src/Raven.Client/Documents/Session/Tokens/VectorSearchToken.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
using System.Text;
using Raven.Client.Documents.Queries;

namespace Raven.Client.Documents.Session.Tokens;

public sealed class VectorSearchToken : WhereToken
{
private float SimilarityThreshold { get; set; }

private EmbeddingQuantizationType SourceQuantizationType { get; set; }

private EmbeddingQuantizationType TargetQuantizationType { get; set; }

private EmbeddingQuantizationType VectorQuantizationType { get; set; }

private bool IsSourceBase64Encoded { get; set; }

private bool IsVectorBase64Encoded { get; set; }

public VectorSearchToken(string fieldName, string parameterName, EmbeddingQuantizationType sourceQuantizationType, EmbeddingQuantizationType targetQuantizationType, EmbeddingQuantizationType vectorQuantizationType, bool isSourceBase64Encoded, bool isVectorBase64Encoded, float similarityThreshold)
{
FieldName = fieldName;
ParameterName = parameterName;

SourceQuantizationType = sourceQuantizationType;
TargetQuantizationType = targetQuantizationType;
VectorQuantizationType = vectorQuantizationType;

IsSourceBase64Encoded = isSourceBase64Encoded;
IsVectorBase64Encoded = isVectorBase64Encoded;

SimilarityThreshold = similarityThreshold;
}

public override void WriteTo(StringBuilder writer)
{
bool explicitSourceQuantizationType = false;

writer.Append("vector.search(");

if (IsSourceBase64Encoded)
writer.Append("base64(");

if (SourceQuantizationType == EmbeddingQuantizationType.None)
{
if (TargetQuantizationType != EmbeddingQuantizationType.None)
{
// TODO
writer.Append($"f32_{TargetQuantizationType.ToString().ToLower()}(");

explicitSourceQuantizationType = true;
}
}

else if (TargetQuantizationType != EmbeddingQuantizationType.None)
{
writer.Append($"{SourceQuantizationType.ToString().ToLower()}(");

explicitSourceQuantizationType = true;
}

writer.Append(FieldName);

if (IsSourceBase64Encoded)
writer.Append(')');

if (explicitSourceQuantizationType)
writer.Append(')');

writer.Append(", ");

if (IsVectorBase64Encoded)
writer.Append("base64(");

if (VectorQuantizationType != EmbeddingQuantizationType.F32)
writer.Append($"{VectorQuantizationType.ToString().ToLower()}(");

writer.Append($"${ParameterName}");

if (VectorQuantizationType != EmbeddingQuantizationType.F32)
writer.Append(')');

if (IsVectorBase64Encoded)
writer.Append(')');

writer.Append($", {SimilarityThreshold}");

writer.Append(')');
}
}
3 changes: 2 additions & 1 deletion src/Raven.Client/Documents/Session/Tokens/WhereOperator.cs
Original file line number Diff line number Diff line change
@@ -20,6 +20,7 @@ public enum WhereOperator
Spatial_Contains,
Spatial_Disjoint,
Spatial_Intersects,
Regex
Regex,
Vector_Search
}
}
Loading

0 comments on commit 3808e6d

Please sign in to comment.