diff --git a/src/Raven.Client/Documents/Linq/RavenQueryProviderProcessor.cs b/src/Raven.Client/Documents/Linq/RavenQueryProviderProcessor.cs index 9f6937b8b491..9e65dd20aea4 100644 --- a/src/Raven.Client/Documents/Linq/RavenQueryProviderProcessor.cs +++ b/src/Raven.Client/Documents/Linq/RavenQueryProviderProcessor.cs @@ -1638,6 +1638,20 @@ private void VisitLinqExtensionsMethodCall(MethodCallExpression expression) break; } + case Func, IVectorField> fieldFactory: + { + LinqPathProvider.GetValueFromExpressionWithoutConversion(expression.Arguments[2], out var embeddingFieldValueFactoryObject); + + fieldFactory.Invoke(fieldBuilder); + + if (embeddingFieldValueFactoryObject is not Action fieldValueFactory) + throw new Exception(); + + fieldValueFactory.Invoke(valueBuilder); + + break; + } + default: throw new Exception(); } diff --git a/src/Raven.Client/Documents/LinqExtensions.cs b/src/Raven.Client/Documents/LinqExtensions.cs index 2a971bfd9e7d..e7b1b25021e8 100644 --- a/src/Raven.Client/Documents/LinqExtensions.cs +++ b/src/Raven.Client/Documents/LinqExtensions.cs @@ -1252,6 +1252,18 @@ public static IRavenQueryable VectorSearch(this IQueryable source, Func return (IRavenQueryable)queryable; } + + public static IRavenQueryable VectorSearch(this IQueryable source, Func, IVectorField> embeddingFieldFactory, Action embeddingValueFactory, float minimumSimilarity = 0.8f) + { + var currentMethod = (MethodInfo)MethodBase.GetCurrentMethod(); + + currentMethod = ConvertMethodIfNecessary(currentMethod, typeof(T)); + var expression = ConvertExpressionIfNecessary(source); + + var queryable = source.Provider.CreateQuery(Expression.Call(null, currentMethod, expression, Expression.Constant(embeddingFieldFactory), Expression.Constant(embeddingValueFactory), Expression.Constant(minimumSimilarity))); + + return (IRavenQueryable)queryable; + } [MethodImpl(MethodImplOptions.AggressiveInlining)] private static Expression ConvertExpressionIfNecessary(IQueryable source) diff --git a/src/Raven.Client/Documents/Queries/VectorFieldFactory.cs b/src/Raven.Client/Documents/Queries/VectorFieldFactory.cs index a6dea6f0c237..178bc60bef2c 100644 --- a/src/Raven.Client/Documents/Queries/VectorFieldFactory.cs +++ b/src/Raven.Client/Documents/Queries/VectorFieldFactory.cs @@ -18,6 +18,7 @@ public interface IVectorFieldFactory public IVectorEmbeddingField WithEmbedding(string fieldName, EmbeddingType storedEmbeddingQuantization = Constants.VectorSearch.DefaultEmbeddingType, VectorIndexingStrategy vectorIndexingStrategy = Constants.VectorSearch.DefaultIndexingStrategy); public IVectorEmbeddingField WithEmbedding(Expression> propertySelector, EmbeddingType storedEmbeddingQuantization = Constants.VectorSearch.DefaultEmbeddingType, VectorIndexingStrategy vectorIndexingStrategy = Constants.VectorSearch.DefaultIndexingStrategy); + public IVectorEmbeddingField WithBase64(string fieldName, EmbeddingType storedEmbeddingQuantization = Constants.VectorSearch.DefaultEmbeddingType, VectorIndexingStrategy vectorIndexingStrategy = Constants.VectorSearch.DefaultIndexingStrategy); public IVectorEmbeddingField WithBase64(Expression> propertySelector, EmbeddingType storedEmbeddingQuantization = Constants.VectorSearch.DefaultEmbeddingType, VectorIndexingStrategy vectorIndexingStrategy = Constants.VectorSearch.DefaultIndexingStrategy); @@ -37,12 +38,12 @@ public interface IVectorEmbeddingField public IVectorEmbeddingField TargetQuantization(EmbeddingType targetEmbeddingQuantization); } -public interface IVectorField : IVectorEmbeddingField, IVectorEmbeddingTextField +public interface IVectorField { } -internal sealed class VectorEmbeddingFieldFactory : IVectorFieldFactory, IVectorField +internal sealed class VectorEmbeddingFieldFactory : IVectorFieldFactory, IVectorField, IVectorEmbeddingField, IVectorEmbeddingTextField { internal string FieldName { get; set; } internal EmbeddingType SourceQuantizationType { get; set; } = Constants.VectorSearch.DefaultEmbeddingType; diff --git a/test/SlowTests/Issues/RavenDB-22076.cs b/test/SlowTests/Issues/RavenDB-22076.cs index e5e23468192f..1826179052f0 100644 --- a/test/SlowTests/Issues/RavenDB-22076.cs +++ b/test/SlowTests/Issues/RavenDB-22076.cs @@ -106,6 +106,8 @@ public void TestRqlGenerationAsync(Options options) .VectorSearch(x => x.WithBase64("Base64Field", EmbeddingType.Int8), factory => factory.ByEmbedding([0.2f, 0.3f])).ToString(); Assert.Equal("from 'Dtos' where vector.search(embedding.i8(Base64Field), $p0)", q6); + + session.Advanced.AsyncDocumentQuery().VectorSearch(x => x.WithField("aaa"), factory => factory.ByEmbedding([0.2f, 0.3f])); } } }