diff --git a/src/main/java/io/milvus/param/IndexType.java b/src/main/java/io/milvus/param/IndexType.java index 6b7cc3de5..002f10a2d 100644 --- a/src/main/java/io/milvus/param/IndexType.java +++ b/src/main/java/io/milvus/param/IndexType.java @@ -27,7 +27,7 @@ */ public enum IndexType { None(0), - //Only supported for float vectors + // Only supported for float vectors FLAT(1), IVF_FLAT(2), IVF_SQ8(3), @@ -37,19 +37,22 @@ public enum IndexType { AUTOINDEX(11), SCANN(12), - // GPU index + // GPU indexes only for float vectors GPU_IVF_FLAT(50), GPU_IVF_PQ(51), - //Only supported for binary vectors + // Only supported for binary vectors BIN_FLAT(80), BIN_IVF_FLAT(81), - //Scalar field index start from here - //Only for varchar type field + // Only for varchar type field TRIE("Trie", 100), - //Only for scalar type field + // Only for scalar type field STL_SORT(200), + + // Only for sparse vectors + SPARSE_INVERTED_INDEX(300), + SPARSE_WAND(301) ; @Getter diff --git a/src/main/java/io/milvus/param/ParamUtils.java b/src/main/java/io/milvus/param/ParamUtils.java index e31020b1f..850746627 100644 --- a/src/main/java/io/milvus/param/ParamUtils.java +++ b/src/main/java/io/milvus/param/ParamUtils.java @@ -46,6 +46,7 @@ public static HashMap getTypeErrorMsg() { typeErrMsg.put(DataType.BinaryVector, "Type mismatch for field '%s': Binary vector field's value type must be ByteBuffer"); typeErrMsg.put(DataType.Float16Vector, "Type mismatch for field '%s': Float16 vector field's value type must be ByteBuffer"); typeErrMsg.put(DataType.BFloat16Vector, "Type mismatch for field '%s': BFloat16 vector field's value type must be ByteBuffer"); + typeErrMsg.put(DataType.SparseFloatVector, "Type mismatch for field '%s': SparseFloatVector vector field's value type must be SortedMap"); return typeErrMsg; } @@ -98,12 +99,11 @@ public static void checkFieldData(FieldType fieldSchema, List values, boolean throw new ParamException(String.format(msg, fieldSchema.getName(), i, temp.size(), dim)); } } + break; } - break; case BinaryVector: case Float16Vector: - case BFloat16Vector: - { + case BFloat16Vector: { int dim = fieldSchema.getDimension(); for (int i = 0; i < values.size(); ++i) { Object value = values.get(i); @@ -120,8 +120,30 @@ public static void checkFieldData(FieldType fieldSchema, List values, boolean throw new ParamException(String.format(msg, fieldSchema.getName(), i, v.position()*8, dim)); } } + break; } - break; + case SparseFloatVector: + for (Object value : values) { + if (!(value instanceof SortedMap)) { + throw new ParamException(String.format(errMsgs.get(dataType), fieldSchema.getName())); + } + + // is SortedMap ? + SortedMap m = (SortedMap)value; + if (m.isEmpty()) { // not allow empty value for sparse vector + String msg = "Not allow empty SortedMap for sparse vector field '%s'"; + throw new ParamException(String.format(msg, fieldSchema.getName())); + } + if (!(m.firstKey() instanceof Long)) { + String msg = "The key of SortedMap must be Long for sparse vector field '%s'"; + throw new ParamException(String.format(msg, fieldSchema.getName())); + } + if (!(m.get(m.firstKey()) instanceof Float)) { + String msg = "The value of SortedMap must be Float for sparse vector field '%s'"; + throw new ParamException(String.format(msg, fieldSchema.getName())); + } + } + break; case Int64: for (Object value : values) { if (!(value instanceof Long)) { @@ -455,8 +477,16 @@ public static SearchRequest convertSearchParam(@NonNull SearchParam requestParam byte[] array = buf.array(); ByteString bs = ByteString.copyFrom(array); byteStrings.add(bs); + } else if (vector instanceof SortedMap) { + plType = PlaceholderType.SparseFloatVector; + SortedMap map = (SortedMap) vector; + ByteString bs = genSparseFloatBytes(map); + byteStrings.add(bs); } else { - String msg = "Search target vector type is illegal(Only allow List or ByteBuffer)"; + String msg = "Search target vector type is illegal." + + " Only allow List for FloatVector," + + " ByteBuffer for BinaryVector/Float16Vector/BFloat16Vector," + + " List> for SparseFloatVector."; throw new ParamException(msg); } } @@ -623,6 +653,7 @@ public static boolean isVectorDataType(DataType dataType) { add(DataType.BinaryVector); add(DataType.Float16Vector); add(DataType.BFloat16Vector); + add(DataType.SparseFloatVector); }}; return vectorDataType.contains(dataType); } @@ -631,7 +662,6 @@ private static FieldData genFieldData(FieldType fieldType, List objects) { return genFieldData(fieldType, objects, Boolean.FALSE); } - @SuppressWarnings("unchecked") private static FieldData genFieldData(FieldType fieldType, List objects, boolean isDynamic) { if (objects == null) { throw new ParamException("Cannot generate FieldData from null object"); @@ -694,11 +724,56 @@ private static VectorField genVectorField(DataType dataType, List objects) { } else { return VectorField.newBuilder().setDim(dim).setBfloat16Vector(byteString).build(); } + } else if (dataType == DataType.SparseFloatVector) { + SparseFloatArray sparseArray = genSparseFloatArray(objects); + return VectorField.newBuilder().setDim(sparseArray.getDim()).setSparseFloatVector(sparseArray).build(); } throw new ParamException("Illegal vector dataType:" + dataType); } + private static ByteString genSparseFloatBytes(SortedMap sparse) { + ByteBuffer buf = ByteBuffer.allocate((Integer.BYTES + Float.BYTES) * sparse.size()); + buf.order(ByteOrder.LITTLE_ENDIAN); // Milvus uses little endian by default + for (Map.Entry entry : sparse.entrySet()) { + long k = entry.getKey(); + if (k < 0 || k > (long)Math.pow(2.0, 32.0)-1) { + throw new ParamException("Sparse vector index must be positive and less than 2^32-1"); + } + // here we construct a binary from the long key + ByteBuffer lBuf = ByteBuffer.allocate(Long.BYTES); + lBuf.order(ByteOrder.LITTLE_ENDIAN); + lBuf.putLong(k); + // the server requires a binary of unsigned int, append the first 4 bytes + buf.put(lBuf.array(), 0, 4); + + float v = entry.getValue(); + if (Float.isNaN(v) || Float.isInfinite(v)) { + throw new ParamException("Sparse vector value cannot be NaN or Infinite"); + } + buf.putFloat(entry.getValue()); + } + + return ByteString.copyFrom(buf.array()); + } + + private static SparseFloatArray genSparseFloatArray(List objects) { + int dim = 0; // the real dim is unknown, set the max size as dim + SparseFloatArray.Builder builder = SparseFloatArray.newBuilder(); + // each object must be SortedMap, which is already validated by checkFieldData() + for (Object object : objects) { + if (!(object instanceof SortedMap)) { + throw new ParamException("SparseFloatVector vector field's value type must be SortedMap"); + } + SortedMap sparse = (SortedMap) object; + dim = Math.max(dim, sparse.size()); + ByteString byteString = genSparseFloatBytes(sparse); + builder.addContents(byteString); + } + + return builder.setDim(dim).build(); + } + private static ScalarField genScalarField(FieldType fieldType, List objects) { if (fieldType.getDataType() == DataType.Array) { ArrayArray.Builder builder = ArrayArray.newBuilder(); diff --git a/src/main/java/io/milvus/param/QueryNodeSingleSearch.java b/src/main/java/io/milvus/param/QueryNodeSingleSearch.java index b8e12e720..396c04ef5 100644 --- a/src/main/java/io/milvus/param/QueryNodeSingleSearch.java +++ b/src/main/java/io/milvus/param/QueryNodeSingleSearch.java @@ -101,6 +101,7 @@ public Builder withVectorFieldName(@NonNull String vectorFieldName) { * @param vectors list of target vectors: * if vector type is FloatVector, vectors is List of List Float * if vector type is BinaryVector/Float16Vector/BFloat16Vector, vectors is List of ByteBuffer + * if vector type is SparseFloatVector, values is List of SortedMap[Long, Float] * @return Builder */ public Builder withVectors(@NonNull List vectors) { diff --git a/src/main/java/io/milvus/param/collection/FieldType.java b/src/main/java/io/milvus/param/collection/FieldType.java index 6db652060..0ba92ae7b 100644 --- a/src/main/java/io/milvus/param/collection/FieldType.java +++ b/src/main/java/io/milvus/param/collection/FieldType.java @@ -271,7 +271,8 @@ public FieldType build() throws ParamException { throw new ParamException("String type is not supported, use Varchar instead"); } - if (ParamUtils.isVectorDataType(dataType)) { + // SparseVector has no dimension, other vector types must have dimension + if (ParamUtils.isVectorDataType(dataType) && dataType != DataType.SparseFloatVector) { if (!typeParams.containsKey(Constant.VECTOR_DIM)) { throw new ParamException("Vector field dimension must be specified"); } diff --git a/src/main/java/io/milvus/param/dml/InsertParam.java b/src/main/java/io/milvus/param/dml/InsertParam.java index c82b22dc4..13757af8d 100644 --- a/src/main/java/io/milvus/param/dml/InsertParam.java +++ b/src/main/java/io/milvus/param/dml/InsertParam.java @@ -218,7 +218,9 @@ protected void checkRows() { * If dataType is Varchar, values is List of String; * If dataType is FloatVector, values is List of List Float; * If dataType is BinaryVector/Float16Vector/BFloat16Vector, values is List of ByteBuffer; + * If dataType is SparseFloatVector, values is List of SortedMap[Long, Float]; * If dataType is Array, values can be List of List Boolean/Integer/Short/Long/Float/Double/String; + * If dataType is JSON, values is List of JSONObject; * * Note: * If dataType is Int8/Int16/Int32, values is List of Integer or Short diff --git a/src/main/java/io/milvus/param/dml/SearchParam.java b/src/main/java/io/milvus/param/dml/SearchParam.java index f3e089d37..355a4812d 100644 --- a/src/main/java/io/milvus/param/dml/SearchParam.java +++ b/src/main/java/io/milvus/param/dml/SearchParam.java @@ -32,6 +32,7 @@ import java.nio.ByteBuffer; import java.util.List; +import java.util.SortedMap; /** * Parameters for search interface. @@ -238,6 +239,7 @@ public Builder addOutField(@NonNull String fieldName) { * @param vectors list of target vectors: * if vector type is FloatVector, vectors is List of List Float; * if vector type is BinaryVector/Float16Vector/BFloat16Vector, vectors is List of ByteBuffer; + * if vector type is SparseFloatVector, values is List of SortedMap[Long, Float]; * @return Builder */ public Builder withVectors(@NonNull List vectors) { @@ -310,6 +312,7 @@ public SearchParam build() throws ParamException { if (vectors.get(0) instanceof List) { // float vectors + // TODO: here only check the first element, potential risk List first = (List) vectors.get(0); if (!(first.get(0) instanceof Float)) { throw new ParamException("Float vector field's value must be Lst"); @@ -324,6 +327,7 @@ public SearchParam build() throws ParamException { } } else if (vectors.get(0) instanceof ByteBuffer) { // binary vectors + // TODO: here only check the first element, potential risk ByteBuffer first = (ByteBuffer) vectors.get(0); int dim = first.position(); for (int i = 1; i < vectors.size(); ++i) { @@ -332,8 +336,18 @@ public SearchParam build() throws ParamException { throw new ParamException("Target vector dimension must be equal"); } } + } else if (vectors.get(0) instanceof SortedMap) { + // sparse vectors, must be SortedMap + // TODO: here only check the first element, potential risk + SortedMap map = (SortedMap) vectors.get(0); + + } else { - throw new ParamException("Target vector type must be List or ByteBuffer"); + String msg = "Search target vector type is illegal." + + " Only allow List for FloatVector," + + " ByteBuffer for BinaryVector/Float16Vector/BFloat16Vector," + + " List> for SparseFloatVector."; + throw new ParamException(msg); } return new SearchParam(this); diff --git a/src/main/java/io/milvus/response/FieldDataWrapper.java b/src/main/java/io/milvus/response/FieldDataWrapper.java index 77d920876..ee3609c3c 100644 --- a/src/main/java/io/milvus/response/FieldDataWrapper.java +++ b/src/main/java/io/milvus/response/FieldDataWrapper.java @@ -3,18 +3,18 @@ import com.alibaba.fastjson.JSONObject; import com.google.protobuf.ProtocolStringList; import io.milvus.exception.ParamException; -import io.milvus.grpc.ArrayArray; -import io.milvus.grpc.DataType; -import io.milvus.grpc.FieldData; +import io.milvus.grpc.*; import io.milvus.exception.IllegalResponseException; -import io.milvus.grpc.ScalarField; import io.milvus.param.ParamUtils; import lombok.NonNull; import java.nio.ByteBuffer; +import java.nio.ByteOrder; import java.util.ArrayList; import java.util.List; +import java.util.SortedMap; +import java.util.TreeMap; import java.util.stream.Collectors; import com.google.protobuf.ByteString; @@ -56,6 +56,23 @@ public int getDim() throws IllegalResponseException { return (int) fieldData.getVectors().getDim(); } + // this method returns bytes size of each vector according to vector type + private int checkDim(DataType dt, ByteString data, int dim) { + if (dt == DataType.BinaryVector) { + if ((data.size()*8) % dim != 0) { + throw new IllegalResponseException("Returned binary vector field data array size doesn't match dimension"); + } + return dim/8; + } else if (dt == DataType.Float16Vector || dt == DataType.BFloat16Vector) { + if (data.size() != dim*2) { + throw new IllegalResponseException("Returned float16 vector field data array size doesn't match dimension"); + } + return dim*2; + } + + return 0; + } + /** * Gets the row count of a field. * * Throws {@link IllegalResponseException} if the field type is illegal. @@ -75,13 +92,26 @@ public long getRowCount() throws IllegalResponseException { return data.size()/dim; } case BinaryVector: { + // for binary vector, each dimension is one bit, each byte is 8 dim int dim = getDim(); ByteString data = fieldData.getVectors().getBinaryVector(); - if ((data.size()*8) % dim != 0) { - throw new IllegalResponseException("Returned binary vector field data array size doesn't match dimension"); - } + int bytePerVec = checkDim(dt, data, dim); - return (data.size()*8)/dim; + return data.size()/bytePerVec; + } + case Float16Vector: + case BFloat16Vector: { + // for float16 vector, each dimension 2 bytes + int dim = getDim(); + ByteString data = (dt == DataType.Float16Vector) ? + fieldData.getVectors().getFloat16Vector() : fieldData.getVectors().getBfloat16Vector(); + int bytePerVec = checkDim(dt, data, dim); + + return data.size()/bytePerVec; + } + case SparseFloatVector: { + // for sparse vector, each content is a vector + return fieldData.getVectors().getSparseFloatVector().getContentsCount(); } case Int64: return fieldData.getScalars().getLongData().getDataCount(); @@ -109,15 +139,17 @@ public long getRowCount() throws IllegalResponseException { /** * Returns the field data according to its type: - * float vector field return List of List Float, - * binary vector field return List of ByteBuffer - * int64 field return List of Long - * int32/int16/int8 field return List of Integer - * boolean field return List of Boolean - * float field return List of Float - * double field return List of Double - * varchar field return List of String - * array field return List of List + * FloatVector field returns List of List Float, + * BinaryVector/Float16Vector/BFloat16Vector fields return List of ByteBuffer + * SparseFloatVector field returns List of SortedMap[Long, Float] + * Int64 field returns List of Long + * Int32/Int16/Int8 fields return List of Integer + * Bool field returns List of Boolean + * Float field returns List of Float + * Double field returns List of Double + * Varchar field returns List of String + * Array field returns List of List + * JSON field returns List of String; * etc. * * Throws {@link IllegalResponseException} if the field type is illegal. @@ -141,16 +173,22 @@ public List getFieldData() throws IllegalResponseException { } return packData; } - case BinaryVector: { + case BinaryVector: + case Float16Vector: + case BFloat16Vector: { int dim = getDim(); - ByteString data = fieldData.getVectors().getBinaryVector(); - if ((data.size()*8) % dim != 0) { - throw new IllegalResponseException("Returned binary vector field data array size doesn't match dimension"); + ByteString data = null; + if (dt == DataType.BinaryVector) { + data = fieldData.getVectors().getBinaryVector(); + } else if (dt == DataType.Float16Vector) { + data = fieldData.getVectors().getFloat16Vector(); + } else if (dt == DataType.BFloat16Vector) { + data = fieldData.getVectors().getBfloat16Vector(); } - List packData = new ArrayList<>(); - int bytePerVec = dim/8; + int bytePerVec = checkDim(dt, data, dim); int count = data.size()/bytePerVec; + List packData = new ArrayList<>(); for (int i = 0; i < count; ++i) { ByteBuffer bf = ByteBuffer.allocate(bytePerVec); bf.put(data.substring(i * bytePerVec, (i + 1) * bytePerVec).toByteArray()); @@ -158,6 +196,40 @@ public List getFieldData() throws IllegalResponseException { } return packData; } + case SparseFloatVector: { + // in Java sdk, each sparse vector is pairs of long+float + // in server side, each sparse vector is stored as uint+float (8 bytes) + // don't use sparseArray.getDim() because the dim is the max index of each rows + SparseFloatArray sparseArray = fieldData.getVectors().getSparseFloatVector(); + List> packData = new ArrayList<>(); + for (int i = 0; i < sparseArray.getContentsCount(); ++i) { + ByteString bs = sparseArray.getContents(i); + ByteBuffer bf = ByteBuffer.wrap(bs.toByteArray()); + bf.order(ByteOrder.LITTLE_ENDIAN); + SortedMap sparse = new TreeMap<>(); + long num = bf.limit()/8; // each uint+float pair is 8 bytes + for (long j = 0; j < num; j++) { + // here we convert an uint 4-bytes to a long value + ByteBuffer pBuf = ByteBuffer.allocate(Long.BYTES); + pBuf.order(ByteOrder.LITTLE_ENDIAN); + int offset = 8*(int)j; + byte[] aa = bf.array(); + for (int k = offset; k < offset + 4; k++) { + pBuf.put(aa[k]); // fill the first 4 bytes with the unit bytes + } + pBuf.putInt(0); // fill the last 4 bytes to zero + pBuf.rewind(); // reset position to head + long k = pBuf.getLong(); // this is the long value converted from the uint + + // here we get the float value as normal + bf.position(offset+4); // position offsets 4 bytes since they were converted to long + float v = bf.getFloat(); // this is the float value + sparse.put(k, v); + } + packData.add(sparse); + } + return packData; + } case Array: List> array = new ArrayList<>(); ArrayArray arrArray = fieldData.getScalars().getArrayData(); @@ -202,7 +274,7 @@ private List getScalarData(DataType dt, ScalarField scalar) { return protoStrList.subList(0, protoStrList.size()); case JSON: List dataList = scalar.getJsonData().getDataList(); - return dataList.stream().map(ByteString::toByteArray).collect(Collectors.toList()); + return dataList.stream().map(ByteString::toStringUtf8).collect(Collectors.toList()); default: return new ArrayList<>(); } @@ -257,6 +329,16 @@ public Object valueByIdx(int index) throws ParamException { private JSONObject parseObjectData(int index) { Object object = valueByIdx(index); - return JSONObject.parseObject(new String((byte[])object)); + return ParseJSONObject(object); + } + + public static JSONObject ParseJSONObject(Object object) { + if (object instanceof String) { + return JSONObject.parseObject((String)object); + } else if (object instanceof byte[]) { + return JSONObject.parseObject(new String((byte[]) object)); + } else { + throw new IllegalResponseException("Illegal type value for JSON parser"); + } } } diff --git a/src/main/java/io/milvus/response/SearchResultsWrapper.java b/src/main/java/io/milvus/response/SearchResultsWrapper.java index f8eadba89..264965358 100644 --- a/src/main/java/io/milvus/response/SearchResultsWrapper.java +++ b/src/main/java/io/milvus/response/SearchResultsWrapper.java @@ -187,7 +187,7 @@ public List getIDScore(int indexOfTarget) throws ParamException, Illega Object value = wrapper.valueByIdx((int)offset + n); if (wrapper.isJsonField()) { - idScores.get(n).put(field.getFieldName(), JSONObject.parseObject(new String((byte[])value))); + idScores.get(n).put(field.getFieldName(), FieldDataWrapper.ParseJSONObject(value)); } else { idScores.get(n).put(field.getFieldName(), value); } diff --git a/src/main/java/io/milvus/response/basic/RowRecordWrapper.java b/src/main/java/io/milvus/response/basic/RowRecordWrapper.java index 4080bd93e..32bf5e3d4 100644 --- a/src/main/java/io/milvus/response/basic/RowRecordWrapper.java +++ b/src/main/java/io/milvus/response/basic/RowRecordWrapper.java @@ -47,7 +47,7 @@ protected QueryResultsWrapper.RowRecord buildRowRecord(QueryResultsWrapper.RowRe } Object value = wrapper.valueByIdx((int)index); if (wrapper.isJsonField()) { - JSONObject jsonField = JSONObject.parseObject(new String((byte[])value)); + JSONObject jsonField = FieldDataWrapper.ParseJSONObject(value); if (wrapper.isDynamicField()) { for (String key: jsonField.keySet()) { record.put(key, jsonField.get(key)); diff --git a/src/main/java/io/milvus/v2/utils/DataUtils.java b/src/main/java/io/milvus/v2/utils/DataUtils.java index e4b780530..3933bc4c4 100644 --- a/src/main/java/io/milvus/v2/utils/DataUtils.java +++ b/src/main/java/io/milvus/v2/utils/DataUtils.java @@ -461,6 +461,7 @@ public static HashMap getTypeErrorMsg() { typeErrMsg.put(DataType.BinaryVector, "Type mismatch for field '%s': Binary vector field's value type must be ByteBuffer"); typeErrMsg.put(DataType.Float16Vector, "Type mismatch for field '%s': Float16 vector field's value type must be ByteBuffer"); typeErrMsg.put(DataType.BFloat16Vector, "Type mismatch for field '%s': BFloat16 vector field's value type must be ByteBuffer"); + typeErrMsg.put(DataType.SparseFloatVector, "Type mismatch for field '%s': SparseFloatVector vector field's value type must be SortedMap"); return typeErrMsg; } } diff --git a/src/test/java/io/milvus/client/MilvusClientDockerTest.java b/src/test/java/io/milvus/client/MilvusClientDockerTest.java index 5114323d1..d9a0ffaff 100644 --- a/src/test/java/io/milvus/client/MilvusClientDockerTest.java +++ b/src/test/java/io/milvus/client/MilvusClientDockerTest.java @@ -69,7 +69,7 @@ class MilvusClientDockerTest { private static final int dimension = 128; @Container - private static final MilvusContainer milvus = new MilvusContainer("milvusdb/milvus:v2.3.10"); + private static final MilvusContainer milvus = new MilvusContainer("milvusdb/milvus:v2.4.0-rc.1"); @BeforeAll public static void setUp() { @@ -142,6 +142,21 @@ private List generateBinaryVectors(int count) { } + private List> generateSparseVectors(int count) { + Random ran = new Random(); + List> vectors = new ArrayList<>(); + for (int n = 0; n < count; ++n) { + SortedMap sparse = new TreeMap<>(); + int dim = ran.nextInt(10) + 1; + for (int i = 0; i < dim; ++i) { + sparse.put((long)ran.nextInt(1000000), ran.nextFloat()); + } + vectors.add(sparse); + } + return vectors; + + } + @Test void testFloatVectors() { String randomCollectionName = generator.generate(10); @@ -665,6 +680,122 @@ void testBinaryVectors() { Assertions.assertEquals(R.Status.Success.getCode(), dropR.getStatus().intValue()); } + @Test + void testSparseVector() { + String randomCollectionName = generator.generate(10); + + // collection schema + String field1Name = "field1"; + String field2Name = "field2"; + FieldType field1 = FieldType.newBuilder() + .withPrimaryKey(true) + .withAutoID(false) + .withDataType(DataType.Int64) + .withName(field1Name) + .build(); + + FieldType field2 = FieldType.newBuilder() + .withDataType(DataType.SparseFloatVector) + .withName(field2Name) + .build(); + + // create collection + CreateCollectionParam createParam = CreateCollectionParam.newBuilder() + .withCollectionName(randomCollectionName) + .addFieldType(field1) + .addFieldType(field2) + .build(); + + R createR = client.createCollection(createParam); + Assertions.assertEquals(R.Status.Success.getCode(), createR.getStatus().intValue()); + + int rowCount = 10000; + List ids = new ArrayList<>(); + for (int i = 0; i < rowCount; i++) { + ids.add((long)i); + } + List> vectors = generateSparseVectors(rowCount); + List fields = new ArrayList<>(); + fields.add(new InsertParam.Field(field1Name, ids)); + fields.add(new InsertParam.Field(field2Name, vectors)); + + InsertParam insertParam = InsertParam.newBuilder() + .withCollectionName(randomCollectionName) + .withFields(fields) + .build(); + + R insertR = client.insert(insertParam); + Assertions.assertEquals(R.Status.Success.getCode(), insertR.getStatus().intValue()); + + // create index + CreateIndexParam indexParam = CreateIndexParam.newBuilder() + .withCollectionName(randomCollectionName) + .withFieldName(field2Name) + .withIndexType(IndexType.SPARSE_INVERTED_INDEX) + .withMetricType(MetricType.IP) + .withExtraParam("{\"drop_ratio_build\":0.2}") + .build(); + + R createIndexR = client.createIndex(indexParam); + Assertions.assertEquals(R.Status.Success.getCode(), createIndexR.getStatus().intValue()); + + // load collection + R loadR = client.loadCollection(LoadCollectionParam.newBuilder() + .withCollectionName(randomCollectionName) + .build()); + Assertions.assertEquals(R.Status.Success.getCode(), loadR.getStatus().intValue()); + + // pick some vectors to search with index + int nq = 5; + List targetVectorIDs = new ArrayList<>(); + List> targetVectors = new ArrayList<>(); + Random ran = new Random(); + int randomIndex = ran.nextInt(rowCount); + for (int i = randomIndex; i < randomIndex + nq; ++i) { + targetVectorIDs.add(ids.get(i)); + targetVectors.add(vectors.get(i)); + } + + System.out.println("Search target IDs:" + targetVectorIDs); + System.out.println("Search target vectors:" + targetVectors); + + int topK = 5; + SearchParam searchParam = SearchParam.newBuilder() + .withCollectionName(randomCollectionName) + .withMetricType(MetricType.IP) + .withTopK(topK) + .withVectors(targetVectors) + .withVectorFieldName(field2Name) + .addOutField(field2Name) + .withParams("{\"drop_ratio_search\":0.2}") + .build(); + + R searchR = client.search(searchParam); +// System.out.println(searchR); + Assertions.assertEquals(R.Status.Success.getCode(), searchR.getStatus().intValue()); + + // verify the search result + SearchResultsWrapper results = new SearchResultsWrapper(searchR.getData().getResults()); + for (int i = 0; i < targetVectors.size(); ++i) { + List scores = results.getIDScore(i); + System.out.println("The result of No." + i + " target vector(ID = " + targetVectorIDs.get(i) + "):"); + System.out.println(scores); + Assertions.assertEquals(targetVectorIDs.get(i).longValue(), scores.get(0).getLongID()); + + Object v = scores.get(0).get(field2Name); + SortedMap sparse = (SortedMap)v; + Assertions.assertTrue(sparse.equals(targetVectors.get(i))); + } + + // drop collection + DropCollectionParam dropParam = DropCollectionParam.newBuilder() + .withCollectionName(randomCollectionName) + .build(); + + R dropR = client.dropCollection(dropParam); + Assertions.assertEquals(R.Status.Success.getCode(), dropR.getStatus().intValue()); + } + @Test void testAsyncMethods() { String randomCollectionName = generator.generate(10);