Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flight stream impl #93

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
package org.opensearch.arrow.spi;

import org.opensearch.common.annotation.ExperimentalApi;
import org.opensearch.common.unit.TimeValue;
import org.opensearch.core.tasks.TaskId;

import java.io.Closeable;
Expand Down Expand Up @@ -95,6 +96,14 @@ public interface StreamProducer<VectorRoot, Allocator> extends Closeable {
*/
BatchedJob<VectorRoot> createJob(Allocator allocator);

/**
* Returns the deadline for the job execution.
* After this deadline, the job should be considered expired.
*
* @return TimeValue representing the job's deadline
*/
TimeValue getJobDeadline();

/**
* Provides an estimate of the total number of rows that will be produced.
*
Expand Down Expand Up @@ -150,6 +159,6 @@ interface FlushSignal {
*
* @param timeout Maximum milliseconds to wait
*/
void awaitConsumption(int timeout);
void awaitConsumption(TimeValue timeout);
}
}
4 changes: 4 additions & 0 deletions plugins/arrow-flight-rpc/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,10 @@ dependencies {
}
}

tasks.internalClusterTest {
jvmArgs += ["--add-opens", "java.base/java.nio=org.apache.arrow.memory.core,ALL-UNNAMED"]
}

tasks.named('test').configure {
jacoco {
excludes = ['org/apache/arrow/flight/**']
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,26 +9,39 @@
package org.opensearch.arrow.flight;

import org.apache.arrow.flight.CallOptions;
import org.apache.arrow.flight.FlightDescriptor;
import org.apache.arrow.flight.FlightInfo;
import org.apache.arrow.flight.FlightRuntimeException;
import org.apache.arrow.flight.OSFlightClient;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.vector.FieldVector;
import org.apache.arrow.vector.IntVector;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.opensearch.arrow.flight.bootstrap.FlightClientManager;
import org.opensearch.arrow.flight.bootstrap.FlightService;
import org.opensearch.arrow.flight.bootstrap.FlightStreamPlugin;
import org.opensearch.arrow.spi.StreamManager;
import org.opensearch.arrow.spi.StreamProducer;
import org.opensearch.arrow.spi.StreamReader;
import org.opensearch.arrow.spi.StreamTicket;
import org.opensearch.cluster.node.DiscoveryNode;
import org.opensearch.common.unit.TimeValue;
import org.opensearch.common.util.FeatureFlags;
import org.opensearch.plugins.Plugin;
import org.opensearch.test.FeatureFlagSetter;
import org.opensearch.test.OpenSearchIntegTestCase;
import org.junit.BeforeClass;

import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;

@OpenSearchIntegTestCase.ClusterScope(scope = OpenSearchIntegTestCase.Scope.SUITE, numDataNodes = 5)
@OpenSearchIntegTestCase.ClusterScope(scope = OpenSearchIntegTestCase.Scope.SUITE, numDataNodes = 3)
public class ArrowFlightServerIT extends OpenSearchIntegTestCase {

private FlightClientManager flightClientManager;

@BeforeClass
public static void setupFeatureFlags() {
FeatureFlagSetter.set(FeatureFlags.ARROW_STREAMS_SETTING.getKey());
Expand All @@ -44,16 +57,251 @@ public void setUp() throws Exception {
super.setUp();
ensureGreen();
Thread.sleep(1000);
FlightService flightService = internalCluster().getInstance(FlightService.class);
flightClientManager = flightService.getFlightClientManager();
}

public void testArrowFlightEndpoint() throws Exception {
public void testArrowFlightEndpoint() {
for (DiscoveryNode node : getClusterState().nodes()) {
FlightService flightService = internalCluster().getInstance(FlightService.class, node.getName());
FlightClientManager flightClientManager = flightService.getFlightClientManager();
OSFlightClient flightClient = flightClientManager.getFlightClient(node.getId()).get();
assertNotNull(flightClient);
flightClient.handshake(CallOptions.timeout(5000L, TimeUnit.MILLISECONDS));
}
}

public void testFlightStreamReader() throws Exception {
for (DiscoveryNode node : getClusterState().nodes()) {
StreamManager streamManagerRandomNode = getStreamManagerRandomNode();
StreamTicket ticket = streamManagerRandomNode.registerStream(getStreamProducer(), null);
StreamManager streamManagerCurrentNode = getStreamManager(node.getName());
// reader should be accessible from any node in the cluster due to the use ProxyStreamProducer
try (StreamReader<VectorSchemaRoot> reader = streamManagerCurrentNode.getStreamReader(ticket)) {
int totalBatches = 0;
assertNotNull(reader.getRoot().getVector("docID"));
while (reader.next()) {
IntVector docIDVector = (IntVector) reader.getRoot().getVector("docID");
assertEquals(10, docIDVector.getValueCount());
for (int i = 0; i < 10; i++) {
assertEquals(docIDVector.toString(), i + (totalBatches * 10L), docIDVector.get(i));
}
totalBatches++;
}
assertEquals(10, totalBatches);
}
}
}

public void testEarlyCancel() throws Exception {
DiscoveryNode previousNode = null;
for (DiscoveryNode node : getClusterState().nodes()) {
try (OSFlightClient flightClient = flightClientManager.getFlightClient(node.getId())) {
assertNotNull(flightClient);
flightClient.handshake(CallOptions.timeout(5000L, TimeUnit.MILLISECONDS));
if (previousNode == null) {
previousNode = node;
continue;
}
StreamManager streamManagerServer = getStreamManager(node.getName());
TestStreamProducer streamProducer = getStreamProducer();
StreamTicket ticket = streamManagerServer.registerStream(streamProducer, null);
StreamManager streamManagerClient = getStreamManager(previousNode.getName());

CountDownLatch readerComplete = new CountDownLatch(1);
AtomicReference<Exception> readerException = new AtomicReference<>();
AtomicReference<StreamReader<VectorSchemaRoot>> readerRef = new AtomicReference<>();

// Start reader thread
Thread readerThread = new Thread(() -> {
try {
StreamReader<VectorSchemaRoot> reader = streamManagerClient.getStreamReader(ticket);
readerRef.set(reader);
assertNotNull(reader.getRoot());
IntVector docIDVector = (IntVector) reader.getRoot().getVector("docID");
assertNotNull(docIDVector);

// Read first batch
reader.next();
assertEquals(10, docIDVector.getValueCount());
for (int i = 0; i < 10; i++) {
assertEquals(docIDVector.toString(), i, docIDVector.get(i));
}
reader.close();
} catch (Exception e) {
readerException.set(e);
} finally {
readerComplete.countDown();
}
}, "flight-reader-thread");

readerThread.start();
assertTrue("Reader thread did not complete in time", readerComplete.await(1, TimeUnit.SECONDS));

if (readerException.get() != null) {
throw readerException.get();
}

StreamReader<VectorSchemaRoot> reader = readerRef.get();

try {
reader.next();
fail("Expected FlightRuntimeException");
} catch (FlightRuntimeException e) {
assertEquals("CANCELLED", e.status().code().name());
assertEquals("Stream closed before end", e.getMessage());
reader.close();
}

// Wait for close to complete
// Due to https://github.com/grpc/grpc-java/issues/5882, there is a logic in FlightStream.java
// where it exhausts the stream on the server side before it is actually cancelled.
assertTrue(
"Timeout waiting for stream cancellation on server [" + node.getName() + "]",
streamProducer.waitForClose(2, TimeUnit.SECONDS)
);
previousNode = node;
}
}

public void testFlightStreamServerError() throws Exception {
DiscoveryNode previousNode = null;
for (DiscoveryNode node : getClusterState().nodes()) {
if (previousNode == null) {
previousNode = node;
continue;
}
StreamManager streamManagerServer = getStreamManager(node.getName());
TestStreamProducer streamProducer = getStreamProducer();
streamProducer.setProduceError(true);
StreamTicket ticket = streamManagerServer.registerStream(streamProducer, null);
StreamManager streamManagerClient = getStreamManager(previousNode.getName());
try (StreamReader<VectorSchemaRoot> reader = streamManagerClient.getStreamReader(ticket)) {
int totalBatches = 0;
assertNotNull(reader.getRoot().getVector("docID"));
try {
while (reader.next()) {
IntVector docIDVector = (IntVector) reader.getRoot().getVector("docID");
assertEquals(10, docIDVector.getValueCount());
totalBatches++;
}
fail("Expected FlightRuntimeException");
} catch (FlightRuntimeException e) {
assertEquals("INTERNAL", e.status().code().name());
assertEquals("There was an error servicing your request.", e.getMessage());
}
assertEquals(1, totalBatches);
}
previousNode = node;
}
}

public void testFlightGetInfo() throws Exception {
StreamTicket ticket = null;
for (DiscoveryNode node : getClusterState().nodes()) {
FlightService flightService = internalCluster().getInstance(FlightService.class, node.getName());
StreamManager streamManager = flightService.getStreamManager();
if (ticket == null) {
ticket = streamManager.registerStream(getStreamProducer(), null);
}
FlightClientManager flightClientManager = flightService.getFlightClientManager();
OSFlightClient flightClient = flightClientManager.getFlightClient(node.getId()).get();
assertNotNull(flightClient);
FlightDescriptor flightDescriptor = FlightDescriptor.command(ticket.toBytes());
FlightInfo flightInfo = flightClient.getInfo(flightDescriptor, CallOptions.timeout(5000L, TimeUnit.MILLISECONDS));
assertNotNull(flightInfo);
assertEquals(100, flightInfo.getRecords());
}
}

private StreamManager getStreamManager(String nodeName) {
FlightService flightService = internalCluster().getInstance(FlightService.class, nodeName);
return flightService.getStreamManager();
}

private StreamManager getStreamManagerRandomNode() {
FlightService flightService = internalCluster().getInstance(FlightService.class);
return flightService.getStreamManager();
}

private TestStreamProducer getStreamProducer() {
return new TestStreamProducer();
}

private static class TestStreamProducer implements StreamProducer<VectorSchemaRoot, BufferAllocator> {
volatile boolean isClosed = false;
private final CountDownLatch closeLatch = new CountDownLatch(1);
TimeValue deadline = TimeValue.timeValueSeconds(5);
private volatile boolean produceError = false;

public void setProduceError(boolean produceError) {
this.produceError = produceError;
}

TestStreamProducer() {}

VectorSchemaRoot root;

@Override
public VectorSchemaRoot createRoot(BufferAllocator allocator) {
IntVector docIDVector = new IntVector("docID", allocator);
FieldVector[] vectors = new FieldVector[] { docIDVector };
root = new VectorSchemaRoot(Arrays.asList(vectors));
return root;
}

@Override
public BatchedJob<VectorSchemaRoot> createJob(BufferAllocator allocator) {
return new BatchedJob<>() {
@Override
public void run(VectorSchemaRoot root, FlushSignal flushSignal) {
IntVector docIDVector = (IntVector) root.getVector("docID");
root.setRowCount(10);
for (int i = 0; i < 100; i++) {
docIDVector.setSafe(i % 10, i);
if ((i + 1) % 10 == 0) {
flushSignal.awaitConsumption(TimeValue.timeValueMillis(1000));
docIDVector.clear();
root.setRowCount(10);
if (produceError) {
throw new RuntimeException("Server error while producing batch");
}
}
}
}

@Override
public void onCancel() {
root.close();
isClosed = true;
}

@Override
public boolean isCancelled() {
return isClosed;
}
};
}

@Override
public TimeValue getJobDeadline() {
return deadline;
}

@Override
public int estimatedRowCount() {
return 100;
}

@Override
public String getAction() {
return "";
}

@Override
public void close() {
root.close();
closeLatch.countDown();
isClosed = true;
}

public boolean waitForClose(long timeout, TimeUnit unit) throws InterruptedException {
return closeLatch.await(timeout, unit);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/
package org.opensearch.arrow.flight.api;
package org.opensearch.arrow.flight.api.flightinfo;

import org.opensearch.rest.BaseRestHandler;
import org.opensearch.rest.RestRequest;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
* compatible open source license.
*/

package org.opensearch.arrow.flight.api;
package org.opensearch.arrow.flight.api.flightinfo;

import org.opensearch.action.support.nodes.BaseNodeResponse;
import org.opensearch.cluster.node.DiscoveryNode;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
* compatible open source license.
*/

package org.opensearch.arrow.flight.api;
package org.opensearch.arrow.flight.api.flightinfo;

import org.opensearch.action.ActionType;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
* compatible open source license.
*/

package org.opensearch.arrow.flight.api;
package org.opensearch.arrow.flight.api.flightinfo;

import org.opensearch.action.support.nodes.BaseNodesRequest;
import org.opensearch.core.common.io.stream.StreamInput;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
* compatible open source license.
*/

package org.opensearch.arrow.flight.api;
package org.opensearch.arrow.flight.api.flightinfo;

import org.opensearch.action.FailedNodeException;
import org.opensearch.action.support.nodes.BaseNodesResponse;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
* compatible open source license.
*/

package org.opensearch.arrow.flight.api;
package org.opensearch.arrow.flight.api.flightinfo;

import org.opensearch.action.FailedNodeException;
import org.opensearch.action.support.ActionFilters;
Expand Down
Loading
Loading