Skip to content

Commit

Permalink
RavenDB-22076 EmbeddingQuantizationType changes
Browse files Browse the repository at this point in the history
  • Loading branch information
Lwiel committed Oct 17, 2024
1 parent d0f7aa1 commit d2de31b
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 15 deletions.
22 changes: 10 additions & 12 deletions src/Raven.Client/Documents/Queries/VectorFieldFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -107,19 +107,19 @@ IVectorEmbeddingField IVectorFieldFactory<T>.WithBase64(Expression<Func<T, objec
return this;
}

IVectorField IVectorFieldFactory<T>.WithField(string fieldName, EmbeddingType storedEmbeddingQuantization)
IVectorField IVectorFieldFactory<T>.WithField(string fieldName)
{
FieldName = fieldName;
SourceQuantizationType = storedEmbeddingQuantization;
SourceQuantizationType = EmbeddingQuantizationType.Any;
DestinationQuantizationType = SourceQuantizationType;

return this;
}

IVectorField IVectorFieldFactory<T>.WithField(Expression<Func<T, object>> propertySelector, EmbeddingType storedEmbeddingQuantization)
IVectorField IVectorFieldFactory<T>.WithField(Expression<Func<T, object>> propertySelector)
{
FieldName = propertySelector.ToPropertyPath(DocumentConventions.Default);
SourceQuantizationType = storedEmbeddingQuantization;
SourceQuantizationType = EmbeddingQuantizationType.Any;
DestinationQuantizationType = SourceQuantizationType;

return this;
Expand All @@ -129,19 +129,17 @@ IVectorEmbeddingField IVectorEmbeddingField.TargetQuantization(EmbeddingType tar
{
DestinationQuantizationType = targetEmbeddingQuantization;

if (SourceQuantizationType is EmbeddingType.Int8 or EmbeddingType.Binary && DestinationQuantizationType != SourceQuantizationType)
throw new InvalidDataException("Cannot quantize already quantized embeddings");
if (DestinationQuantizationType == EmbeddingType.Text)
throw new InvalidDataException("Cannot quantize the embedding to Text. This option is only for SourceQuantizationType.");
if (DestinationQuantizationType < SourceQuantizationType)
throw new Exception($"Cannot quantize vector with {SourceQuantizationType} quantization into {DestinationQuantizationType}");

if (SourceQuantizationType == EmbeddingType.Int8 && DestinationQuantizationType == EmbeddingType.Binary)
throw new Exception("Cannot quantize already quantized embeddings");

return this;
}

public IVectorEmbeddingTextField TargetQuantization(EmbeddingType targetEmbeddingQuantization)
public IVectorEmbeddingTextField TargetQuantization(EmbeddingQuantizationType targetEmbeddingQuantization)
{
if (DestinationQuantizationType == EmbeddingType.Text)
throw new InvalidDataException("Cannot quantize the embedding to Text. This option is only for SourceQuantizationType.");
DestinationQuantizationType = targetEmbeddingQuantization;

return this;
Expand Down
11 changes: 8 additions & 3 deletions test/SlowTests/Issues/RavenDB-22076.cs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,11 @@ public void TestRqlGeneration(Options options)
var q4 = session.Advanced.DocumentQuery<Dto>().VectorSearch(x => x.WithField("VectorField"), factory => factory.ByBase64("aaaa==", EmbeddingType.Binary)).ToString();

Assert.Equal("from 'Dtos' where vector.search(VectorField, base64(i1($p0)), 0.8)", q4);

var q5 = session.Advanced.DocumentQuery<Dto>().VectorSearch(x => x.WithText("TextField").TargetQuantization(EmbeddingQuantizationType.I8),
factory => factory.ByText("aaaa")).ToString();

Assert.Equal("from 'Dtos' where vector.search(text_i8(TextField), $p0, 0.8)", q5);
}
}
}
Expand All @@ -59,7 +64,7 @@ public void TestRqlGenerationAsync(Options options)
var ex2 = Assert.Throws<Exception>(() => session.Advanced.AsyncDocumentQuery<Dto>()
.VectorSearch(x => x.WithEmbedding("EmbeddingField", EmbeddingType.Int8).TargetQuantization(EmbeddingType.Float32), factory => factory.ByEmbedding([2.5f, 3.3f], EmbeddingType.Binary), 0.65f).ToString());

Assert.Contains("Cannot quantize vector with I8 quantization into None", ex2.Message);
Assert.Contains("Cannot quantize vector with I8 quantization into F32", ex2.Message);

var q1 = session.Advanced.AsyncDocumentQuery<Dto>()
.VectorSearch(x => x.WithEmbedding("EmbeddingField", EmbeddingType.Int8), factory => factory.ByEmbedding([2.5f, 3.3f], EmbeddingType.Binary), 0.65f).ToString();
Expand All @@ -80,7 +85,7 @@ public void TestRqlGenerationAsync(Options options)
var q4 = session.Advanced.AsyncDocumentQuery<Dto>()
.VectorSearch(x => x.WithText("TextField"), factory => factory.ByText("abc")).ToString();

Assert.Equal("from 'Dtos' where vector.search(TextField, $p0, 0.8)", q4);
Assert.Equal("from 'Dtos' where vector.search(text(TextField), $p0, 0.8)", q4);

var q5 = session.Advanced.AsyncDocumentQuery<Dto>()
.VectorSearch(x => x.WithBase64("Base64Field", EmbeddingType.Binary), factory => factory.ByBase64("ddddd==", EmbeddingType.Int8), 0.85f).ToString();
Expand All @@ -105,7 +110,7 @@ public void TestLinqExtensions(Options options)
{
var q1 = session.Query<Dto>().VectorSearch(x => x.WithText("TextField"), factory => factory.ByText("SomeText")).ToString();

Assert.Equal("from 'Dtos' where vector.search(TextField, $p0, 0.8)", q1);
Assert.Equal("from 'Dtos' where vector.search(text(TextField), $p0, 0.8)", q1);

var q2 = session.Query<Dto>().VectorSearch(x => x.WithEmbedding("EmbeddingField", EmbeddingType.Int8), factory => factory.ByEmbedding([0.2f, -0.3f]), 0.75f).ToString();

Expand Down

0 comments on commit d2de31b

Please sign in to comment.