Skip to content

Commit

Permalink
feat(java): support take api for java module (#3316)
Browse files Browse the repository at this point in the history
  • Loading branch information
yanghua authored Dec 31, 2024
1 parent 898396d commit 8767c10
Show file tree
Hide file tree
Showing 6 changed files with 156 additions and 5 deletions.
58 changes: 56 additions & 2 deletions java/core/lance-jni/src/blocking_dataset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,16 @@ use arrow::datatypes::Schema;
use arrow::ffi::FFI_ArrowSchema;
use arrow::ffi_stream::ArrowArrayStreamReader;
use arrow::ffi_stream::FFI_ArrowArrayStream;
use arrow::ipc::writer::StreamWriter;
use arrow::record_batch::RecordBatchIterator;
use arrow_schema::DataType;
use jni::objects::{JMap, JString, JValue};
use jni::sys::jlong;
use jni::sys::{jboolean, jint};
use jni::sys::{jbyteArray, jlong};
use jni::{objects::JObject, JNIEnv};
use lance::dataset::builder::DatasetBuilder;
use lance::dataset::transaction::Operation;
use lance::dataset::{ColumnAlteration, Dataset, ReadParams, WriteParams};
use lance::dataset::{ColumnAlteration, Dataset, ProjectionRequest, ReadParams, WriteParams};
use lance::io::{ObjectStore, ObjectStoreParams};
use lance::table::format::Fragment;
use lance::table::format::Index;
Expand Down Expand Up @@ -683,6 +684,59 @@ fn inner_list_indexes<'local>(
Ok(array_list)
}

#[no_mangle]
pub extern "system" fn Java_com_lancedb_lance_Dataset_nativeTake(
mut env: JNIEnv,
java_dataset: JObject,
indices_obj: JObject, // List<Long>
columns_obj: JObject, // List<String>
) -> jbyteArray {
match inner_take(&mut env, java_dataset, indices_obj, columns_obj) {
Ok(byte_array) => byte_array,
Err(e) => {
let _ = env.throw_new("java/lang/RuntimeException", format!("{:?}", e));
std::ptr::null_mut()
}
}
}

fn inner_take(
env: &mut JNIEnv,
java_dataset: JObject,
indices_obj: JObject, // List<Long>
columns_obj: JObject, // List<String>
) -> Result<jbyteArray> {
let indices: Vec<i64> = env.get_longs(&indices_obj)?;
let indices_u64: Vec<u64> = indices.iter().map(|&x| x as u64).collect();
let indices_slice: &[u64] = &indices_u64;
let columns: Vec<String> = env.get_strings(&columns_obj)?;

let result = {
let dataset_guard =
unsafe { env.get_rust_field::<_, _, BlockingDataset>(java_dataset, NATIVE_DATASET) }?;
let dataset = &dataset_guard.inner;

let projection = ProjectionRequest::from_columns(columns, dataset.schema());

match RT.block_on(dataset.take(indices_slice, projection)) {
Ok(res) => res,
Err(e) => {
return Err(e.into());
}
}
};

let mut buffer = Vec::new();
{
let mut writer = StreamWriter::try_new(&mut buffer, &result.schema())?;
writer.write(&result)?;
writer.finish()?;
}

let byte_array = env.byte_array_from_slice(&buffer)?;
Ok(**byte_array)
}

//////////////////////////////
// Schema evolution Methods //
//////////////////////////////
Expand Down
24 changes: 24 additions & 0 deletions java/core/lance-jni/src/ffi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ pub trait JNIEnvExt {
/// Get integers from Java List<Integer> object.
fn get_integers(&mut self, obj: &JObject) -> Result<Vec<i32>>;

/// Get longs from Java List<Long> object.
fn get_longs(&mut self, obj: &JObject) -> Result<Vec<i64>>;

/// Get strings from Java List<String> object.
fn get_strings(&mut self, obj: &JObject) -> Result<Vec<String>>;

Expand Down Expand Up @@ -127,6 +130,18 @@ impl JNIEnvExt for JNIEnv<'_> {
Ok(results)
}

fn get_longs(&mut self, obj: &JObject) -> Result<Vec<i64>> {
let list = self.get_list(obj)?;
let mut iter = list.iter(self)?;
let mut results = Vec::with_capacity(list.size(self)? as usize);
while let Some(elem) = iter.next(self)? {
let long_obj = self.call_method(elem, "longValue", "()J", &[])?;
let long_value = long_obj.j()?;
results.push(long_value);
}
Ok(results)
}

