Skip to content

Commit

Permalink
Support sparse vector
Browse files Browse the repository at this point in the history
Signed-off-by: yhmo <[email protected]>
  • Loading branch information
yhmo committed Apr 1, 2024
1 parent b9abebd commit 7af7fad
Show file tree
Hide file tree
Showing 11 changed files with 352 additions and 42 deletions.
15 changes: 9 additions & 6 deletions src/main/java/io/milvus/param/IndexType.java
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
*/
public enum IndexType {
None(0),
//Only supported for float vectors
// Only supported for float vectors
FLAT(1),
IVF_FLAT(2),
IVF_SQ8(3),
Expand All @@ -37,19 +37,22 @@ public enum IndexType {
AUTOINDEX(11),
SCANN(12),

// GPU index
// GPU indexes only for float vectors
GPU_IVF_FLAT(50),
GPU_IVF_PQ(51),

//Only supported for binary vectors
// Only supported for binary vectors
BIN_FLAT(80),
BIN_IVF_FLAT(81),

//Scalar field index start from here
//Only for varchar type field
// Only for varchar type field
TRIE("Trie", 100),
//Only for scalar type field
// Only for scalar type field
STL_SORT(200),

// Only for sparse vectors
SPARSE_INVERTED_INDEX(300),
SPARSE_WAND(301)
;

@Getter
Expand Down
87 changes: 81 additions & 6 deletions src/main/java/io/milvus/param/ParamUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ public static HashMap<DataType, String> getTypeErrorMsg() {
typeErrMsg.put(DataType.BinaryVector, "Type mismatch for field '%s': Binary vector field's value type must be ByteBuffer");
typeErrMsg.put(DataType.Float16Vector, "Type mismatch for field '%s': Float16 vector field's value type must be ByteBuffer");
typeErrMsg.put(DataType.BFloat16Vector, "Type mismatch for field '%s': BFloat16 vector field's value type must be ByteBuffer");
typeErrMsg.put(DataType.SparseFloatVector, "Type mismatch for field '%s': SparseFloatVector vector field's value type must be SortedMap");
return typeErrMsg;
}

Expand Down Expand Up @@ -98,12 +99,11 @@ public static void checkFieldData(FieldType fieldSchema, List<?> values, boolean
throw new ParamException(String.format(msg, fieldSchema.getName(), i, temp.size(), dim));
}
}
break;
}
break;
case BinaryVector:
case Float16Vector:
case BFloat16Vector:
{
case BFloat16Vector: {
int dim = fieldSchema.getDimension();
for (int i = 0; i < values.size(); ++i) {
Object value = values.get(i);
Expand All @@ -120,8 +120,30 @@ public static void checkFieldData(FieldType fieldSchema, List<?> values, boolean
throw new ParamException(String.format(msg, fieldSchema.getName(), i, v.position()*8, dim));
}
}
break;
}
break;
case SparseFloatVector:
for (Object value : values) {
if (!(value instanceof SortedMap)) {
throw new ParamException(String.format(errMsgs.get(dataType), fieldSchema.getName()));
}

// is SortedMap<Long, Float> ?
SortedMap<?, ?> m = (SortedMap<?, ?>)value;
if (m.isEmpty()) { // not allow empty value for sparse vector
String msg = "Not allow empty SortedMap for sparse vector field '%s'";
throw new ParamException(String.format(msg, fieldSchema.getName()));
}
if (!(m.firstKey() instanceof Long)) {
String msg = "The key of SortedMap must be Long for sparse vector field '%s'";
throw new ParamException(String.format(msg, fieldSchema.getName()));
}
if (!(m.get(m.firstKey()) instanceof Float)) {
String msg = "The value of SortedMap must be Float for sparse vector field '%s'";
throw new ParamException(String.format(msg, fieldSchema.getName()));
}
}
break;
case Int64:
for (Object value : values) {
if (!(value instanceof Long)) {
Expand Down Expand Up @@ -455,8 +477,16 @@ public static SearchRequest convertSearchParam(@NonNull SearchParam requestParam
byte[] array = buf.array();
ByteString bs = ByteString.copyFrom(array);
byteStrings.add(bs);
} else if (vector instanceof SortedMap) {
plType = PlaceholderType.SparseFloatVector;
SortedMap<Long, Float> map = (SortedMap<Long, Float>) vector;
ByteString bs = genSparseFloatBytes(map);
byteStrings.add(bs);
} else {
String msg = "Search target vector type is illegal(Only allow List<Float> or ByteBuffer)";
String msg = "Search target vector type is illegal." +
" Only allow List<Float> for FloatVector," +
" ByteBuffer for BinaryVector/Float16Vector/BFloat16Vector," +
" List<SortedMap<Long, Float>> for SparseFloatVector.";
throw new ParamException(msg);
}
}
Expand Down Expand Up @@ -623,6 +653,7 @@ public static boolean isVectorDataType(DataType dataType) {
add(DataType.BinaryVector);
add(DataType.Float16Vector);
add(DataType.BFloat16Vector);
add(DataType.SparseFloatVector);
}};
return vectorDataType.contains(dataType);
}
Expand All @@ -631,7 +662,6 @@ private static FieldData genFieldData(FieldType fieldType, List<?> objects) {
return genFieldData(fieldType, objects, Boolean.FALSE);
}

