Skip to content

Commit

Permalink
RavenDB-22076 EmbeddingQuantizationType -> EmbeddingType
Browse files Browse the repository at this point in the history
  • Loading branch information
Lwiel committed Oct 17, 2024
1 parent d2de31b commit 3ae659b
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 10 deletions.
8 changes: 3 additions & 5 deletions src/Raven.Client/Documents/Queries/VectorFieldFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
using Raven.Client.Documents.Conventions;
using Raven.Client.Documents.Indexes.Vector;
using Raven.Client.Extensions;
using Sparrow;
using Sparrow.Binary;

namespace Raven.Client.Documents.Queries;

Expand All @@ -24,9 +22,9 @@ public interface IVectorFieldFactory<T>

public IVectorEmbeddingField WithBase64(Expression<Func<T, object>> propertySelector, EmbeddingType storedEmbeddingQuantization = Constants.VectorSearch.DefaultEmbeddingType);

public IVectorField WithField(string fieldName, EmbeddingType storedEmbeddingQuantization = Constants.VectorSearch.DefaultEmbeddingType);
public IVectorField WithField(string fieldName);

public IVectorField WithField(Expression<Func<T, object>> propertySelector, EmbeddingType storedEmbeddingQuantization = Constants.VectorSearch.DefaultEmbeddingType);
public IVectorField WithField(Expression<Func<T, object>> propertySelector);
}

public interface IVectorEmbeddingTextField
Expand Down Expand Up @@ -138,7 +136,7 @@ IVectorEmbeddingField IVectorEmbeddingField.TargetQuantization(EmbeddingType tar
return this;
}

public IVectorEmbeddingTextField TargetQuantization(EmbeddingQuantizationType targetEmbeddingQuantization)
public IVectorEmbeddingTextField TargetQuantization(EmbeddingType targetEmbeddingQuantization)
{
DestinationQuantizationType = targetEmbeddingQuantization;

Expand Down
11 changes: 6 additions & 5 deletions test/SlowTests/Issues/RavenDB-22076.cs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ public void TestRqlGeneration(Options options)

Assert.Equal("from 'Dtos' where vector.search(VectorField, base64(i1($p0)), 0.8)", q4);

var q5 = session.Advanced.DocumentQuery<Dto>().VectorSearch(x => x.WithText("TextField").TargetQuantization(EmbeddingQuantizationType.I8),
var q5 = session.Advanced.DocumentQuery<Dto>().VectorSearch(x => x.WithText("TextField").TargetQuantization(EmbeddingType.Int8),
factory => factory.ByText("aaaa")).ToString();

Assert.Equal("from 'Dtos' where vector.search(text_i8(TextField), $p0, 0.8)", q5);
Expand Down Expand Up @@ -108,14 +108,15 @@ public void TestLinqExtensions(Options options)
{
using (var session = store.OpenSession())
{
/*
var q1 = session.Query<Dto>().VectorSearch(x => x.WithText("TextField"), factory => factory.ByText("SomeText")).ToString();
Assert.Equal("from 'Dtos' where vector.search(text(TextField), $p0, 0.8)", q1);
Assert.Equal("from 'Dtos' where vector.search(embedding.text(TextField), $p0, 0.8)", q1);
var q2 = session.Query<Dto>().VectorSearch(x => x.WithEmbedding("EmbeddingField", EmbeddingType.Int8), factory => factory.ByEmbedding([0.2f, -0.3f]), 0.75f).ToString();
Assert.Equal("from 'Dtos' where vector.search(i8(EmbeddingField), $p0, 0.75)", q2);

Assert.Equal("from 'Dtos' where vector.search(embedding.i8(EmbeddingField), $p0)", q2);
*/
var q3 = session.Query<Dto>().VectorSearch(x => x.WithEmbedding("EmbeddingField").TargetQuantization(EmbeddingType.Int8), factory => factory.ByEmbedding([0.2f, -0.3f], EmbeddingType.Int8)).ToString();

Assert.Equal("from 'Dtos' where vector.search(f32_i8(EmbeddingField), i8($p0), 0.8)", q3);
Expand Down

0 comments on commit 3ae659b

Please sign in to comment.