fn get_strings(&mut self, obj: &JObject) -> Result<Vec<String>> {
let list = self.get_list(obj)?;
let mut iter = list.iter(self)?;
Expand Down Expand Up @@ -348,6 +363,15 @@ pub extern "system" fn Java_com_lancedb_lance_test_JniTestHelper_parseInts(
ok_or_throw_without_return!(env, env.get_integers(&list_obj));
}

#[no_mangle]
pub extern "system" fn Java_com_lancedb_lance_test_JniTestHelper_parseLongs(
mut env: JNIEnv,
_obj: JObject,
list_obj: JObject, // List<Long>
) {
ok_or_throw_without_return!(env, env.get_longs(&list_obj));
}

#[no_mangle]
pub extern "system" fn Java_com_lancedb_lance_test_JniTestHelper_parseIntsOpt(
mut env: JNIEnv,
Expand Down
32 changes: 32 additions & 0 deletions java/core/src/main/java/com/lancedb/lance/Dataset.java
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,15 @@
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.util.Preconditions;
import org.apache.arrow.vector.ipc.ArrowReader;
import org.apache.arrow.vector.ipc.ArrowStreamReader;
import org.apache.arrow.vector.types.pojo.Schema;

import java.io.ByteArrayInputStream;
import java.io.Closeable;
import java.io.IOException;
import java.nio.channels.Channels;
import java.nio.channels.ReadableByteChannel;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -315,6 +321,32 @@ public LanceScanner newScan(ScanOptions options) {
}
}

/**
* Select rows of data by index.
*
* @param indices the indices to take
* @param columns the columns to take
* @return an ArrowReader
*/
public ArrowReader take(List<Long> indices, List<String> columns) throws IOException {
Preconditions.checkArgument(nativeDatasetHandle != 0, "Dataset is closed");
try (LockManager.ReadLock readLock = lockManager.acquireReadLock()) {
byte[] arrowData = nativeTake(indices, columns);
ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(arrowData);
ReadableByteChannel readChannel = Channels.newChannel(byteArrayInputStream);
return new ArrowStreamReader(readChannel, allocator) {
@Override
public void close() throws IOException {
super.close();
readChannel.close();
byteArrayInputStream.close();
}
};
}
}

private native byte[] nativeTake(List<Long> indices, List<String> columns);

/**
* Gets the URI of the dataset.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,13 @@ public class JniTestHelper {
*/
public static native void parseInts(List<Integer> intsList);

/**
* JNI parse longs test.
*
* @param longsList the given list of longs
*/
public static native void parseLongs(List<Long> longsList);

/**
* JNI parse ints opts test.
*
Expand Down
35 changes: 32 additions & 3 deletions java/core/src/test/java/com/lancedb/lance/DatasetTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@

import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.ipc.ArrowReader;
import org.apache.arrow.vector.types.pojo.ArrowType;
import org.apache.arrow.vector.types.pojo.Field;
import org.apache.arrow.vector.types.pojo.Schema;
Expand All @@ -25,10 +27,9 @@

import java.io.IOException;
import java.net.URISyntaxException;
import java.nio.channels.ClosedChannelException;
import java.nio.file.Path;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.*;
import java.util.stream.Collectors;

import static org.junit.jupiter.api.Assertions.*;
Expand Down Expand Up @@ -307,4 +308,32 @@ void testDropPath() {
Dataset.drop(datasetPath, new HashMap<>());
}
}

@Test
void testTake() throws IOException, ClosedChannelException {
String testMethodName = new Object() {}.getClass().getEnclosingMethod().getName();
String datasetPath = tempDir.resolve(testMethodName).toString();
try (RootAllocator allocator = new RootAllocator(Long.MAX_VALUE)) {
TestUtils.SimpleTestDataset testDataset =
new TestUtils.SimpleTestDataset(allocator, datasetPath);
dataset = testDataset.createEmptyDataset();

try (Dataset dataset2 = testDataset.write(1, 5)) {
List<Long> indices = Arrays.asList(1L, 4L);
List<String> columns = Arrays.asList("id", "name");
try (ArrowReader reader = dataset2.take(indices, columns)) {
while (reader.loadNextBatch()) {
VectorSchemaRoot result = reader.getVectorSchemaRoot();
assertNotNull(result);
assertEquals(indices.size(), result.getRowCount());

for (int i = 0; i < indices.size(); i++) {
assertEquals(indices.get(i).intValue(), result.getVector("id").getObject(i));
assertNotNull(result.getVector("name").getObject(i));
}
}
}
}
}
}
}
5 changes: 5 additions & 0 deletions java/core/src/test/java/com/lancedb/lance/JNITest.java
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,11 @@ public void testInts() {
JniTestHelper.parseInts(Arrays.asList(1, 2, 3));
}

@Test
public void testLongs() {
JniTestHelper.parseLongs(Arrays.asList(1L, 2L, 3L, Long.MAX_VALUE));
}

@Test
public void testIntsOpt() {
JniTestHelper.parseIntsOpt(Optional.of(Arrays.asList(1, 2, 3)));
Expand Down

0 comments on commit 8767c10

Please sign in to comment.