@SuppressWarnings("unchecked")
private static FieldData genFieldData(FieldType fieldType, List<?> objects, boolean isDynamic) {
if (objects == null) {
throw new ParamException("Cannot generate FieldData from null object");
Expand Down Expand Up @@ -694,11 +724,56 @@ private static VectorField genVectorField(DataType dataType, List<?> objects) {
} else {
return VectorField.newBuilder().setDim(dim).setBfloat16Vector(byteString).build();
}
} else if (dataType == DataType.SparseFloatVector) {
SparseFloatArray sparseArray = genSparseFloatArray(objects);
return VectorField.newBuilder().setDim(sparseArray.getDim()).setSparseFloatVector(sparseArray).build();
}

throw new ParamException("Illegal vector dataType:" + dataType);
}

private static ByteString genSparseFloatBytes(SortedMap<Long, Float> sparse) {
ByteBuffer buf = ByteBuffer.allocate((Integer.BYTES + Float.BYTES) * sparse.size());
buf.order(ByteOrder.LITTLE_ENDIAN); // Milvus uses little endian by default
for (Map.Entry<Long, Float> entry : sparse.entrySet()) {
long k = entry.getKey();
if (k < 0 || k > (long)Math.pow(2.0, 32.0)-1) {
throw new ParamException("Sparse vector index must be positive and less than 2^32-1");
}
// here we construct a binary from the long key
ByteBuffer lBuf = ByteBuffer.allocate(Long.BYTES);
lBuf.order(ByteOrder.LITTLE_ENDIAN);
lBuf.putLong(k);
// the server requires a binary of unsigned int, append the first 4 bytes
buf.put(lBuf.array(), 0, 4);

float v = entry.getValue();
if (Float.isNaN(v) || Float.isInfinite(v)) {
throw new ParamException("Sparse vector value cannot be NaN or Infinite");
}
buf.putFloat(entry.getValue());
}

return ByteString.copyFrom(buf.array());
}

private static SparseFloatArray genSparseFloatArray(List<?> objects) {
int dim = 0; // the real dim is unknown, set the max size as dim
SparseFloatArray.Builder builder = SparseFloatArray.newBuilder();
// each object must be SortedMap<Long, Float>, which is already validated by checkFieldData()
for (Object object : objects) {
if (!(object instanceof SortedMap)) {
throw new ParamException("SparseFloatVector vector field's value type must be SortedMap");
}
SortedMap<Long, Float> sparse = (SortedMap<Long, Float>) object;
dim = Math.max(dim, sparse.size());
ByteString byteString = genSparseFloatBytes(sparse);
builder.addContents(byteString);
}

return builder.setDim(dim).build();
}

