Skip to content

Commit

Permalink
Support sparse vector
Browse files Browse the repository at this point in the history
Signed-off-by: yhmo <[email protected]>
  • Loading branch information
yhmo committed Mar 29, 2024
1 parent 9439a12 commit 96960ad
Show file tree
Hide file tree
Showing 7 changed files with 190 additions and 32 deletions.
86 changes: 80 additions & 6 deletions src/main/java/io/milvus/param/ParamUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ public static HashMap<DataType, String> getTypeErrorMsg() {
typeErrMsg.put(DataType.BinaryVector, "Type mismatch for field '%s': Binary vector field's value type must be ByteBuffer");
typeErrMsg.put(DataType.Float16Vector, "Type mismatch for field '%s': Float16 vector field's value type must be ByteBuffer");
typeErrMsg.put(DataType.BFloat16Vector, "Type mismatch for field '%s': BFloat16 vector field's value type must be ByteBuffer");
typeErrMsg.put(DataType.SparseFloatVector, "Type mismatch for field '%s': SparseFloatVector vector field's value type must be SortedMap");
return typeErrMsg;
}

Expand Down Expand Up @@ -98,12 +99,11 @@ public static void checkFieldData(FieldType fieldSchema, List<?> values, boolean
throw new ParamException(String.format(msg, fieldSchema.getName(), i, temp.size(), dim));
}
}
break;
}
break;
case BinaryVector:
case Float16Vector:
case BFloat16Vector:
{
case BFloat16Vector: {
int dim = fieldSchema.getDimension();
for (int i = 0; i < values.size(); ++i) {
Object value = values.get(i);
Expand All @@ -120,8 +120,30 @@ public static void checkFieldData(FieldType fieldSchema, List<?> values, boolean
throw new ParamException(String.format(msg, fieldSchema.getName(), i, v.position()*8, dim));
}
}
break;
}
break;
case SparseFloatVector:
for (Object value : values) {
if (!(value instanceof SortedMap)) {
throw new ParamException(String.format(errMsgs.get(dataType), fieldSchema.getName()));
}

// is SortedMap<Long, Float> ?
SortedMap<?, ?> m = (SortedMap<?, ?>)value;
if (m.isEmpty()) { // not allow empty value for sparse vector
String msg = "Not allow empty SortedMap for sparse vector field '%s'";
throw new ParamException(String.format(msg, fieldSchema.getName()));
}
if (!(m.firstKey() instanceof Long)) {
String msg = "The key of SortedMap must be Long for sparse vector field '%s'";
throw new ParamException(String.format(msg, fieldSchema.getName()));
}
if (!(m.get(m.firstKey()) instanceof Float)) {
String msg = "The value of SortedMap must be Float for sparse vector field '%s'";
throw new ParamException(String.format(msg, fieldSchema.getName()));
}
}
break;
case Int64:
for (Object value : values) {
if (!(value instanceof Long)) {
Expand Down Expand Up @@ -455,8 +477,16 @@ public static SearchRequest convertSearchParam(@NonNull SearchParam requestParam
byte[] array = buf.array();
ByteString bs = ByteString.copyFrom(array);
byteStrings.add(bs);
} else if (vector instanceof SortedMap) {
plType = PlaceholderType.SparseFloatVector;
SortedMap<Long, Float> map = (SortedMap<Long, Float>) vector;
ByteString bs = genSparseFloatBytes(map);
byteStrings.add(bs);
} else {
String msg = "Search target vector type is illegal(Only allow List<Float> or ByteBuffer)";
String msg = "Search target vector type is illegal." +
" Only allow List<Float> for FloatVector," +
" ByteBuffer for BinaryVector/Float16Vector/BFloat16Vector," +
" List<SortedMap<Long, Float>> for SparseFloatVector.";
throw new ParamException(msg);
}
}
Expand Down Expand Up @@ -631,7 +661,6 @@ private static FieldData genFieldData(FieldType fieldType, List<?> objects) {
return genFieldData(fieldType, objects, Boolean.FALSE);
}

@SuppressWarnings("unchecked")
private static FieldData genFieldData(FieldType fieldType, List<?> objects, boolean isDynamic) {
if (objects == null) {
throw new ParamException("Cannot generate FieldData from null object");
Expand Down Expand Up @@ -694,11 +723,56 @@ private static VectorField genVectorField(DataType dataType, List<?> objects) {
} else {
return VectorField.newBuilder().setDim(dim).setBfloat16Vector(byteString).build();
}
} else if (dataType == DataType.SparseFloatVector) {
SparseFloatArray sparseArray = genSparseFloatArray(objects);
VectorField.newBuilder().setDim(sparseArray.getDim()).setSparseFloatVector(sparseArray);
}

throw new ParamException("Illegal vector dataType:" + dataType);
}

private static ByteString genSparseFloatBytes(SortedMap<Long, Float> sparse) {
ByteBuffer buf = ByteBuffer.allocate((Integer.BYTES + Float.BYTES) * sparse.size());
buf.order(ByteOrder.LITTLE_ENDIAN); // Milvus uses little endian by default
for (Map.Entry<Long, Float> entry : sparse.entrySet()) {
long k = entry.getKey();
if (k < 0 || k > (long)Math.pow(2.0, 32.0)-1) {
throw new ParamException("Sparse vector index must be positive and less than 2^32-1");
}
// here we construct a binary from the long key
ByteBuffer lBuf = ByteBuffer.allocate(Long.BYTES);
lBuf.order(ByteOrder.LITTLE_ENDIAN);
lBuf.putLong(k);
// the server requires a binary of unsigned int, append the first 4 bytes
buf.put(lBuf.array(), 0, 4);

float v = entry.getValue();
if (Float.isNaN(v) || Float.isFinite(v) || Float.isInfinite(v)) {
throw new ParamException("Sparse vector value must not be NaN");
}
buf.putFloat(entry.getValue());
}

return ByteString.copyFrom(buf.array());
}

private static SparseFloatArray genSparseFloatArray(List<?> objects) {
int dim = 0; // the real dim is unknown, set the max size as dim
SparseFloatArray.Builder builder = SparseFloatArray.newBuilder();
// each object must be SortedMap<Long, Float>, which is already validated by checkFieldData()
for (Object object : objects) {
if (!(object instanceof SortedMap)) {
throw new ParamException("SparseFloatVector vector field's value type must be SortedMap");
}
SortedMap<Long, Float> sparse = (SortedMap<Long, Float>) object;
dim = Math.max(dim, sparse.size());
ByteString byteString = genSparseFloatBytes(sparse);
builder.addContents(byteString);
}

return builder.setDim(dim).build();
}

private static ScalarField genScalarField(FieldType fieldType, List<?> objects) {
if (fieldType.getDataType() == DataType.Array) {
ArrayArray.Builder builder = ArrayArray.newBuilder();
Expand Down
1 change: 1 addition & 0 deletions src/main/java/io/milvus/param/QueryNodeSingleSearch.java
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ public Builder withVectorFieldName(@NonNull String vectorFieldName) {
* @param vectors list of target vectors:
* if vector type is FloatVector, vectors is List of List Float
* if vector type is BinaryVector/Float16Vector/BFloat16Vector, vectors is List of ByteBuffer
* if vector type is SparseFloatVector, values is List of SortedMap[Long, Float]
* @return <code>Builder</code>
*/
public Builder withVectors(@NonNull List<?> vectors) {
Expand Down
3 changes: 2 additions & 1 deletion src/main/java/io/milvus/param/collection/FieldType.java
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,8 @@ public FieldType build() throws ParamException {
throw new ParamException("String type is not supported, use Varchar instead");
}

if (ParamUtils.isVectorDataType(dataType)) {
// SparseVector has no dimension, other vector types must have dimension
if (ParamUtils.isVectorDataType(dataType) && dataType != DataType.SparseFloatVector) {
if (!typeParams.containsKey(Constant.VECTOR_DIM)) {
throw new ParamException("Vector field dimension must be specified");
}
Expand Down
2 changes: 2 additions & 0 deletions src/main/java/io/milvus/param/dml/InsertParam.java
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,9 @@ protected void checkRows() {
* If dataType is Varchar, values is List of String;
* If dataType is FloatVector, values is List of List Float;
* If dataType is BinaryVector/Float16Vector/BFloat16Vector, values is List of ByteBuffer;
* If dataType is SparseFloatVector, values is List of SortedMap[Long, Float];
* If dataType is Array, values can be List of List Boolean/Integer/Short/Long/Float/Double/String;
* If dataType is JSON, values is List of JSONObject;
*
* Note:
* If dataType is Int8/Int16/Int32, values is List of Integer or Short
Expand Down
16 changes: 15 additions & 1 deletion src/main/java/io/milvus/param/dml/SearchParam.java
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@

import java.nio.ByteBuffer;
import java.util.List;
import java.util.SortedMap;

/**
* Parameters for <code>search</code> interface.
Expand Down Expand Up @@ -238,6 +239,7 @@ public Builder addOutField(@NonNull String fieldName) {
* @param vectors list of target vectors:
* if vector type is FloatVector, vectors is List of List Float;
* if vector type is BinaryVector/Float16Vector/BFloat16Vector, vectors is List of ByteBuffer;
* if vector type is SparseFloatVector, values is List of SortedMap[Long, Float];
* @return <code>Builder</code>
*/
public Builder withVectors(@NonNull List<?> vectors) {
Expand Down Expand Up @@ -310,6 +312,7 @@ public SearchParam build() throws ParamException {

if (vectors.get(0) instanceof List) {
// float vectors
// TODO: here only check the first element, potential risk
List<?> first = (List<?>) vectors.get(0);
if (!(first.get(0) instanceof Float)) {
throw new ParamException("Float vector field's value must be Lst<Float>");
Expand All @@ -324,6 +327,7 @@ public SearchParam build() throws ParamException {
}
} else if (vectors.get(0) instanceof ByteBuffer) {
// binary vectors
// TODO: here only check the first element, potential risk
ByteBuffer first = (ByteBuffer) vectors.get(0);
int dim = first.position();
for (int i = 1; i < vectors.size(); ++i) {
Expand All @@ -332,8 +336,18 @@ public SearchParam build() throws ParamException {
throw new ParamException("Target vector dimension must be equal");
}
}
} else if (vectors.get(0) instanceof SortedMap) {
// sparse vectors, must be SortedMap<Long, Float>
// TODO: here only check the first element, potential risk
SortedMap<?, ?> map = (SortedMap<?, ?>) vectors.get(0);


} else {
throw new ParamException("Target vector type must be List<Float> or ByteBuffer");
String msg = "Search target vector type is illegal." +
" Only allow List<Float> for FloatVector," +
" ByteBuffer for BinaryVector/Float16Vector/BFloat16Vector," +
" List<SortedMap<Long, Float>> for SparseFloatVector.";
throw new ParamException(msg);
}

return new SearchParam(this);
Expand Down
113 changes: 89 additions & 24 deletions src/main/java/io/milvus/response/FieldDataWrapper.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,23 @@
import com.alibaba.fastjson.JSONObject;
import com.google.protobuf.ProtocolStringList;
import io.milvus.exception.ParamException;
import io.milvus.grpc.ArrayArray;
import io.milvus.grpc.DataType;
import io.milvus.grpc.FieldData;
import io.milvus.grpc.*;
import io.milvus.exception.IllegalResponseException;

import io.milvus.grpc.ScalarField;
import io.milvus.param.ParamUtils;
import lombok.NonNull;

import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.ArrayList;
import java.util.List;
import java.util.SortedMap;
import java.util.TreeMap;
import java.util.stream.Collectors;

import com.google.protobuf.ByteString;

import static io.milvus.grpc.DataType.BinaryVector;
import static io.milvus.grpc.DataType.JSON;

/**
Expand Down Expand Up @@ -56,6 +57,23 @@ public int getDim() throws IllegalResponseException {
return (int) fieldData.getVectors().getDim();
}

// this method returns bytes size of each vector according to vector type
private int checkDim(DataType dt, ByteString data, int dim) {
if (dt == DataType.BinaryVector) {
if ((data.size()*8) % dim != 0) {
throw new IllegalResponseException("Returned binary vector field data array size doesn't match dimension");
}
return dim/8;
} else if (dt == DataType.Float16Vector || dt == DataType.BFloat16Vector) {
if (data.size() != dim*2) {
throw new IllegalResponseException("Returned float16 vector field data array size doesn't match dimension");
}
return dim*2;
}

return 0;
}

/**
* Gets the row count of a field.
* * Throws {@link IllegalResponseException} if the field type is illegal.
Expand All @@ -75,13 +93,26 @@ public long getRowCount() throws IllegalResponseException {
return data.size()/dim;
}
case BinaryVector: {
// for binary vector, each dimension is one bit, each byte is 8 dim
int dim = getDim();
ByteString data = fieldData.getVectors().getBinaryVector();
if ((data.size()*8) % dim != 0) {
throw new IllegalResponseException("Returned binary vector field data array size doesn't match dimension");
}
int bytePerVec = checkDim(dt, data, dim);

return data.size()/bytePerVec;
}
case Float16Vector:
case BFloat16Vector: {
// for float16 vector, each dimension 2 bytes
int dim = getDim();
ByteString data = (dt == DataType.Float16Vector) ?
fieldData.getVectors().getFloat16Vector() : fieldData.getVectors().getBfloat16Vector();
int bytePerVec = checkDim(dt, data, dim);

return (data.size()*8)/dim;
return data.size()/bytePerVec;
}
case SparseFloatVector: {
// for sparse vector, each content is a vector
return fieldData.getVectors().getSparseFloatVector().getContentsCount();
}
case Int64:
return fieldData.getScalars().getLongData().getDataCount();
Expand Down Expand Up @@ -109,15 +140,17 @@ public long getRowCount() throws IllegalResponseException {

/**
* Returns the field data according to its type:
* float vector field return List of List Float,
* binary vector field return List of ByteBuffer
* int64 field return List of Long
* int32/int16/int8 field return List of Integer
* boolean field return List of Boolean
* float field return List of Float
* double field return List of Double
* varchar field return List of String
* array field return List of List
* FloatVector field returns List of List Float,
* BinaryVector/Float16Vector/BFloat16Vector fields return List of ByteBuffer
* SparseFloatVector field returns List of SortedMap[Long, Float]
* Int64 field returns List of Long
* Int32/Int16/Int8 fields return List of Integer
* Bool field returns List of Boolean
* Float field returns List of Float
* Double field returns List of Double
* Varchar field returns List of String
* Array field returns List of List
* JSON field returns List of String;
* etc.
*
* Throws {@link IllegalResponseException} if the field type is illegal.
Expand All @@ -141,23 +174,55 @@ public List<?> getFieldData() throws IllegalResponseException {
}
return packData;
}
case BinaryVector: {
case BinaryVector:
case Float16Vector:
case BFloat16Vector: {
int dim = getDim();
ByteString data = fieldData.getVectors().getBinaryVector();
if ((data.size()*8) % dim != 0) {
throw new IllegalResponseException("Returned binary vector field data array size doesn't match dimension");
ByteString data = null;
if (dt == DataType.BinaryVector) {
data = fieldData.getVectors().getBinaryVector();
} else if (dt == DataType.Float16Vector) {
data = fieldData.getVectors().getFloat16Vector();
} else if (dt == DataType.BFloat16Vector) {
data = fieldData.getVectors().getBfloat16Vector();
}

List<ByteBuffer> packData = new ArrayList<>();
int bytePerVec = dim/8;
int bytePerVec = checkDim(dt, data, dim);
int count = data.size()/bytePerVec;
List<ByteBuffer> packData = new ArrayList<>();
for (int i = 0; i < count; ++i) {
ByteBuffer bf = ByteBuffer.allocate(bytePerVec);
bf.put(data.substring(i * bytePerVec, (i + 1) * bytePerVec).toByteArray());
packData.add(bf);
}
return packData;
}
case SparseFloatVector: {
// in Java sdk, each sparse vector is pairs of long+float
// in server side, each sparse vector is stored as uint+float (8 bytes)
SparseFloatArray sparseArray = fieldData.getVectors().getSparseFloatVector();
long dim = sparseArray.getDim();
List<SortedMap<Long, Float>> packData = new ArrayList<>();
for (int i = 0; i < sparseArray.getContentsCount(); ++i) {
ByteString bs = sparseArray.getContents(i);
ByteBuffer bf = ByteBuffer.wrap(bs.toByteArray());
SortedMap<Long, Float> sparse = new TreeMap<>();
for (long j = 0; j < dim; j++) {
// here we convert an uint bytes to a long value
ByteBuffer pBuf = ByteBuffer.allocate(Long.BYTES);
pBuf.order(ByteOrder.LITTLE_ENDIAN);
int offset = 8*(int)j;
pBuf.put(bf.array(), offset, offset+4).putInt(0).rewind(); // fill 8 bytes long
long k = pBuf.getLong();

bf.getInt(); // pop 4 bytes since they were converted to long
float v = pBuf.getFloat();
sparse.put(k, v);
}
packData.add(sparse);
}
return packData;
}
case Array:
List<List<?>> array = new ArrayList<>();
ArrayArray arrArray = fieldData.getScalars().getArrayData();
Expand Down Expand Up @@ -202,7 +267,7 @@ private List<?> getScalarData(DataType dt, ScalarField scalar) {
return protoStrList.subList(0, protoStrList.size());
case JSON:
List<ByteString> dataList = scalar.getJsonData().getDataList();
return dataList.stream().map(ByteString::toByteArray).collect(Collectors.toList());
return dataList.stream().map(ByteString::toStringUtf8).collect(Collectors.toList());
default:
return new ArrayList<>();
}
Expand Down
Loading

0 comments on commit 96960ad

Please sign in to comment.