Skip to content

Commit

Permalink
RavenDB-22076 Removed TargetQuantization from WithField
Browse files Browse the repository at this point in the history
  • Loading branch information
Lwiel committed Oct 18, 2024
1 parent 2935992 commit 61d6166
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 2 deletions.
14 changes: 14 additions & 0 deletions src/Raven.Client/Documents/Linq/RavenQueryProviderProcessor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1638,6 +1638,20 @@ private void VisitLinqExtensionsMethodCall(MethodCallExpression expression)

break;
}
case Func<IVectorFieldFactory<T>, IVectorField> fieldFactory:
{
LinqPathProvider.GetValueFromExpressionWithoutConversion(expression.Arguments[2], out var embeddingFieldValueFactoryObject);

fieldFactory.Invoke(fieldBuilder);

if (embeddingFieldValueFactoryObject is not Action<IVectorFieldValueFactory> fieldValueFactory)
throw new Exception();

fieldValueFactory.Invoke(valueBuilder);

break;
}

default:
throw new Exception();
}
Expand Down
12 changes: 12 additions & 0 deletions src/Raven.Client/Documents/LinqExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1252,6 +1252,18 @@ public static IRavenQueryable<T> VectorSearch<T>(this IQueryable<T> source, Func

return (IRavenQueryable<T>)queryable;
}

public static IRavenQueryable<T> VectorSearch<T>(this IQueryable<T> source, Func<IVectorFieldFactory<T>, IVectorField> embeddingFieldFactory, Action<IVectorFieldValueFactory> embeddingValueFactory, float minimumSimilarity = 0.8f)
{
var currentMethod = (MethodInfo)MethodBase.GetCurrentMethod();

currentMethod = ConvertMethodIfNecessary(currentMethod, typeof(T));
var expression = ConvertExpressionIfNecessary(source);

var queryable = source.Provider.CreateQuery(Expression.Call(null, currentMethod, expression, Expression.Constant(embeddingFieldFactory), Expression.Constant(embeddingValueFactory), Expression.Constant(minimumSimilarity)));

return (IRavenQueryable<T>)queryable;
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static Expression ConvertExpressionIfNecessary<T>(IQueryable<T> source)
Expand Down
5 changes: 3 additions & 2 deletions src/Raven.Client/Documents/Queries/VectorFieldFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ public interface IVectorFieldFactory<T>
public IVectorEmbeddingField WithEmbedding(string fieldName, EmbeddingType storedEmbeddingQuantization = Constants.VectorSearch.DefaultEmbeddingType, VectorIndexingStrategy vectorIndexingStrategy = Constants.VectorSearch.DefaultIndexingStrategy);

public IVectorEmbeddingField WithEmbedding(Expression<Func<T, object>> propertySelector, EmbeddingType storedEmbeddingQuantization = Constants.VectorSearch.DefaultEmbeddingType, VectorIndexingStrategy vectorIndexingStrategy = Constants.VectorSearch.DefaultIndexingStrategy);

public IVectorEmbeddingField WithBase64(string fieldName, EmbeddingType storedEmbeddingQuantization = Constants.VectorSearch.DefaultEmbeddingType, VectorIndexingStrategy vectorIndexingStrategy = Constants.VectorSearch.DefaultIndexingStrategy);

public IVectorEmbeddingField WithBase64(Expression<Func<T, object>> propertySelector, EmbeddingType storedEmbeddingQuantization = Constants.VectorSearch.DefaultEmbeddingType, VectorIndexingStrategy vectorIndexingStrategy = Constants.VectorSearch.DefaultIndexingStrategy);
Expand All @@ -37,12 +38,12 @@ public interface IVectorEmbeddingField
public IVectorEmbeddingField TargetQuantization(EmbeddingType targetEmbeddingQuantization);
}

public interface IVectorField : IVectorEmbeddingField, IVectorEmbeddingTextField
public interface IVectorField
{

}

internal sealed class VectorEmbeddingFieldFactory<T> : IVectorFieldFactory<T>, IVectorField
internal sealed class VectorEmbeddingFieldFactory<T> : IVectorFieldFactory<T>, IVectorField, IVectorEmbeddingField, IVectorEmbeddingTextField
{
internal string FieldName { get; set; }
internal EmbeddingType SourceQuantizationType { get; set; } = Constants.VectorSearch.DefaultEmbeddingType;
Expand Down
2 changes: 2 additions & 0 deletions test/SlowTests/Issues/RavenDB-22076.cs
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,8 @@ public void TestRqlGenerationAsync(Options options)
.VectorSearch(x => x.WithBase64("Base64Field", EmbeddingType.Int8), factory => factory.ByEmbedding([0.2f, 0.3f])).ToString();

Assert.Equal("from 'Dtos' where vector.search(embedding.i8(Base64Field), $p0)", q6);

session.Advanced.AsyncDocumentQuery<Dto>().VectorSearch(x => x.WithField("aaa"), factory => factory.ByEmbedding([0.2f, 0.3f]));
}
}
}
Expand Down

0 comments on commit 61d6166

Please sign in to comment.