private static ScalarField genScalarField(FieldType fieldType, List<?> objects) {
if (fieldType.getDataType() == DataType.Array) {
ArrayArray.Builder builder = ArrayArray.newBuilder();
Expand Down
1 change: 1 addition & 0 deletions src/main/java/io/milvus/param/QueryNodeSingleSearch.java
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ public Builder withVectorFieldName(@NonNull String vectorFieldName) {
* @param vectors list of target vectors:
* if vector type is FloatVector, vectors is List of List Float
* if vector type is BinaryVector/Float16Vector/BFloat16Vector, vectors is List of ByteBuffer
* if vector type is SparseFloatVector, values is List of SortedMap[Long, Float]
* @return <code>Builder</code>
*/
public Builder withVectors(@NonNull List<?> vectors) {
Expand Down
3 changes: 2 additions & 1 deletion src/main/java/io/milvus/param/collection/FieldType.java
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,8 @@ public FieldType build() throws ParamException {
throw new ParamException("String type is not supported, use Varchar instead");
}

if (ParamUtils.isVectorDataType(dataType)) {
// SparseVector has no dimension, other vector types must have dimension
if (ParamUtils.isVectorDataType(dataType) && dataType != DataType.SparseFloatVector) {
if (!typeParams.containsKey(Constant.VECTOR_DIM)) {
throw new ParamException("Vector field dimension must be specified");
}
Expand Down
2 changes: 2 additions & 0 deletions src/main/java/io/milvus/param/dml/InsertParam.java
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,9 @@ protected void checkRows() {
* If dataType is Varchar, values is List of String;
* If dataType is FloatVector, values is List of List Float;
* If dataType is BinaryVector/Float16Vector/BFloat16Vector, values is List of ByteBuffer;
* If dataType is SparseFloatVector, values is List of SortedMap[Long, Float];
* If dataType is Array, values can be List of List Boolean/Integer/Short/Long/Float/Double/String;
* If dataType is JSON, values is List of JSONObject;
*
* Note:
* If dataType is Int8/Int16/Int32, values is List of Integer or Short
Expand Down
16 changes: 15 additions & 1 deletion src/main/java/io/milvus/param/dml/SearchParam.java
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@

import java.nio.ByteBuffer;
import java.util.List;
import java.util.SortedMap;

/**
* Parameters for <code>search</code> interface.
Expand Down Expand Up @@ -238,6 +239,7 @@ public Builder addOutField(@NonNull String fieldName) {
* @param vectors list of target vectors:
* if vector type is FloatVector, vectors is List of List Float;
* if vector type is BinaryVector/Float16Vector/BFloat16Vector, vectors is List of ByteBuffer;
* if vector type is SparseFloatVector, values is List of SortedMap[Long, Float];
* @return <code>Builder</code>
*/
public Builder withVectors(@NonNull List<?> vectors) {
Expand Down Expand Up @@ -310,6 +312,7 @@ public SearchParam build() throws ParamException {

if (vectors.get(0) instanceof List) {
// float vectors
// TODO: here only check the first element, potential risk
List<?> first = (List<?>) vectors.get(0);
if (!(first.get(0) instanceof Float)) {
throw new ParamException("Float vector field's value must be Lst<Float>");
Expand All @@ -324,6 +327,7 @@ public SearchParam build() throws ParamException {
}
} else if (vectors.get(0) instanceof ByteBuffer) {
// binary vectors
// TODO: here only check the first element, potential risk
ByteBuffer first = (ByteBuffer) vectors.get(0);
int dim = first.position();
for (int i = 1; i < vectors.size(); ++i) {
Expand All @@ -332,8 +336,18 @@ public SearchParam build() throws ParamException {
throw new ParamException("Target vector dimension must be equal");
}
}
} else if (vectors.get(0) instanceof SortedMap) {
// sparse vectors, must be SortedMap<Long, Float>
// TODO: here only check the first element, potential risk
SortedMap<?, ?> map = (SortedMap<?, ?>) vectors.get(0);


} else {
throw new ParamException("Target vector type must be List<Float> or ByteBuffer");
String msg = "Search target vector type is illegal." +
" Only allow List<Float> for FloatVector," +
" ByteBuffer for BinaryVector/Float16Vector/BFloat16Vector," +
" List<SortedMap<Long, Float>> for SparseFloatVector.";
throw new ParamException(msg);
}

return new SearchParam(this);
Expand Down
Loading

0 comments on commit 7af7fad

Please sign in to comment.