diff --git a/libs/arrow-spi/src/main/java/org/opensearch/arrow/spi/StreamProducer.java b/libs/arrow-spi/src/main/java/org/opensearch/arrow/spi/StreamProducer.java index 6ca5b8944319b..ea2ba1b114348 100644 --- a/libs/arrow-spi/src/main/java/org/opensearch/arrow/spi/StreamProducer.java +++ b/libs/arrow-spi/src/main/java/org/opensearch/arrow/spi/StreamProducer.java @@ -9,6 +9,7 @@ package org.opensearch.arrow.spi; import org.opensearch.common.annotation.ExperimentalApi; +import org.opensearch.common.unit.TimeValue; import org.opensearch.core.tasks.TaskId; import java.io.Closeable; @@ -95,6 +96,14 @@ public interface StreamProducer extends Closeable { */ BatchedJob createJob(Allocator allocator); + /** + * Returns the deadline for the job execution. + * After this deadline, the job should be considered expired. + * + * @return TimeValue representing the job's deadline + */ + TimeValue getJobDeadline(); + /** * Provides an estimate of the total number of rows that will be produced. * @@ -150,6 +159,6 @@ interface FlushSignal { * * @param timeout Maximum milliseconds to wait */ - void awaitConsumption(int timeout); + void awaitConsumption(TimeValue timeout); } } diff --git a/plugins/arrow-flight-rpc/build.gradle b/plugins/arrow-flight-rpc/build.gradle index c949520d1d4fc..07129bcb062f5 100644 --- a/plugins/arrow-flight-rpc/build.gradle +++ b/plugins/arrow-flight-rpc/build.gradle @@ -70,6 +70,10 @@ dependencies { } } +tasks.internalClusterTest { + jvmArgs += ["--add-opens", "java.base/java.nio=org.apache.arrow.memory.core,ALL-UNNAMED"] +} + tasks.named('test').configure { jacoco { excludes = ['org/apache/arrow/flight/**'] diff --git a/plugins/arrow-flight-rpc/src/internalClusterTest/java/org/opensearch/arrow/flight/ArrowFlightServerIT.java b/plugins/arrow-flight-rpc/src/internalClusterTest/java/org/opensearch/arrow/flight/ArrowFlightServerIT.java index 63c816287b1ae..f3aae77bfabf0 100644 --- a/plugins/arrow-flight-rpc/src/internalClusterTest/java/org/opensearch/arrow/flight/ArrowFlightServerIT.java +++ b/plugins/arrow-flight-rpc/src/internalClusterTest/java/org/opensearch/arrow/flight/ArrowFlightServerIT.java @@ -9,26 +9,39 @@ package org.opensearch.arrow.flight; import org.apache.arrow.flight.CallOptions; +import org.apache.arrow.flight.FlightDescriptor; +import org.apache.arrow.flight.FlightInfo; +import org.apache.arrow.flight.FlightRuntimeException; import org.apache.arrow.flight.OSFlightClient; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.VectorSchemaRoot; import org.opensearch.arrow.flight.bootstrap.FlightClientManager; import org.opensearch.arrow.flight.bootstrap.FlightService; import org.opensearch.arrow.flight.bootstrap.FlightStreamPlugin; +import org.opensearch.arrow.spi.StreamManager; +import org.opensearch.arrow.spi.StreamProducer; +import org.opensearch.arrow.spi.StreamReader; +import org.opensearch.arrow.spi.StreamTicket; import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.common.unit.TimeValue; import org.opensearch.common.util.FeatureFlags; import org.opensearch.plugins.Plugin; import org.opensearch.test.FeatureFlagSetter; import org.opensearch.test.OpenSearchIntegTestCase; import org.junit.BeforeClass; +import java.util.Arrays; import java.util.Collection; import java.util.Collections; +import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; -@OpenSearchIntegTestCase.ClusterScope(scope = OpenSearchIntegTestCase.Scope.SUITE, numDataNodes = 5) +@OpenSearchIntegTestCase.ClusterScope(scope = OpenSearchIntegTestCase.Scope.SUITE, numDataNodes = 3) public class ArrowFlightServerIT extends OpenSearchIntegTestCase { - private FlightClientManager flightClientManager; - @BeforeClass public static void setupFeatureFlags() { FeatureFlagSetter.set(FeatureFlags.ARROW_STREAMS_SETTING.getKey()); @@ -44,16 +57,251 @@ public void setUp() throws Exception { super.setUp(); ensureGreen(); Thread.sleep(1000); - FlightService flightService = internalCluster().getInstance(FlightService.class); - flightClientManager = flightService.getFlightClientManager(); } - public void testArrowFlightEndpoint() throws Exception { + public void testArrowFlightEndpoint() { + for (DiscoveryNode node : getClusterState().nodes()) { + FlightService flightService = internalCluster().getInstance(FlightService.class, node.getName()); + FlightClientManager flightClientManager = flightService.getFlightClientManager(); + OSFlightClient flightClient = flightClientManager.getFlightClient(node.getId()).get(); + assertNotNull(flightClient); + flightClient.handshake(CallOptions.timeout(5000L, TimeUnit.MILLISECONDS)); + } + } + + public void testFlightStreamReader() throws Exception { + for (DiscoveryNode node : getClusterState().nodes()) { + StreamManager streamManagerRandomNode = getStreamManagerRandomNode(); + StreamTicket ticket = streamManagerRandomNode.registerStream(getStreamProducer(), null); + StreamManager streamManagerCurrentNode = getStreamManager(node.getName()); + // reader should be accessible from any node in the cluster due to the use ProxyStreamProducer + try (StreamReader reader = streamManagerCurrentNode.getStreamReader(ticket)) { + int totalBatches = 0; + assertNotNull(reader.getRoot().getVector("docID")); + while (reader.next()) { + IntVector docIDVector = (IntVector) reader.getRoot().getVector("docID"); + assertEquals(10, docIDVector.getValueCount()); + for (int i = 0; i < 10; i++) { + assertEquals(docIDVector.toString(), i + (totalBatches * 10L), docIDVector.get(i)); + } + totalBatches++; + } + assertEquals(10, totalBatches); + } + } + } + + public void testEarlyCancel() throws Exception { + DiscoveryNode previousNode = null; for (DiscoveryNode node : getClusterState().nodes()) { - try (OSFlightClient flightClient = flightClientManager.getFlightClient(node.getId())) { - assertNotNull(flightClient); - flightClient.handshake(CallOptions.timeout(5000L, TimeUnit.MILLISECONDS)); + if (previousNode == null) { + previousNode = node; + continue; } + StreamManager streamManagerServer = getStreamManager(node.getName()); + TestStreamProducer streamProducer = getStreamProducer(); + StreamTicket ticket = streamManagerServer.registerStream(streamProducer, null); + StreamManager streamManagerClient = getStreamManager(previousNode.getName()); + + CountDownLatch readerComplete = new CountDownLatch(1); + AtomicReference readerException = new AtomicReference<>(); + AtomicReference> readerRef = new AtomicReference<>(); + + // Start reader thread + Thread readerThread = new Thread(() -> { + try { + StreamReader reader = streamManagerClient.getStreamReader(ticket); + readerRef.set(reader); + assertNotNull(reader.getRoot()); + IntVector docIDVector = (IntVector) reader.getRoot().getVector("docID"); + assertNotNull(docIDVector); + + // Read first batch + reader.next(); + assertEquals(10, docIDVector.getValueCount()); + for (int i = 0; i < 10; i++) { + assertEquals(docIDVector.toString(), i, docIDVector.get(i)); + } + reader.close(); + } catch (Exception e) { + readerException.set(e); + } finally { + readerComplete.countDown(); + } + }, "flight-reader-thread"); + + readerThread.start(); + assertTrue("Reader thread did not complete in time", readerComplete.await(1, TimeUnit.SECONDS)); + + if (readerException.get() != null) { + throw readerException.get(); + } + + StreamReader reader = readerRef.get(); + + try { + reader.next(); + fail("Expected FlightRuntimeException"); + } catch (FlightRuntimeException e) { + assertEquals("CANCELLED", e.status().code().name()); + assertEquals("Stream closed before end", e.getMessage()); + reader.close(); + } + + // Wait for close to complete + // Due to https://github.com/grpc/grpc-java/issues/5882, there is a logic in FlightStream.java + // where it exhausts the stream on the server side before it is actually cancelled. + assertTrue( + "Timeout waiting for stream cancellation on server [" + node.getName() + "]", + streamProducer.waitForClose(2, TimeUnit.SECONDS) + ); + previousNode = node; + } + } + + public void testFlightStreamServerError() throws Exception { + DiscoveryNode previousNode = null; + for (DiscoveryNode node : getClusterState().nodes()) { + if (previousNode == null) { + previousNode = node; + continue; + } + StreamManager streamManagerServer = getStreamManager(node.getName()); + TestStreamProducer streamProducer = getStreamProducer(); + streamProducer.setProduceError(true); + StreamTicket ticket = streamManagerServer.registerStream(streamProducer, null); + StreamManager streamManagerClient = getStreamManager(previousNode.getName()); + try (StreamReader reader = streamManagerClient.getStreamReader(ticket)) { + int totalBatches = 0; + assertNotNull(reader.getRoot().getVector("docID")); + try { + while (reader.next()) { + IntVector docIDVector = (IntVector) reader.getRoot().getVector("docID"); + assertEquals(10, docIDVector.getValueCount()); + totalBatches++; + } + fail("Expected FlightRuntimeException"); + } catch (FlightRuntimeException e) { + assertEquals("INTERNAL", e.status().code().name()); + assertEquals("There was an error servicing your request.", e.getMessage()); + } + assertEquals(1, totalBatches); + } + previousNode = node; + } + } + + public void testFlightGetInfo() throws Exception { + StreamTicket ticket = null; + for (DiscoveryNode node : getClusterState().nodes()) { + FlightService flightService = internalCluster().getInstance(FlightService.class, node.getName()); + StreamManager streamManager = flightService.getStreamManager(); + if (ticket == null) { + ticket = streamManager.registerStream(getStreamProducer(), null); + } + FlightClientManager flightClientManager = flightService.getFlightClientManager(); + OSFlightClient flightClient = flightClientManager.getFlightClient(node.getId()).get(); + assertNotNull(flightClient); + FlightDescriptor flightDescriptor = FlightDescriptor.command(ticket.toBytes()); + FlightInfo flightInfo = flightClient.getInfo(flightDescriptor, CallOptions.timeout(5000L, TimeUnit.MILLISECONDS)); + assertNotNull(flightInfo); + assertEquals(100, flightInfo.getRecords()); + } + } + + private StreamManager getStreamManager(String nodeName) { + FlightService flightService = internalCluster().getInstance(FlightService.class, nodeName); + return flightService.getStreamManager(); + } + + private StreamManager getStreamManagerRandomNode() { + FlightService flightService = internalCluster().getInstance(FlightService.class); + return flightService.getStreamManager(); + } + + private TestStreamProducer getStreamProducer() { + return new TestStreamProducer(); + } + + private static class TestStreamProducer implements StreamProducer { + volatile boolean isClosed = false; + private final CountDownLatch closeLatch = new CountDownLatch(1); + TimeValue deadline = TimeValue.timeValueSeconds(5); + private volatile boolean produceError = false; + + public void setProduceError(boolean produceError) { + this.produceError = produceError; + } + + TestStreamProducer() {} + + VectorSchemaRoot root; + + @Override + public VectorSchemaRoot createRoot(BufferAllocator allocator) { + IntVector docIDVector = new IntVector("docID", allocator); + FieldVector[] vectors = new FieldVector[] { docIDVector }; + root = new VectorSchemaRoot(Arrays.asList(vectors)); + return root; + } + + @Override + public BatchedJob createJob(BufferAllocator allocator) { + return new BatchedJob<>() { + @Override + public void run(VectorSchemaRoot root, FlushSignal flushSignal) { + IntVector docIDVector = (IntVector) root.getVector("docID"); + root.setRowCount(10); + for (int i = 0; i < 100; i++) { + docIDVector.setSafe(i % 10, i); + if ((i + 1) % 10 == 0) { + flushSignal.awaitConsumption(TimeValue.timeValueMillis(1000)); + docIDVector.clear(); + root.setRowCount(10); + if (produceError) { + throw new RuntimeException("Server error while producing batch"); + } + } + } + } + + @Override + public void onCancel() { + root.close(); + isClosed = true; + } + + @Override + public boolean isCancelled() { + return isClosed; + } + }; + } + + @Override + public TimeValue getJobDeadline() { + return deadline; + } + + @Override + public int estimatedRowCount() { + return 100; + } + + @Override + public String getAction() { + return ""; + } + + @Override + public void close() { + root.close(); + closeLatch.countDown(); + isClosed = true; + } + + public boolean waitForClose(long timeout, TimeUnit unit) throws InterruptedException { + return closeLatch.await(timeout, unit); } } } diff --git a/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/api/FlightServerInfoAction.java b/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/api/flightinfo/FlightServerInfoAction.java similarity index 97% rename from plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/api/FlightServerInfoAction.java rename to plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/api/flightinfo/FlightServerInfoAction.java index 529bee72c708d..c988090081266 100644 --- a/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/api/FlightServerInfoAction.java +++ b/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/api/flightinfo/FlightServerInfoAction.java @@ -5,7 +5,7 @@ * this file be licensed under the Apache-2.0 license or a * compatible open source license. */ -package org.opensearch.arrow.flight.api; +package org.opensearch.arrow.flight.api.flightinfo; import org.opensearch.rest.BaseRestHandler; import org.opensearch.rest.RestRequest; diff --git a/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/api/NodeFlightInfo.java b/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/api/flightinfo/NodeFlightInfo.java similarity index 98% rename from plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/api/NodeFlightInfo.java rename to plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/api/flightinfo/NodeFlightInfo.java index e804b0c518523..23163bfac8c2e 100644 --- a/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/api/NodeFlightInfo.java +++ b/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/api/flightinfo/NodeFlightInfo.java @@ -6,7 +6,7 @@ * compatible open source license. */ -package org.opensearch.arrow.flight.api; +package org.opensearch.arrow.flight.api.flightinfo; import org.opensearch.action.support.nodes.BaseNodeResponse; import org.opensearch.cluster.node.DiscoveryNode; diff --git a/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/api/NodesFlightInfoAction.java b/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/api/flightinfo/NodesFlightInfoAction.java similarity index 93% rename from plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/api/NodesFlightInfoAction.java rename to plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/api/flightinfo/NodesFlightInfoAction.java index 3148c58a1509d..3c3a9965459cb 100644 --- a/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/api/NodesFlightInfoAction.java +++ b/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/api/flightinfo/NodesFlightInfoAction.java @@ -6,7 +6,7 @@ * compatible open source license. */ -package org.opensearch.arrow.flight.api; +package org.opensearch.arrow.flight.api.flightinfo; import org.opensearch.action.ActionType; diff --git a/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/api/NodesFlightInfoRequest.java b/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/api/flightinfo/NodesFlightInfoRequest.java similarity index 97% rename from plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/api/NodesFlightInfoRequest.java rename to plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/api/flightinfo/NodesFlightInfoRequest.java index 1b707f461819c..43bf38a096b57 100644 --- a/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/api/NodesFlightInfoRequest.java +++ b/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/api/flightinfo/NodesFlightInfoRequest.java @@ -6,7 +6,7 @@ * compatible open source license. */ -package org.opensearch.arrow.flight.api; +package org.opensearch.arrow.flight.api.flightinfo; import org.opensearch.action.support.nodes.BaseNodesRequest; import org.opensearch.core.common.io.stream.StreamInput; diff --git a/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/api/NodesFlightInfoResponse.java b/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/api/flightinfo/NodesFlightInfoResponse.java similarity index 98% rename from plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/api/NodesFlightInfoResponse.java rename to plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/api/flightinfo/NodesFlightInfoResponse.java index 721cd631924bd..805aa188ce37a 100644 --- a/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/api/NodesFlightInfoResponse.java +++ b/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/api/flightinfo/NodesFlightInfoResponse.java @@ -6,7 +6,7 @@ * compatible open source license. */ -package org.opensearch.arrow.flight.api; +package org.opensearch.arrow.flight.api.flightinfo; import org.opensearch.action.FailedNodeException; import org.opensearch.action.support.nodes.BaseNodesResponse; diff --git a/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/api/TransportNodesFlightInfoAction.java b/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/api/flightinfo/TransportNodesFlightInfoAction.java similarity index 98% rename from plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/api/TransportNodesFlightInfoAction.java rename to plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/api/flightinfo/TransportNodesFlightInfoAction.java index d4722e20d1f84..51f4cc05b8001 100644 --- a/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/api/TransportNodesFlightInfoAction.java +++ b/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/api/flightinfo/TransportNodesFlightInfoAction.java @@ -6,7 +6,7 @@ * compatible open source license. */ -package org.opensearch.arrow.flight.api; +package org.opensearch.arrow.flight.api.flightinfo; import org.opensearch.action.FailedNodeException; import org.opensearch.action.support.ActionFilters; diff --git a/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/api/package-info.java b/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/api/flightinfo/package-info.java similarity index 83% rename from plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/api/package-info.java rename to plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/api/flightinfo/package-info.java index d89ec87f9a51e..19dde32f32e8f 100644 --- a/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/api/package-info.java +++ b/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/api/flightinfo/package-info.java @@ -9,4 +9,4 @@ /** * Action to retrieve flight info from nodes */ -package org.opensearch.arrow.flight.api; +package org.opensearch.arrow.flight.api.flightinfo; diff --git a/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/bootstrap/FlightClientManager.java b/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/bootstrap/FlightClientManager.java index db02b9681e5c5..c0ccc9e9f55e8 100644 --- a/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/bootstrap/FlightClientManager.java +++ b/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/bootstrap/FlightClientManager.java @@ -14,10 +14,10 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.opensearch.Version; -import org.opensearch.arrow.flight.api.NodeFlightInfo; -import org.opensearch.arrow.flight.api.NodesFlightInfoAction; -import org.opensearch.arrow.flight.api.NodesFlightInfoRequest; -import org.opensearch.arrow.flight.api.NodesFlightInfoResponse; +import org.opensearch.arrow.flight.api.flightinfo.NodeFlightInfo; +import org.opensearch.arrow.flight.api.flightinfo.NodesFlightInfoAction; +import org.opensearch.arrow.flight.api.flightinfo.NodesFlightInfoRequest; +import org.opensearch.arrow.flight.api.flightinfo.NodesFlightInfoResponse; import org.opensearch.arrow.flight.bootstrap.tls.SslContextProvider; import org.opensearch.cluster.ClusterChangedEvent; import org.opensearch.cluster.ClusterStateListener; @@ -31,12 +31,15 @@ import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.client.Client; +import java.util.Collection; import java.util.Map; import java.util.Objects; -import java.util.Set; +import java.util.Optional; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ExecutorService; +import java.util.concurrent.TimeUnit; import io.netty.channel.EventLoopGroup; @@ -58,6 +61,7 @@ public class FlightClientManager implements ClusterStateListener, AutoCloseable private final ClientConfiguration clientConfig; private final Map flightClients = new ConcurrentHashMap<>(); private final Client client; + private static final long CLIENT_BUILD_TIMEOUT_MS = TimeUnit.MINUTES.toMillis(1); /** * Creates a new FlightClientManager instance. @@ -89,60 +93,154 @@ public FlightClientManager( clusterService.addListener(this); } + /** + * Returns the location of a Flight client for a given node ID. + * + * @param nodeId The ID of the node for which to retrieve the location + * @return The Location of the Flight client for the specified node + */ + public Location getFlightClientLocation(String nodeId) { + ClientHolder clientHolder = flightClients.get(nodeId); + if (clientHolder != null && clientHolder.location != null) { + return clientHolder.location; + } + buildClientAsync(nodeId); + return null; + } + /** * Returns a Flight client for a given node ID. * * @param nodeId The ID of the node for which to retrieve the Flight client * @return An OpenSearchFlightClient instance for the specified node */ - public OSFlightClient getFlightClient(String nodeId) { - ClientHolder clientHolder = flightClients.getOrDefault(nodeId, null); - return clientHolder != null ? clientHolder.flightClient : null; + public Optional getFlightClient(String nodeId) { + if (nodeId == null || nodeId.isEmpty()) { + throw new IllegalArgumentException("Node ID cannot be null or empty"); + } + + ClientHolder holder = flightClients.get(nodeId); + + if (holder == null) { + buildClientAsync(nodeId); + return Optional.empty(); + } + + if (holder.state == BuildState.COMPLETE) { + return Optional.ofNullable(holder.flightClient); + } + + if (holder.isStale()) { + logger.warn("Detected stale building state for node [{}], triggering rebuild", nodeId); + if (flightClients.remove(nodeId, holder)) { + try { + holder.close(); + } catch (Exception e) { + logger.warn("Error closing stale client holder for node [{}]. {}", nodeId, e.getMessage()); + } + buildClientAsync(nodeId); + } + } + + return Optional.empty(); } /** - * Returns the location of a Flight client for a given node ID. - * - * @param nodeId The ID of the node for which to retrieve the location - * @return The Location of the Flight client for the specified node + * Represents the state and metadata of a Flight client */ - public Location getFlightClientLocation(String nodeId) { - ClientHolder clientHolder = flightClients.getOrDefault(nodeId, null); - return clientHolder != null ? clientHolder.location : null; + private record ClientHolder(OSFlightClient flightClient, Location location, long buildStartTime, BuildState state) + implements + AutoCloseable { + + private static ClientHolder building() { + return new ClientHolder(null, null, System.currentTimeMillis(), BuildState.BUILDING); + } + + private static ClientHolder complete(OSFlightClient client, Location location) { + return new ClientHolder(client, location, System.currentTimeMillis(), BuildState.COMPLETE); + } + + boolean isStale() { + return state == BuildState.BUILDING && (System.currentTimeMillis() - buildStartTime) > CLIENT_BUILD_TIMEOUT_MS; + } + + /** + * Closes the client holder and logs the operation + * @param nodeId The ID of the node this holder belongs to + * @param reason The reason for closing + */ + public void close(String nodeId, String reason) { + try { + if (flightClient != null) { + flightClient.close(); + } + if (state == BuildState.BUILDING) { + logger.info("Cleaned up building state for node [{}]: {}", nodeId, reason); + } else { + logger.info("Closed client for node [{}]: {}", nodeId, reason); + } + } catch (Exception e) { + logger.error("Failed to close client for node [{}] ({}): {}", nodeId, reason, e.getMessage()); + } + } + + @Override + public void close() throws Exception { + if (flightClient != null) { + flightClient.close(); + } + } + } + + private enum BuildState { + BUILDING, + COMPLETE } /** - * Builds a client for a given nodeId in asynchronous manner - * @param nodeId nodeId of the node to build client for + * Initiates async build of a flight client for the given node */ - public void buildClientAsync(String nodeId) { + void buildClientAsync(String nodeId) { + // Try to put a building placeholder + ClientHolder placeholder = ClientHolder.building(); + if (flightClients.putIfAbsent(nodeId, placeholder) != null) { + return; // Another thread is already handling this node + } + CompletableFuture locationFuture = new CompletableFuture<>(); locationFuture.thenAccept(location -> { - DiscoveryNode node = getNodeFromClusterState(nodeId); - buildClientAndAddToPool(location, node); + try { + DiscoveryNode node = getNodeFromClusterState(nodeId); + if (!isValidNode(node)) { + logger.warn("Node [{}] is not valid for client creation", nodeId); + flightClients.remove(nodeId, placeholder); + return; + } + + OSFlightClient flightClient = buildClient(location); + ClientHolder newHolder = ClientHolder.complete(flightClient, location); + + if (!flightClients.replace(nodeId, placeholder, newHolder)) { + // Something changed while we were building + logger.warn("Failed to store new client for node [{}], state changed during build", nodeId); + flightClient.close(); + } + } catch (Exception e) { + logger.error("Failed to build Flight client for node [{}]. {}", nodeId, e); + flightClients.remove(nodeId, placeholder); + throw new RuntimeException(e); + } }).exceptionally(throwable -> { - logger.error("Failed to get Flight server location for node: [{}] {}", nodeId, throwable); - throw new RuntimeException(throwable); + flightClients.remove(nodeId, placeholder); + logger.error("Failed to get Flight server location for node [{}] {}", nodeId, throwable); + throw new CompletionException(throwable); }); - requestNodeLocation(nodeId, locationFuture); - } - Map getClients() { - return flightClients; + requestNodeLocation(nodeId, locationFuture); } - private void buildClientAndAddToPool(Location location, DiscoveryNode node) { - if (!isValidNode(node)) { - logger.warn( - "Unable to build FlightClient for node [{}] with role [{}] on version [{}]", - node.getId(), - node.getRoles(), - node.getVersion() - ); - return; - } - OSFlightClient flightClient = buildClient(location); - flightClients.put(node.getId(), new ClientHolder(location, flightClient)); + Collection getClients() { + return flightClients.values(); } private void requestNodeLocation(String nodeId, CompletableFuture future) { @@ -194,15 +292,12 @@ private DiscoveryNode getNodeFromClusterState(String nodeId) { @Override public void close() throws Exception { for (ClientHolder clientHolder : flightClients.values()) { - clientHolder.flightClient.close(); + clientHolder.close(); } flightClients.clear(); grpcExecutor.shutdown(); } - private record ClientHolder(Location location, OSFlightClient flightClient) { - } - /** * Returns the ID of the local node in the cluster. * @@ -219,23 +314,79 @@ public String getLocalNodeId() { */ @Override public void clusterChanged(ClusterChangedEvent event) { - if (event.nodesChanged()) { - DiscoveryNodes nodes = event.state().nodes(); - flightClients.keySet().removeIf(nodeId -> !nodes.nodeExists(nodeId)); - for (DiscoveryNode node : nodes) { - if (!flightClients.containsKey(node.getId()) && isValidNode(node)) { - buildClientAsync(node.getId()); + if (!event.nodesChanged()) { + return; + } + + final DiscoveryNodes nodes = event.state().nodes(); + + cleanupStaleBuilding(); + removeStaleClients(nodes); + updateExistingClients(nodes); + } + + private void removeStaleClients(DiscoveryNodes nodes) { + flightClients.entrySet().removeIf(entry -> { + String nodeId = entry.getKey(); + ClientHolder holder = entry.getValue(); + + if (!nodes.nodeExists(nodeId)) { + holder.close(nodeId, "node no longer exists"); + return true; + } + + if (holder.state == BuildState.BUILDING && holder.isStale()) { + holder.close(nodeId, "client build state is stale"); + return true; + } + + return false; + }); + } + + /** + * Updates clients for existing nodes based on their validity + */ + private void updateExistingClients(DiscoveryNodes nodes) { + for (DiscoveryNode node : nodes) { + String nodeId = node.getId(); + + if (isValidNode(node)) { + ClientHolder existingHolder = flightClients.get(nodeId); + + if (existingHolder == null) { + buildClientAsync(nodeId); + } else if (existingHolder.state == BuildState.BUILDING && existingHolder.isStale()) { + if (flightClients.remove(nodeId, existingHolder)) { + existingHolder.close(nodeId, "rebuilding stale client"); + buildClientAsync(nodeId); + } + } + } else { + ClientHolder holder = flightClients.remove(nodeId); + if (holder != null) { + holder.close(nodeId, "node is no longer valid"); } } } } - private static boolean isValidNode(DiscoveryNode node) { - return node != null && !node.getVersion().before(MIN_SUPPORTED_VERSION) && FeatureFlags.isEnabled(ARROW_STREAMS_SETTING); + /** + * Cleans up any clients that are in a stale BUILDING state + */ + private void cleanupStaleBuilding() { + flightClients.entrySet().removeIf(entry -> { + ClientHolder holder = entry.getValue(); + if (holder.state == BuildState.BUILDING && holder.isStale()) { + holder.close(entry.getKey(), "cleaning up stale building state"); + return true; + } + return false; + }); } - private Set getCurrentClusterNodes() { - return Objects.requireNonNull(clientConfig.clusterService).state().nodes().getNodes().keySet(); + private static boolean isValidNode(DiscoveryNode node) { + return node != null && !node.getVersion().before(MIN_SUPPORTED_VERSION) && FeatureFlags.isEnabled(ARROW_STREAMS_SETTING); } @VisibleForTesting @@ -245,18 +396,5 @@ Map getFlightClients() { private record ClientConfiguration(BufferAllocator allocator, ClusterService clusterService, SslContextProvider sslContextProvider, EventLoopGroup workerELG, ExecutorService grpcExecutor) { - private ClientConfiguration( - BufferAllocator allocator, - ClusterService clusterService, - @Nullable SslContextProvider sslContextProvider, - EventLoopGroup workerELG, - ExecutorService grpcExecutor - ) { - this.allocator = allocator; - this.clusterService = clusterService; - this.sslContextProvider = sslContextProvider; - this.workerELG = workerELG; - this.grpcExecutor = grpcExecutor; - } } } diff --git a/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/bootstrap/FlightService.java b/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/bootstrap/FlightService.java index 7735fc3df73e0..1d75135eacc1c 100644 --- a/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/bootstrap/FlightService.java +++ b/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/bootstrap/FlightService.java @@ -8,7 +8,6 @@ package org.opensearch.arrow.flight.bootstrap; -import org.apache.arrow.flight.NoOpFlightProducer; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; import org.apache.arrow.util.AutoCloseables; @@ -17,6 +16,8 @@ import org.apache.logging.log4j.Logger; import org.opensearch.arrow.flight.bootstrap.tls.DefaultSslContextProvider; import org.opensearch.arrow.flight.bootstrap.tls.SslContextProvider; +import org.opensearch.arrow.flight.impl.BaseFlightProducer; +import org.opensearch.arrow.flight.impl.FlightStreamManager; import org.opensearch.arrow.spi.StreamManager; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.network.NetworkService; @@ -104,7 +105,7 @@ protected void doStart() { client ); initializeStreamManager(clientManager); - serverComponents.setFlightProducer(new NoOpFlightProducer()); + serverComponents.setFlightProducer(new BaseFlightProducer(clientManager, (FlightStreamManager) streamManager, allocator)); serverComponents.start(); } catch (Exception e) { @@ -149,7 +150,10 @@ SslContextProvider getSslContextProvider() { @Override protected void doStop() { try { - AutoCloseables.close(serverComponents, streamManager, clientManager, allocator); + AutoCloseables.close(serverComponents); + AutoCloseables.close(streamManager); + AutoCloseables.close(clientManager); + AutoCloseables.close(allocator); } catch (Exception e) { throw new RuntimeException(e); } @@ -165,6 +169,7 @@ protected void doClose() { } private void initializeStreamManager(FlightClientManager clientManager) { - streamManager = null; + streamManager = new FlightStreamManager(() -> allocator); + ((FlightStreamManager) streamManager).setClientManager(clientManager); } } diff --git a/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/bootstrap/FlightStreamPlugin.java b/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/bootstrap/FlightStreamPlugin.java index 582b4b15162c5..63136a32f6d37 100644 --- a/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/bootstrap/FlightStreamPlugin.java +++ b/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/bootstrap/FlightStreamPlugin.java @@ -8,9 +8,9 @@ package org.opensearch.arrow.flight.bootstrap; -import org.opensearch.arrow.flight.api.FlightServerInfoAction; -import org.opensearch.arrow.flight.api.NodesFlightInfoAction; -import org.opensearch.arrow.flight.api.TransportNodesFlightInfoAction; +import org.opensearch.arrow.flight.api.flightinfo.FlightServerInfoAction; +import org.opensearch.arrow.flight.api.flightinfo.NodesFlightInfoAction; +import org.opensearch.arrow.flight.api.flightinfo.TransportNodesFlightInfoAction; import org.opensearch.arrow.spi.StreamManager; import org.opensearch.cluster.metadata.IndexNameExpressionResolver; import org.opensearch.cluster.node.DiscoveryNode; diff --git a/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/bootstrap/ServerComponents.java b/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/bootstrap/ServerComponents.java index 179cd1050d931..6beb61d811b8d 100644 --- a/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/bootstrap/ServerComponents.java +++ b/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/bootstrap/ServerComponents.java @@ -268,7 +268,7 @@ private boolean startFlightServer(TransportAddress transportAddress) { return true; } catch (Exception e) { String errorMsg = "Failed to start Arrow Flight server at " + serverLocation; - logger.debug(errorMsg, e); + logger.error(errorMsg, e); return false; } } diff --git a/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/impl/BaseBackpressureStrategy.java b/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/impl/BaseBackpressureStrategy.java new file mode 100644 index 0000000000000..d06d2eda5f23f --- /dev/null +++ b/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/impl/BaseBackpressureStrategy.java @@ -0,0 +1,44 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.arrow.flight.impl; + +import org.apache.arrow.flight.BackpressureStrategy; + +/** + * Base class for backpressure strategy. + */ +public class BaseBackpressureStrategy extends BackpressureStrategy.CallbackBackpressureStrategy { + private final Runnable readyCallback; + private final Runnable cancelCallback; + + /** + * Constructor for BaseBackpressureStrategy. + * + * @param readyCallback Callback to execute when the listener is ready. + * @param cancelCallback Callback to execute when the listener is cancelled. + */ + BaseBackpressureStrategy(Runnable readyCallback, Runnable cancelCallback) { + this.readyCallback = readyCallback; + this.cancelCallback = cancelCallback; + } + + /** Callback to execute when the listener is ready. */ + protected void readyCallback() { + if (readyCallback != null) { + readyCallback.run(); + } + } + + /** Callback to execute when the listener is cancelled. */ + protected void cancelCallback() { + if (cancelCallback != null) { + cancelCallback.run(); + } + } +} diff --git a/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/impl/BaseFlightProducer.java b/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/impl/BaseFlightProducer.java new file mode 100644 index 0000000000000..9752c5b09d3cc --- /dev/null +++ b/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/impl/BaseFlightProducer.java @@ -0,0 +1,175 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.arrow.flight.impl; + +import org.apache.arrow.flight.BackpressureStrategy; +import org.apache.arrow.flight.CallStatus; +import org.apache.arrow.flight.FlightDescriptor; +import org.apache.arrow.flight.FlightEndpoint; +import org.apache.arrow.flight.FlightInfo; +import org.apache.arrow.flight.Location; +import org.apache.arrow.flight.NoOpFlightProducer; +import org.apache.arrow.flight.OSFlightClient; +import org.apache.arrow.flight.Ticket; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.arrow.flight.bootstrap.FlightClientManager; +import org.opensearch.arrow.spi.StreamProducer; +import org.opensearch.arrow.spi.StreamTicket; + +import java.util.Collections; +import java.util.Optional; + +/** + * BaseFlightProducer extends NoOpFlightProducer to provide stream management functionality + * for Arrow Flight in OpenSearch. This class handles the retrieval and streaming of data + * based on provided tickets, managing backpressure, and coordinating between the stream + * provider and the server stream listener. + */ +public class BaseFlightProducer extends NoOpFlightProducer { + private final FlightClientManager flightClientManager; + private final FlightStreamManager streamManager; + private final BufferAllocator allocator; + private static final Logger logger = LogManager.getLogger(BaseFlightProducer.class); + + /** + * Constructs a new BaseFlightProducer. + * + * @param flightClientManager The FlightClientManager to handle client connections. + * @param streamManager The StreamManager to handle stream operations, including + * retrieving and removing streams based on tickets. + * @param allocator The BufferAllocator for memory management in Arrow operations. + */ + public BaseFlightProducer(FlightClientManager flightClientManager, FlightStreamManager streamManager, BufferAllocator allocator) { + this.flightClientManager = flightClientManager; + this.streamManager = streamManager; + this.allocator = allocator; + } + + /** + * Handles the retrieval and streaming of data based on the provided ticket. + * This method orchestrates the entire process of setting up the stream, + * managing backpressure, and handling data flow to the client. + * + * @param context The call context (unused in this implementation) + * @param ticket The ticket containing stream information + * @param listener The server stream listener to handle the data flow + */ + @Override + public void getStream(CallContext context, Ticket ticket, ServerStreamListener listener) { + StreamTicket streamTicket = streamManager.getStreamTicketFactory().fromBytes(ticket.getBytes()); + Optional streamProducerHolder = Optional.empty(); + try { + if (streamTicket.getNodeId().equals(flightClientManager.getLocalNodeId())) { + streamProducerHolder = streamManager.removeStreamProducer(streamTicket); + } else { + Optional remoteClient = flightClientManager.getFlightClient(streamTicket.getNodeId()); + if (remoteClient.isEmpty()) { + listener.error( + CallStatus.UNAVAILABLE.withDescription("Either server is not up yet or node does not support Streams.").cause() + ); + return; + } + StreamProducer proxyProvider = new ProxyStreamProducer( + new FlightStreamReader(remoteClient.get().getStream(ticket)) + ); + streamProducerHolder = Optional.of(FlightStreamManager.StreamProducerHolder.create(proxyProvider, allocator)); + } + if (streamProducerHolder.isEmpty()) { + listener.error(CallStatus.NOT_FOUND.withDescription("Stream not found").toRuntimeException()); + return; + } + try (StreamProducer producer = streamProducerHolder.get().producer()) { + StreamProducer.BatchedJob batchedJob = producer.createJob(allocator); + if (context.isCancelled()) { + batchedJob.onCancel(); + listener.error(CallStatus.CANCELLED.cause()); + return; + } + listener.setOnCancelHandler(batchedJob::onCancel); + BackpressureStrategy backpressureStrategy = new BaseBackpressureStrategy(null, batchedJob::onCancel); + backpressureStrategy.register(listener); + StreamProducer.FlushSignal flushSignal = (timeout) -> { + BackpressureStrategy.WaitResult result = backpressureStrategy.waitForListener(timeout.millis()); + if (result.equals(BackpressureStrategy.WaitResult.READY)) { + listener.putNext(); + } else if (result.equals(BackpressureStrategy.WaitResult.TIMEOUT)) { + listener.error(CallStatus.TIMED_OUT.cause()); + throw new RuntimeException("Stream deadline exceeded for consumption"); + } else if (result.equals(BackpressureStrategy.WaitResult.CANCELLED)) { + batchedJob.onCancel(); + listener.error(CallStatus.CANCELLED.cause()); + throw new RuntimeException("Stream cancelled by client"); + } else if (result.equals(BackpressureStrategy.WaitResult.OTHER)) { + batchedJob.onCancel(); + listener.error(CallStatus.INTERNAL.toRuntimeException()); + throw new RuntimeException("Error while waiting for client: " + result); + } else { + batchedJob.onCancel(); + listener.error(CallStatus.INTERNAL.toRuntimeException()); + throw new RuntimeException("Error while waiting for client: " + result); + } + }; + try (VectorSchemaRoot root = streamProducerHolder.get().getRoot()) { + listener.start(root); + batchedJob.run(root, flushSignal); + } + listener.completed(); + } + } catch (Exception e) { + listener.error(CallStatus.INTERNAL.withDescription(e.getMessage()).withCause(e).cause()); + logger.error(e); + throw new RuntimeException(e); + } + } + + /** + * Retrieves FlightInfo for the given FlightDescriptor, handling both local and remote cases. + * + * @param context The call context + * @param descriptor The FlightDescriptor containing stream information + * @return FlightInfo for the requested stream + */ + @Override + public FlightInfo getFlightInfo(CallContext context, FlightDescriptor descriptor) { + // TODO: this api should only be used internally + StreamTicket streamTicket = streamManager.getStreamTicketFactory().fromBytes(descriptor.getCommand()); + if (streamTicket.getNodeId().equals(flightClientManager.getLocalNodeId())) { + Optional streamProducerHolder = streamManager.getStreamProducer(streamTicket); + if (streamProducerHolder.isEmpty()) { + throw CallStatus.NOT_FOUND.withDescription("FlightInfo not found").toRuntimeException(); + } + Location location = flightClientManager.getFlightClientLocation(streamTicket.getNodeId()); + if (location == null) { + throw CallStatus.UNAVAILABLE.withDescription("Internal error while determining location information from ticket.") + .toRuntimeException(); + } + FlightEndpoint endpoint = new FlightEndpoint(new Ticket(descriptor.getCommand()), location); + FlightInfo.Builder infoBuilder; + try { + infoBuilder = FlightInfo.builder( + streamProducerHolder.get().getRoot().getSchema(), + descriptor, + Collections.singletonList(endpoint) + ).setRecords(streamProducerHolder.get().producer().estimatedRowCount()); + } catch (Exception e) { + throw CallStatus.INTERNAL.withDescription("Internal error while creating VectorSchemaRoot.").toRuntimeException(); + } + return infoBuilder.build(); + } else { + Optional remoteClient = flightClientManager.getFlightClient(streamTicket.getNodeId()); + if (remoteClient.isEmpty()) { + throw CallStatus.UNAVAILABLE.withDescription("Client doesn't support Stream").toRuntimeException(); + } + return remoteClient.get().getInfo(descriptor); + } + } +} diff --git a/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/impl/FlightStreamManager.java b/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/impl/FlightStreamManager.java new file mode 100644 index 0000000000000..bbcdaf049e035 --- /dev/null +++ b/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/impl/FlightStreamManager.java @@ -0,0 +1,201 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.arrow.flight.impl; + +import org.apache.arrow.flight.FlightStream; +import org.apache.arrow.flight.OSFlightClient; +import org.apache.arrow.flight.Ticket; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.arrow.flight.bootstrap.FlightClientManager; +import org.opensearch.arrow.spi.StreamManager; +import org.opensearch.arrow.spi.StreamProducer; +import org.opensearch.arrow.spi.StreamReader; +import org.opensearch.arrow.spi.StreamTicket; +import org.opensearch.arrow.spi.StreamTicketFactory; +import org.opensearch.common.SetOnce; +import org.opensearch.common.cache.Cache; +import org.opensearch.common.cache.CacheBuilder; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.core.tasks.TaskId; + +import java.io.IOException; +import java.util.Objects; +import java.util.Optional; +import java.util.function.Supplier; + +/** + * FlightStreamManager is a concrete implementation of StreamManager that provides + * an abstraction layer for managing Arrow Flight streams in OpenSearch. + * It encapsulates the details of Flight client operations, allowing consumers to + * work with streams without direct exposure to Flight internals. + */ +public class FlightStreamManager implements StreamManager { + private static final Logger logger = LogManager.getLogger(FlightStreamManager.class); + + private FlightStreamTicketFactory ticketFactory; + private FlightClientManager clientManager; + private final Supplier allocatorSupplier; + private final Cache streamProducers; + // TODO read from setting + private static final TimeValue DEFAULT_CACHE_EXPIRE = TimeValue.timeValueMinutes(10); + private static final int MAX_WEIGHT = 1000; + + /** + * Holds a StreamProducer along with its metadata and resources + */ + public record StreamProducerHolder(StreamProducer producer, BufferAllocator allocator, + long creationTime, SetOnce root) { + public StreamProducerHolder { + Objects.requireNonNull(producer, "StreamProducer cannot be null"); + Objects.requireNonNull(allocator, "BufferAllocator cannot be null"); + } + + static StreamProducerHolder create(StreamProducer producer, BufferAllocator allocator) { + return new StreamProducerHolder(producer, allocator, System.currentTimeMillis(), new SetOnce<>()); + } + + boolean isExpired() { + return System.currentTimeMillis() - creationTime > producer.getJobDeadline().getMillis(); + } + + /** + * Gets the VectorSchemaRoot associated with the StreamProducer. + * If the root is not set, it creates a new one using the provided BufferAllocator. + */ + public VectorSchemaRoot getRoot() { + root.trySet(producer.createRoot(allocator)); + return root.get(); + } + } + + /** + * Constructs a new FlightStreamManager. + * @param allocatorSupplier The supplier for BufferAllocator instances used for memory management. + * This parameter is required to be non-null. + + */ + public FlightStreamManager(Supplier allocatorSupplier) { + this.allocatorSupplier = allocatorSupplier; + this.streamProducers = CacheBuilder.builder() + .setExpireAfterWrite(DEFAULT_CACHE_EXPIRE) + .setMaximumWeight(MAX_WEIGHT) + .build(); + } + + /** + * Sets the FlightClientManager for this FlightStreamManager. + * @param clientManager The FlightClientManager instance to use for Flight client operations. + * This parameter is required to be non-null. + */ + public void setClientManager(FlightClientManager clientManager) { + this.clientManager = clientManager; + this.ticketFactory = new FlightStreamTicketFactory(clientManager::getLocalNodeId); + } + + /** + * Registers a new stream producer with the StreamManager. + * @param provider The StreamProducer instance to register. + * @param parentTaskId The parent task ID associated with the stream. + * @return A StreamTicket representing the registered stream. + */ + @Override + @SuppressWarnings("unchecked") + public StreamTicket registerStream(StreamProducer provider, TaskId parentTaskId) { + Objects.requireNonNull(provider, "StreamProducer cannot be null"); + StreamTicket ticket = ticketFactory.newTicket(); + streamProducers.put( + ticket.getTicketId(), + StreamProducerHolder.create((StreamProducer) provider, allocatorSupplier.get()) + ); + return ticket; + } + + /** + * Retrieves a StreamReader for the given StreamTicket. + * @param ticket The StreamTicket representing the stream to retrieve. + * @return A StreamReader instance for the specified stream. + */ + @Override + @SuppressWarnings("unchecked") + public StreamReader getStreamReader(StreamTicket ticket) { + Optional flightClient = clientManager.getFlightClient(ticket.getNodeId()); + if (flightClient.isEmpty()) { + throw new RuntimeException("Flight client not found for node [" + ticket.getNodeId() + "]."); + } + FlightStream stream = flightClient.get().getStream(new Ticket(ticket.toBytes())); + return (StreamReader) new FlightStreamReader(stream); + } + + /** + * Retrieves the StreamTicketFactory used by this StreamManager. + * @return The StreamTicketFactory instance associated with this StreamManager. + */ + @Override + public StreamTicketFactory getStreamTicketFactory() { + return ticketFactory; + } + + /** + * Gets the StreamProducer associated with a ticket if it hasn't expired based on its deadline. + * + * @param ticket The StreamTicket identifying the stream + * @return Optional of StreamProducerHolder containing the producer if found and not expired + */ + public Optional getStreamProducer(StreamTicket ticket) { + Objects.requireNonNull(ticket, "StreamTicket cannot be null"); + StreamProducerHolder holder = streamProducers.get(ticket.getTicketId()); + if (holder != null) { + if (holder.isExpired()) { + removeStreamProducer(ticket); + return Optional.empty(); + } + return Optional.of(holder); + } + return Optional.empty(); + } + + /** + * Gets and removes the StreamProducer associated with a ticket. + * Ensure that close is called on the StreamProducer after use. + * @param ticket The StreamTicket identifying the stream + * @return Optional of StreamProducerHolder containing the producer if found + */ + public Optional removeStreamProducer(StreamTicket ticket) { + Objects.requireNonNull(ticket, "StreamTicket cannot be null"); + + String ticketId = ticket.getTicketId(); + StreamProducerHolder holder = streamProducers.get(ticketId); + + if (holder != null) { + streamProducers.invalidate(ticketId); + return Optional.of(holder); + } + return Optional.empty(); + } + + /** + * Closes the StreamManager and cancels all associated streams. + * This method should be called when the StreamManager is no longer needed to clean up resources. + * It is recommended to implement this method to cancel all threads and clear the streamManager queue. + */ + @Override + public void close() throws Exception { + streamProducers.values().forEach(holder -> { + try { + holder.producer().close(); + } catch (IOException e) { + logger.error("Error closing stream producer, this may cause memory leaks.", e); + } + }); + streamProducers.invalidateAll(); + } +} diff --git a/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/impl/FlightStreamReader.java b/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/impl/FlightStreamReader.java new file mode 100644 index 0000000000000..d9e366dca30e2 --- /dev/null +++ b/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/impl/FlightStreamReader.java @@ -0,0 +1,61 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.arrow.flight.impl; + +import org.apache.arrow.flight.FlightRuntimeException; +import org.apache.arrow.flight.FlightStream; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.opensearch.ExceptionsHelper; +import org.opensearch.arrow.spi.StreamReader; + +/** + * FlightStreamReader is a wrapper class that adapts the FlightStream interface + * to the StreamReader interface. + */ +public class FlightStreamReader implements StreamReader { + + private final FlightStream flightStream; + + /** + * Constructs a FlightStreamReader with the given FlightStream. + * + * @param flightStream The FlightStream to be adapted. + */ + public FlightStreamReader(FlightStream flightStream) { + this.flightStream = flightStream; + } + + /** + * Moves the flightStream to the next batch of data. + * @return true if there is a next batch of data, false otherwise. + * @throws FlightRuntimeException if an error occurs while advancing to the next batch like early termination of stream + */ + @Override + public boolean next() throws FlightRuntimeException { + return flightStream.next(); + } + + /** + * Returns the VectorSchemaRoot containing the current batch of data. + * @return The VectorSchemaRoot containing the current batch of data. + * @throws FlightRuntimeException if an error occurs while retrieving the root like early termination of stream + */ + @Override + public VectorSchemaRoot getRoot() throws FlightRuntimeException { + return flightStream.getRoot(); + } + + /** + * Closes the flightStream. + */ + @Override + public void close() { + ExceptionsHelper.catchAsRuntimeException(flightStream::close); + } +} diff --git a/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/impl/FlightStreamTicket.java b/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/impl/FlightStreamTicket.java new file mode 100644 index 0000000000000..baa9e79fec6a1 --- /dev/null +++ b/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/impl/FlightStreamTicket.java @@ -0,0 +1,111 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.arrow.flight.impl; + +import org.opensearch.arrow.spi.StreamTicket; + +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import java.util.Base64; +import java.util.Objects; + +class FlightStreamTicket implements StreamTicket { + private static final int MAX_TOTAL_SIZE = 4096; + private static final int MAX_ID_LENGTH = 256; + + private final String ticketID; + private final String nodeID; + + public FlightStreamTicket(String ticketID, String nodeID) { + this.ticketID = ticketID; + this.nodeID = nodeID; + } + + @Override + public String getTicketId() { + return ticketID; + } + + @Override + public String getNodeId() { + return nodeID; + } + + @Override + public byte[] toBytes() { + byte[] ticketIDBytes = ticketID.getBytes(StandardCharsets.UTF_8); + byte[] nodeIDBytes = nodeID.getBytes(StandardCharsets.UTF_8); + + if (ticketIDBytes.length > Short.MAX_VALUE || nodeIDBytes.length > Short.MAX_VALUE) { + throw new IllegalArgumentException("Field lengths exceed the maximum allowed size."); + } + ByteBuffer buffer = ByteBuffer.allocate(2 + ticketIDBytes.length + 2 + nodeIDBytes.length); + buffer.putShort((short) ticketIDBytes.length); + buffer.putShort((short) nodeIDBytes.length); + buffer.put(ticketIDBytes); + buffer.put(nodeIDBytes); + return Base64.getEncoder().encode(buffer.array()); + } + + static StreamTicket fromBytes(byte[] bytes) { + if (bytes == null || bytes.length < 4) { + throw new IllegalArgumentException("Invalid byte array input."); + } + + if (bytes.length > MAX_TOTAL_SIZE) { + throw new IllegalArgumentException("Input exceeds maximum allowed size"); + } + + ByteBuffer buffer = ByteBuffer.wrap(Base64.getDecoder().decode(bytes)); + + short ticketIDLength = buffer.getShort(); + if (ticketIDLength < 0 || ticketIDLength > MAX_ID_LENGTH) { + throw new IllegalArgumentException("Invalid ticketID length: " + ticketIDLength); + } + + short nodeIDLength = buffer.getShort(); + if (nodeIDLength < 0 || nodeIDLength > MAX_ID_LENGTH) { + throw new IllegalArgumentException("Invalid nodeID length: " + nodeIDLength); + } + + byte[] ticketIDBytes = new byte[ticketIDLength]; + if (buffer.remaining() < ticketIDLength) { + throw new IllegalArgumentException("Malformed byte array. Not enough data for TicketId."); + } + buffer.get(ticketIDBytes); + + byte[] nodeIDBytes = new byte[nodeIDLength]; + if (buffer.remaining() < nodeIDLength) { + throw new IllegalArgumentException("Malformed byte array. Not enough data for NodeId."); + } + buffer.get(nodeIDBytes); + + String ticketID = new String(ticketIDBytes, StandardCharsets.UTF_8); + String nodeID = new String(nodeIDBytes, StandardCharsets.UTF_8); + return new FlightStreamTicket(ticketID, nodeID); + } + + @Override + public int hashCode() { + return Objects.hash(ticketID, nodeID); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) return true; + if (obj == null || getClass() != obj.getClass()) return false; + FlightStreamTicket that = (FlightStreamTicket) obj; + return Objects.equals(ticketID, that.ticketID) && Objects.equals(nodeID, that.nodeID); + } + + @Override + public String toString() { + return "FlightStreamTicket{ticketID='" + ticketID + "', nodeID='" + nodeID + "'}"; + } +} diff --git a/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/impl/FlightStreamTicketFactory.java b/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/impl/FlightStreamTicketFactory.java new file mode 100644 index 0000000000000..473eb92cf2db3 --- /dev/null +++ b/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/impl/FlightStreamTicketFactory.java @@ -0,0 +1,60 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.arrow.flight.impl; + +import org.opensearch.arrow.spi.StreamTicket; +import org.opensearch.arrow.spi.StreamTicketFactory; +import org.opensearch.common.annotation.ExperimentalApi; + +import java.util.UUID; +import java.util.function.Supplier; + +/** + * Default implementation of StreamTicketFactory + */ +@ExperimentalApi +public class FlightStreamTicketFactory implements StreamTicketFactory { + + private final Supplier nodeId; + + /** + * Constructs a new DefaultStreamTicketFactory instance. + * + * @param nodeId A Supplier that provides the node ID for the StreamTicket + */ + public FlightStreamTicketFactory(Supplier nodeId) { + this.nodeId = nodeId; + } + + /** + * Creates a new StreamTicket with a unique ticket ID. + * + * @return A new StreamTicket instance + */ + @Override + public StreamTicket newTicket() { + return new FlightStreamTicket(generateUniqueTicket(), nodeId.get()); + } + + /** + * Deserializes a StreamTicket from its byte representation. + * + * @param bytes The byte array containing the serialized ticket data + * @return A StreamTicket instance reconstructed from the byte array + * @throws IllegalArgumentException if bytes is null or invalid + */ + @Override + public StreamTicket fromBytes(byte[] bytes) { + return FlightStreamTicket.fromBytes(bytes); + } + + private String generateUniqueTicket() { + return UUID.randomUUID().toString(); + } +} diff --git a/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/impl/ProxyStreamProducer.java b/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/impl/ProxyStreamProducer.java new file mode 100644 index 0000000000000..a97f4697571e0 --- /dev/null +++ b/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/impl/ProxyStreamProducer.java @@ -0,0 +1,123 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.arrow.flight.impl; + +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.opensearch.ExceptionsHelper; +import org.opensearch.arrow.spi.StreamProducer; +import org.opensearch.arrow.spi.StreamReader; +import org.opensearch.arrow.spi.StreamTicket; +import org.opensearch.common.unit.TimeValue; + +import java.util.concurrent.atomic.AtomicBoolean; + +/** + * ProxyStreamProvider acts as forward proxy for FlightStream. + * It creates a BatchedJob to handle the streaming of data from the remote FlightStream. + * This is useful when stream is not present locally and needs to be fetched from a node + * retrieved using {@link StreamTicket#getNodeId()} where it is present. + */ +public class ProxyStreamProducer implements StreamProducer { + + private final StreamReader remoteStream; + + /** + * Constructs a new ProxyStreamProducer instance. + * + * @param remoteStream The remote FlightStream to be proxied. + */ + public ProxyStreamProducer(StreamReader remoteStream) { + this.remoteStream = remoteStream; + } + + /** + * Creates a VectorSchemaRoot for the remote FlightStream. + * @param allocator The allocator to use for creating vectors + * @return A VectorSchemaRoot representing the schema of the remote FlightStream + */ + @Override + public VectorSchemaRoot createRoot(BufferAllocator allocator) { + return remoteStream.getRoot(); + } + + /** + * Creates a BatchedJob + * @param allocator The allocator to use for any additional memory allocations + */ + @Override + public BatchedJob createJob(BufferAllocator allocator) { + return new ProxyBatchedJob(remoteStream); + } + + /** + * Returns the deadline for the remote FlightStream. + * Since the stream is not present locally, the deadline is set to -1. It piggybacks on remote stream expiration + * @return The deadline for the remote FlightStream + */ + @Override + public TimeValue getJobDeadline() { + return TimeValue.MINUS_ONE; + } + + /** + * Provides an estimate of the total number of rows that will be produced. + */ + @Override + public int estimatedRowCount() { + // TODO get it from remote flight stream + return -1; + } + + /** + * Task action name + */ + @Override + public String getAction() { + // TODO get it from remote flight stream + return ""; + } + + /** + * Closes the remote FlightStream. + */ + @Override + public void close() { + ExceptionsHelper.catchAsRuntimeException(remoteStream::close); + } + + static class ProxyBatchedJob implements BatchedJob { + + private final StreamReader remoteStream; + private final AtomicBoolean isCancelled = new AtomicBoolean(false); + + ProxyBatchedJob(StreamReader remoteStream) { + this.remoteStream = remoteStream; + } + + @Override + public void run(VectorSchemaRoot root, FlushSignal flushSignal) { + while (!isCancelled.get() && remoteStream.next()) { + flushSignal.awaitConsumption(TimeValue.timeValueMillis(1000)); + } + } + + @Override + public void onCancel() { + isCancelled.set(true); + } + + @Override + public boolean isCancelled() { + // Proxy stream don't have any business logic to set this flag, + // they piggyback on remote stream getting cancelled. + return isCancelled.get(); + } + } +} diff --git a/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/impl/package-info.java b/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/impl/package-info.java new file mode 100644 index 0000000000000..90ca54b44a55d --- /dev/null +++ b/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/impl/package-info.java @@ -0,0 +1,12 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +/** + * Core components and implementations for OpenSearch Flight service, including base producers and consumers. + */ +package org.opensearch.arrow.flight.impl; diff --git a/plugins/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/FlightStreamPluginTests.java b/plugins/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/FlightStreamPluginTests.java index 6f93d792f9db4..a3f0d1ca99b25 100644 --- a/plugins/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/FlightStreamPluginTests.java +++ b/plugins/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/FlightStreamPluginTests.java @@ -8,8 +8,8 @@ package org.opensearch.arrow.flight; -import org.opensearch.arrow.flight.api.FlightServerInfoAction; -import org.opensearch.arrow.flight.api.NodesFlightInfoAction; +import org.opensearch.arrow.flight.api.flightinfo.FlightServerInfoAction; +import org.opensearch.arrow.flight.api.flightinfo.NodesFlightInfoAction; import org.opensearch.arrow.flight.bootstrap.FlightService; import org.opensearch.arrow.flight.bootstrap.FlightStreamPlugin; import org.opensearch.arrow.spi.StreamManager; diff --git a/plugins/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/api/FlightServerInfoActionTests.java b/plugins/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/api/flightinfo/FlightServerInfoActionTests.java similarity index 98% rename from plugins/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/api/FlightServerInfoActionTests.java rename to plugins/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/api/flightinfo/FlightServerInfoActionTests.java index 6cb75d4a93dbe..d3115fc745475 100644 --- a/plugins/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/api/FlightServerInfoActionTests.java +++ b/plugins/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/api/flightinfo/FlightServerInfoActionTests.java @@ -6,7 +6,7 @@ * compatible open source license. */ -package org.opensearch.arrow.flight.api; +package org.opensearch.arrow.flight.api.flightinfo; import org.opensearch.cluster.ClusterName; import org.opensearch.cluster.node.DiscoveryNode; diff --git a/plugins/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/api/NodeFlightInfoTests.java b/plugins/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/api/flightinfo/NodeFlightInfoTests.java similarity index 99% rename from plugins/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/api/NodeFlightInfoTests.java rename to plugins/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/api/flightinfo/NodeFlightInfoTests.java index 2f8d7deb06f3f..59e695313c16e 100644 --- a/plugins/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/api/NodeFlightInfoTests.java +++ b/plugins/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/api/flightinfo/NodeFlightInfoTests.java @@ -6,7 +6,7 @@ * compatible open source license. */ -package org.opensearch.arrow.flight.api; +package org.opensearch.arrow.flight.api.flightinfo; import org.opensearch.Version; import org.opensearch.cluster.node.DiscoveryNode; diff --git a/plugins/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/api/NodesFlightInfoRequestTests.java b/plugins/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/api/flightinfo/NodesFlightInfoRequestTests.java similarity index 96% rename from plugins/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/api/NodesFlightInfoRequestTests.java rename to plugins/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/api/flightinfo/NodesFlightInfoRequestTests.java index 756177423fe6f..ef8f88b78c3ee 100644 --- a/plugins/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/api/NodesFlightInfoRequestTests.java +++ b/plugins/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/api/flightinfo/NodesFlightInfoRequestTests.java @@ -6,7 +6,7 @@ * compatible open source license. */ -package org.opensearch.arrow.flight.api; +package org.opensearch.arrow.flight.api.flightinfo; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.core.common.io.stream.StreamInput; diff --git a/plugins/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/api/NodesFlightInfoResponseTests.java b/plugins/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/api/flightinfo/NodesFlightInfoResponseTests.java similarity index 99% rename from plugins/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/api/NodesFlightInfoResponseTests.java rename to plugins/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/api/flightinfo/NodesFlightInfoResponseTests.java index 49a6cc6bacf40..707a222fe381f 100644 --- a/plugins/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/api/NodesFlightInfoResponseTests.java +++ b/plugins/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/api/flightinfo/NodesFlightInfoResponseTests.java @@ -6,7 +6,7 @@ * compatible open source license. */ -package org.opensearch.arrow.flight.api; +package org.opensearch.arrow.flight.api.flightinfo; import org.opensearch.Version; import org.opensearch.action.FailedNodeException; diff --git a/plugins/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/api/TransportNodesFlightInfoActionTests.java b/plugins/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/api/flightinfo/TransportNodesFlightInfoActionTests.java similarity index 90% rename from plugins/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/api/TransportNodesFlightInfoActionTests.java rename to plugins/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/api/flightinfo/TransportNodesFlightInfoActionTests.java index 615c3905b135a..ca4c4bf0c28c8 100644 --- a/plugins/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/api/TransportNodesFlightInfoActionTests.java +++ b/plugins/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/api/flightinfo/TransportNodesFlightInfoActionTests.java @@ -1,10 +1,18 @@ -package org.opensearch.arrow.flight.api;/* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - */ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.arrow.flight.api.flightinfo;/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ import org.opensearch.Version; import org.opensearch.action.FailedNodeException; diff --git a/plugins/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/bootstrap/FlightClientManagerTests.java b/plugins/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/bootstrap/FlightClientManagerTests.java index e4abf4cf6038c..49f41a22b2ced 100644 --- a/plugins/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/bootstrap/FlightClientManagerTests.java +++ b/plugins/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/bootstrap/FlightClientManagerTests.java @@ -12,10 +12,10 @@ import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; import org.opensearch.Version; -import org.opensearch.arrow.flight.api.NodeFlightInfo; -import org.opensearch.arrow.flight.api.NodesFlightInfoAction; -import org.opensearch.arrow.flight.api.NodesFlightInfoRequest; -import org.opensearch.arrow.flight.api.NodesFlightInfoResponse; +import org.opensearch.arrow.flight.api.flightinfo.NodeFlightInfo; +import org.opensearch.arrow.flight.api.flightinfo.NodesFlightInfoAction; +import org.opensearch.arrow.flight.api.flightinfo.NodesFlightInfoRequest; +import org.opensearch.arrow.flight.api.flightinfo.NodesFlightInfoResponse; import org.opensearch.arrow.flight.bootstrap.tls.SslContextProvider; import org.opensearch.cluster.ClusterChangedEvent; import org.opensearch.cluster.ClusterName; @@ -44,6 +44,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.Set; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; @@ -107,8 +108,10 @@ public void setUp() throws Exception { clientManager.clusterChanged(event); assertBusy(() -> { assertEquals("Flight client isn't built in time limit", 2, clientManager.getClients().size()); - assertNotNull("local_node should exist", clientManager.getFlightClient("local_node")); - assertNotNull("remote_node should exist", clientManager.getFlightClient("remote_node")); + assertTrue("local_node should exist", clientManager.getFlightClient("local_node").isPresent()); + assertNotNull("local_node should exist", clientManager.getFlightClient("local_node").get()); + assertTrue("remote_node should exist", clientManager.getFlightClient("remote_node").isPresent()); + assertNotNull("remote_node should exist", clientManager.getFlightClient("remote_node").get()); }, 2, TimeUnit.SECONDS); } @@ -192,7 +195,7 @@ public void testGetFlightClientLocation() { } public void testGetFlightClientForNonExistentNode() throws Exception { - assertNull(clientManager.getFlightClient("non_existent_node")); + assertTrue(clientManager.getFlightClient("non_existent_node").isEmpty()); } public void testClusterChangedWithNodesChanged() throws Exception { @@ -215,7 +218,7 @@ public void testClusterChangedWithNodesChanged() throws Exception { for (DiscoveryNode node : newState.nodes()) { assertBusy( () -> { assertNotNull("Flight client isn't built in time limit", clientManager.getFlightClient(node.getId())); }, - 2, + 5, TimeUnit.SECONDS ); } @@ -237,7 +240,7 @@ public void testGetLocalNodeId() throws Exception { public void testCloseWithActiveClients() throws Exception { for (DiscoveryNode node : state.nodes()) { - OSFlightClient client = clientManager.getFlightClient(node.getId()); + OSFlightClient client = clientManager.getFlightClient(node.getId()).get(); assertNotNull(client); } @@ -266,7 +269,7 @@ public void testIncompatibleNodeVersion() throws Exception { when(clusterService.state()).thenReturn(oldVersionState); mockFlightInfoResponse(nodes, 0); - assertNull(clientManager.getFlightClient(oldVersionNode.getId())); + assertFalse(clientManager.getFlightClient(oldVersionNode.getId()).isPresent()); } public void testGetFlightClientLocationTimeout() throws Exception { @@ -286,7 +289,7 @@ public void testGetFlightClientLocationTimeout() throws Exception { ClusterChangedEvent event = new ClusterChangedEvent("test", newState, ClusterState.EMPTY_STATE); clientManager.clusterChanged(event); - assertNull(clientManager.getFlightClient(nodeId)); + assertFalse(clientManager.getFlightClient(nodeId).isPresent()); } public void testGetFlightClientLocationExecutionError() throws Exception { @@ -313,7 +316,7 @@ public void testGetFlightClientLocationExecutionError() throws Exception { ClusterChangedEvent event = new ClusterChangedEvent("test", newState, ClusterState.EMPTY_STATE); clientManager.clusterChanged(event); - assertNull(clientManager.getFlightClient(nodeId)); + assertFalse(clientManager.getFlightClient(nodeId).isPresent()); } public void testFailedClusterUpdateButSuccessfulDirectRequest() throws Exception { @@ -371,23 +374,23 @@ public void testFailedClusterUpdateButSuccessfulDirectRequest() throws Exception ClusterChangedEvent event = new ClusterChangedEvent("test", newState, ClusterState.EMPTY_STATE); clientManager.clusterChanged(event); - + assertBusy(() -> { assertFalse("first call should be invoked", firstCall.get()); }, 5, TimeUnit.SECONDS); // Verify that the client can still be created successfully on direct request clientManager.buildClientAsync(nodeId); assertBusy( () -> { assertNotNull("Flight client should be created successfully on direct request", clientManager.getFlightClient(nodeId)); }, - 2, + 5, TimeUnit.SECONDS ); - assertFalse("first call should be invoked", firstCall.get()); } private void validateNodes() { for (DiscoveryNode node : state.nodes()) { - OSFlightClient client = clientManager.getFlightClient(node.getId()); - assertNotNull("Flight client should be created for existing node", client); + Optional client = clientManager.getFlightClient(node.getId()); + assertTrue("Flight client should be created for node [" + node.getId() + "].", client.isPresent()); + assertNotNull("Flight client should be created for node [" + node.getId() + "].", client.get()); } } diff --git a/plugins/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/bootstrap/FlightServiceTests.java b/plugins/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/bootstrap/FlightServiceTests.java index fa20535384557..35badb7b452eb 100644 --- a/plugins/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/bootstrap/FlightServiceTests.java +++ b/plugins/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/bootstrap/FlightServiceTests.java @@ -93,7 +93,7 @@ public void testStartAndStop() throws Exception { testService.start(); testService.stop(); testService.start(); - assertNull(testService.getStreamManager()); + assertNotNull(testService.getStreamManager()); } } diff --git a/plugins/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/impl/BaseFlightProducerTests.java b/plugins/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/impl/BaseFlightProducerTests.java new file mode 100644 index 0000000000000..479b89127ced8 --- /dev/null +++ b/plugins/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/impl/BaseFlightProducerTests.java @@ -0,0 +1,463 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.arrow.flight.impl; + +import org.apache.arrow.flight.FlightProducer; +import org.apache.arrow.flight.Ticket; +import org.apache.arrow.memory.ArrowBuf; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.dictionary.DictionaryProvider; +import org.apache.arrow.vector.ipc.message.IpcOption; +import org.opensearch.arrow.flight.bootstrap.FlightClientManager; +import org.opensearch.arrow.spi.StreamProducer; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.common.util.FeatureFlags; +import org.opensearch.test.FeatureFlagSetter; +import org.opensearch.test.OpenSearchTestCase; + +import java.util.Optional; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public class BaseFlightProducerTests extends OpenSearchTestCase { + + private BaseFlightProducer baseFlightProducer; + private FlightStreamManager streamManager; + private StreamProducer streamProducer; + private StreamProducer.BatchedJob batchedJob; + private static final String LOCAL_NODE_ID = "localNodeId"; + private static final FlightClientManager flightClientManager = mock(FlightClientManager.class); + private final Ticket ticket = new Ticket((new FlightStreamTicket("test-ticket", LOCAL_NODE_ID)).toBytes()); + private BufferAllocator allocator; + + @Override + @SuppressWarnings("unchecked") + public void setUp() throws Exception { + super.setUp(); + FeatureFlagSetter.set(FeatureFlags.ARROW_STREAMS_SETTING.getKey()); + streamManager = mock(FlightStreamManager.class); + when(streamManager.getStreamTicketFactory()).thenReturn(new FlightStreamTicketFactory(() -> LOCAL_NODE_ID)); + when(flightClientManager.getLocalNodeId()).thenReturn(LOCAL_NODE_ID); + allocator = mock(BufferAllocator.class); + streamProducer = mock(StreamProducer.class); + batchedJob = mock(StreamProducer.BatchedJob.class); + baseFlightProducer = new BaseFlightProducer(flightClientManager, streamManager, allocator); + } + + private static class TestServerStreamListener implements FlightProducer.ServerStreamListener { + private final CountDownLatch completionLatch = new CountDownLatch(1); + private final AtomicInteger putNextCount = new AtomicInteger(0); + private final AtomicBoolean isCancelled = new AtomicBoolean(false); + private Throwable error; + private final AtomicBoolean dataConsumed = new AtomicBoolean(false); + private final AtomicBoolean ready = new AtomicBoolean(false); + private Runnable onReadyHandler; + private Runnable onCancelHandler; + + @Override + public void putNext() { + assertFalse(dataConsumed.get()); + putNextCount.incrementAndGet(); + dataConsumed.set(true); + } + + @Override + public boolean isReady() { + return ready.get(); + } + + public void setReady(boolean val) { + ready.set(val); + if (this.onReadyHandler != null) { + this.onReadyHandler.run(); + } + } + + @Override + public void start(VectorSchemaRoot root) { + // No-op for this test + } + + @Override + public void start(VectorSchemaRoot root, DictionaryProvider dictionaries, IpcOption option) {} + + @Override + public void putNext(ArrowBuf metadata) { + putNext(); + } + + @Override + public void putMetadata(ArrowBuf metadata) { + + } + + @Override + public void completed() { + completionLatch.countDown(); + } + + @Override + public void error(Throwable t) { + error = t; + completionLatch.countDown(); + } + + @Override + public boolean isCancelled() { + return isCancelled.get(); + } + + @Override + public void setOnReadyHandler(Runnable handler) { + this.onReadyHandler = handler; + } + + @Override + public void setOnCancelHandler(Runnable handler) { + this.onCancelHandler = handler; + } + + public void resetConsumptionLatch() { + dataConsumed.set(false); + } + + public boolean getDataConsumed() { + return dataConsumed.get(); + } + + public int getPutNextCount() { + return putNextCount.get(); + } + + public Throwable getError() { + return error; + } + + public void cancel() { + isCancelled.set(true); + if (this.onCancelHandler != null) { + this.onCancelHandler.run(); + } + } + } + + public void testGetStream_SuccessfulFlow() throws Exception { + final VectorSchemaRoot root = mock(VectorSchemaRoot.class); + when(streamManager.removeStreamProducer(any(FlightStreamTicket.class))).thenReturn( + Optional.of(FlightStreamManager.StreamProducerHolder.create(streamProducer, allocator)) + ); + when(streamProducer.createJob(any(BufferAllocator.class))).thenReturn(batchedJob); + when(streamProducer.createRoot(any(BufferAllocator.class))).thenReturn(root); + + AtomicInteger flushCount = new AtomicInteger(0); + TestServerStreamListener listener = new TestServerStreamListener(); + doAnswer(invocation -> { + StreamProducer.FlushSignal flushSignal = invocation.getArgument(1); + for (int i = 0; i < 3; i++) { + Thread clientThread = new Thread(() -> { + listener.setReady(false); + listener.setReady(true); + }); + listener.setReady(false); + clientThread.start(); + flushSignal.awaitConsumption(TimeValue.timeValueMillis(100)); + assertTrue(listener.getDataConsumed()); + flushCount.incrementAndGet(); + listener.resetConsumptionLatch(); + } + return null; + }).when(batchedJob).run(any(VectorSchemaRoot.class), any(StreamProducer.FlushSignal.class)); + baseFlightProducer.getStream(mock(FlightProducer.CallContext.class), ticket, listener); + + assertNull(listener.getError()); + assertEquals(3, listener.getPutNextCount()); + assertEquals(3, flushCount.get()); + + verify(streamManager).removeStreamProducer(any(FlightStreamTicket.class)); + verify(root).close(); + } + + public void testGetStream_WithSlowClient() throws Exception { + final VectorSchemaRoot root = mock(VectorSchemaRoot.class); + + when(streamManager.removeStreamProducer(any(FlightStreamTicket.class))).thenReturn( + Optional.of(FlightStreamManager.StreamProducerHolder.create(streamProducer, allocator)) + ); + when(streamProducer.createJob(any(BufferAllocator.class))).thenReturn(batchedJob); + when(streamProducer.createRoot(any(BufferAllocator.class))).thenReturn(root); + + AtomicInteger flushCount = new AtomicInteger(0); + TestServerStreamListener listener = new TestServerStreamListener(); + + doAnswer(invocation -> { + StreamProducer.FlushSignal flushSignal = invocation.getArgument(1); + for (int i = 0; i < 5; i++) { + Thread clientThread = new Thread(() -> { + try { + listener.setReady(false); + Thread.sleep(100); + listener.setReady(true); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + }); + listener.setReady(false); + clientThread.start(); + flushSignal.awaitConsumption(TimeValue.timeValueMillis(300)); // waiting for consumption for more than client thread sleep + assertTrue(listener.getDataConsumed()); + flushCount.incrementAndGet(); + listener.resetConsumptionLatch(); + } + return null; + }).when(batchedJob).run(any(), any()); + + baseFlightProducer.getStream(mock(FlightProducer.CallContext.class), ticket, listener); + + assertNull(listener.getError()); + assertEquals(5, listener.getPutNextCount()); + assertEquals(5, flushCount.get()); + + verify(streamManager).removeStreamProducer(any(FlightStreamTicket.class)); + verify(root).close(); + } + + public void testGetStream_WithSlowClientTimeout() throws Exception { + final VectorSchemaRoot root = mock(VectorSchemaRoot.class); + + when(streamManager.removeStreamProducer(any(FlightStreamTicket.class))).thenReturn( + Optional.of(FlightStreamManager.StreamProducerHolder.create(streamProducer, allocator)) + ); + when(streamProducer.createJob(any(BufferAllocator.class))).thenReturn(batchedJob); + when(streamProducer.createRoot(any(BufferAllocator.class))).thenReturn(root); + + AtomicInteger flushCount = new AtomicInteger(0); + TestServerStreamListener listener = new TestServerStreamListener(); + doAnswer(invocation -> { + StreamProducer.FlushSignal flushSignal = invocation.getArgument(1); + for (int i = 0; i < 5; i++) { + Thread clientThread = new Thread(() -> { + try { + listener.setReady(false); + Thread.sleep(400); + listener.setReady(true); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + }); + listener.setReady(false); + clientThread.start(); + flushSignal.awaitConsumption(TimeValue.timeValueMillis(100)); // waiting for consumption for less than client thread sleep + assertTrue(listener.getDataConsumed()); + flushCount.incrementAndGet(); + listener.resetConsumptionLatch(); + } + return null; + }).when(batchedJob).run(any(), any()); + + assertThrows(RuntimeException.class, () -> baseFlightProducer.getStream(mock(FlightProducer.CallContext.class), ticket, listener)); + + assertNotNull(listener.getError()); + assertEquals("Stream deadline exceeded for consumption", listener.getError().getMessage()); + assertEquals(0, listener.getPutNextCount()); + assertEquals(0, flushCount.get()); + + verify(streamManager).removeStreamProducer(any(FlightStreamTicket.class)); + verify(root).close(); + } + + public void testGetStream_WithClientCancel() throws Exception { + final VectorSchemaRoot root = mock(VectorSchemaRoot.class); + + when(streamManager.removeStreamProducer(any(FlightStreamTicket.class))).thenReturn( + Optional.of(FlightStreamManager.StreamProducerHolder.create(streamProducer, allocator)) + ); + when(streamProducer.createJob(any(BufferAllocator.class))).thenReturn(batchedJob); + when(streamProducer.createRoot(any(BufferAllocator.class))).thenReturn(root); + + AtomicInteger flushCount = new AtomicInteger(0); + TestServerStreamListener listener = new TestServerStreamListener(); + doAnswer(invocation -> { + StreamProducer.FlushSignal flushSignal = invocation.getArgument(1); + for (int i = 0; i < 5; i++) { + int finalI = i; + Thread clientThread = new Thread(() -> { + if (finalI == 4) { + listener.cancel(); + } else { + listener.setReady(false); + listener.setReady(true); + } + }); + listener.setReady(false); + clientThread.start(); + flushSignal.awaitConsumption(TimeValue.timeValueMillis(100)); // waiting for consumption for less than client thread sleep + assertTrue(listener.getDataConsumed()); + flushCount.incrementAndGet(); + listener.resetConsumptionLatch(); + } + return null; + }).when(batchedJob).run(any(), any()); + + assertThrows(RuntimeException.class, () -> baseFlightProducer.getStream(mock(FlightProducer.CallContext.class), ticket, listener)); + assertNotNull(listener.getError()); + assertEquals("Stream cancelled by client", listener.getError().getMessage()); + assertEquals(4, listener.getPutNextCount()); + assertEquals(4, flushCount.get()); + + verify(streamManager).removeStreamProducer(any(FlightStreamTicket.class)); + verify(root).close(); + } + + public void testGetStream_WithUnresponsiveClient() throws Exception { + final VectorSchemaRoot root = mock(VectorSchemaRoot.class); + + when(streamManager.removeStreamProducer(any(FlightStreamTicket.class))).thenReturn( + Optional.of(FlightStreamManager.StreamProducerHolder.create(streamProducer, allocator)) + ); + when(streamProducer.createJob(any(BufferAllocator.class))).thenReturn(batchedJob); + when(streamProducer.createRoot(any(BufferAllocator.class))).thenReturn(root); + + AtomicInteger flushCount = new AtomicInteger(0); + TestServerStreamListener listener = new TestServerStreamListener(); + doAnswer(invocation -> { + StreamProducer.FlushSignal flushSignal = invocation.getArgument(1); + for (int i = 0; i < 5; i++) { + Thread clientThread = new Thread(() -> { + listener.setReady(false); + // not setting ready to simulate unresponsive behaviour + }); + listener.setReady(false); + clientThread.start(); + flushSignal.awaitConsumption(TimeValue.timeValueMillis(100)); // waiting for consumption for less than client thread sleep + assertTrue(listener.getDataConsumed()); + flushCount.incrementAndGet(); + listener.resetConsumptionLatch(); + } + return null; + }).when(batchedJob).run(any(), any()); + + assertThrows(RuntimeException.class, () -> baseFlightProducer.getStream(mock(FlightProducer.CallContext.class), ticket, listener)); + + assertNotNull(listener.getError()); + assertEquals("Stream deadline exceeded for consumption", listener.getError().getMessage()); + assertEquals(0, listener.getPutNextCount()); + assertEquals(0, flushCount.get()); + + verify(streamManager).removeStreamProducer(any(FlightStreamTicket.class)); + verify(root).close(); + } + + public void testGetStream_WithServerBackpressure() throws Exception { + final VectorSchemaRoot root = mock(VectorSchemaRoot.class); + + when(streamManager.removeStreamProducer(any(FlightStreamTicket.class))).thenReturn( + Optional.of(FlightStreamManager.StreamProducerHolder.create(streamProducer, allocator)) + ); + when(streamProducer.createJob(any(BufferAllocator.class))).thenReturn(batchedJob); + when(streamProducer.createRoot(any(BufferAllocator.class))).thenReturn(root); + + TestServerStreamListener listener = new TestServerStreamListener(); + AtomicInteger flushCount = new AtomicInteger(0); + doAnswer(invocation -> { + StreamProducer.FlushSignal flushSignal = invocation.getArgument(1); + for (int i = 0; i < 5; i++) { + Thread clientThread = new Thread(() -> { + listener.setReady(false); + listener.setReady(true); + }); + listener.setReady(false); + clientThread.start(); + Thread.sleep(100); // simulating writer backpressure + flushSignal.awaitConsumption(TimeValue.timeValueMillis(100)); + assertTrue(listener.getDataConsumed()); + flushCount.incrementAndGet(); + listener.resetConsumptionLatch(); + } + return null; + }).when(batchedJob).run(any(VectorSchemaRoot.class), any(StreamProducer.FlushSignal.class)); + + baseFlightProducer.getStream(mock(FlightProducer.CallContext.class), ticket, listener); + + assertNull(listener.getError()); + assertEquals(5, listener.getPutNextCount()); + assertEquals(5, flushCount.get()); + + verify(streamManager).removeStreamProducer(any(FlightStreamTicket.class)); + verify(root).close(); + } + + public void testGetStream_WithServerError() throws Exception { + final VectorSchemaRoot root = mock(VectorSchemaRoot.class); + + when(streamManager.removeStreamProducer(any(FlightStreamTicket.class))).thenReturn( + Optional.of(FlightStreamManager.StreamProducerHolder.create(streamProducer, allocator)) + ); + when(streamProducer.createJob(any(BufferAllocator.class))).thenReturn(batchedJob); + when(streamProducer.createRoot(any(BufferAllocator.class))).thenReturn(root); + + TestServerStreamListener listener = new TestServerStreamListener(); + AtomicInteger flushCount = new AtomicInteger(0); + doAnswer(invocation -> { + StreamProducer.FlushSignal flushSignal = invocation.getArgument(1); + for (int i = 0; i < 5; i++) { + Thread clientThread = new Thread(() -> { + listener.setReady(false); + listener.setReady(true); + }); + listener.setReady(false); + clientThread.start(); + if (i == 4) { + throw new RuntimeException("Server error"); + } + flushSignal.awaitConsumption(TimeValue.timeValueMillis(100)); + assertTrue(listener.getDataConsumed()); + flushCount.incrementAndGet(); + listener.resetConsumptionLatch(); + } + return null; + }).when(batchedJob).run(any(VectorSchemaRoot.class), any(StreamProducer.FlushSignal.class)); + + assertThrows(RuntimeException.class, () -> baseFlightProducer.getStream(mock(FlightProducer.CallContext.class), ticket, listener)); + + assertNotNull(listener.getError()); + assertEquals("Server error", listener.getError().getMessage()); + assertEquals(4, listener.getPutNextCount()); + assertEquals(4, flushCount.get()); + + verify(streamManager).removeStreamProducer(any(FlightStreamTicket.class)); + verify(root).close(); + } + + public void testGetStream_StreamNotFound() throws Exception { + + when(streamManager.getStreamProducer(any(FlightStreamTicket.class))).thenReturn(null); + + TestServerStreamListener listener = new TestServerStreamListener(); + + baseFlightProducer.getStream(null, ticket, listener); + + assertNotNull(listener.getError()); + assertTrue(listener.getError().getMessage().contains("Stream not found")); + assertEquals(0, listener.getPutNextCount()); + + verify(streamManager).removeStreamProducer(any(FlightStreamTicket.class)); + } + + public void testProxyStreamProviderCreationWithDifferentNodeIDs() { + // TODO: proxy stream provider coverage + } +} diff --git a/plugins/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/impl/FlightStreamManagerTests.java b/plugins/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/impl/FlightStreamManagerTests.java new file mode 100644 index 0000000000000..1bc686fc446c7 --- /dev/null +++ b/plugins/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/impl/FlightStreamManagerTests.java @@ -0,0 +1,69 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.arrow.flight.impl; + +import org.apache.arrow.flight.FlightStream; +import org.apache.arrow.flight.OSFlightClient; +import org.apache.arrow.flight.Ticket; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.types.pojo.Schema; +import org.opensearch.arrow.flight.bootstrap.FlightClientManager; +import org.opensearch.arrow.spi.StreamReader; +import org.opensearch.arrow.spi.StreamTicket; +import org.opensearch.test.OpenSearchTestCase; + +import java.util.Collections; +import java.util.Optional; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public class FlightStreamManagerTests extends OpenSearchTestCase { + + private OSFlightClient flightClient; + private FlightStreamManager flightStreamManager; + private static final String NODE_ID = "testNodeId"; + private static final String TICKET_ID = "testTicketId"; + + @Override + public void setUp() throws Exception { + super.setUp(); + flightClient = mock(OSFlightClient.class); + FlightClientManager clientManager = mock(FlightClientManager.class); + when(clientManager.getFlightClient(NODE_ID)).thenReturn(Optional.of(flightClient)); + BufferAllocator allocator = mock(BufferAllocator.class); + flightStreamManager = new FlightStreamManager(() -> allocator); + flightStreamManager.setClientManager(clientManager); + } + + public void testGetStreamReader() throws Exception { + StreamTicket ticket = new FlightStreamTicket(TICKET_ID, NODE_ID); + FlightStream mockFlightStream = mock(FlightStream.class); + VectorSchemaRoot mockRoot = mock(VectorSchemaRoot.class); + when(flightClient.getStream(new Ticket(ticket.toBytes()))).thenReturn(mockFlightStream); + when(mockFlightStream.getRoot()).thenReturn(mockRoot); + when(mockRoot.getSchema()).thenReturn(new Schema(Collections.emptyList())); + + StreamReader streamReader = flightStreamManager.getStreamReader(ticket); + + assertNotNull(streamReader); + assertNotNull(streamReader.getRoot()); + assertEquals(new Schema(Collections.emptyList()), streamReader.getRoot().getSchema()); + verify(flightClient).getStream(new Ticket(ticket.toBytes())); + } + + public void testGetVectorSchemaRootWithException() { + StreamTicket ticket = new FlightStreamTicket(TICKET_ID, NODE_ID); + when(flightClient.getStream(new Ticket(ticket.toBytes()))).thenThrow(new RuntimeException("Test exception")); + + expectThrows(RuntimeException.class, () -> flightStreamManager.getStreamReader(ticket)); + } +} diff --git a/plugins/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/impl/FlightStreamReaderTests.java b/plugins/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/impl/FlightStreamReaderTests.java new file mode 100644 index 0000000000000..20e112dc730f6 --- /dev/null +++ b/plugins/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/impl/FlightStreamReaderTests.java @@ -0,0 +1,89 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.arrow.flight.impl; + +import org.apache.arrow.flight.FlightStream; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.FieldType; +import org.apache.arrow.vector.types.pojo.Schema; +import org.opensearch.arrow.flight.bootstrap.ServerConfig; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.FeatureFlags; +import org.opensearch.test.FeatureFlagSetter; +import org.opensearch.test.OpenSearchTestCase; + +import java.util.List; + +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public class FlightStreamReaderTests extends OpenSearchTestCase { + + private FlightStream mockFlightStream; + + private FlightStreamReader iterator; + private VectorSchemaRoot root; + private BufferAllocator allocator; + + @Override + public void setUp() throws Exception { + super.setUp(); + FeatureFlagSetter.set(FeatureFlags.ARROW_STREAMS_SETTING.getKey()); + ServerConfig.init(Settings.EMPTY); + mockFlightStream = mock(FlightStream.class); + allocator = new RootAllocator(100000); + Field field = new Field("id", FieldType.nullable(new ArrowType.Int(32, true)), null); + Schema schema = new Schema(List.of(field)); + root = VectorSchemaRoot.create(schema, allocator); + when(mockFlightStream.getRoot()).thenReturn(root); + iterator = new FlightStreamReader(mockFlightStream); + } + + @Override + public void tearDown() throws Exception { + super.tearDown(); + root.close(); + allocator.close(); + } + + public void testNext_ReturnsTrue_WhenFlightStreamHasNext() throws Exception { + when(mockFlightStream.next()).thenReturn(true); + assertTrue(iterator.next()); + assert(mockFlightStream).next(); + } + + public void testNext_ReturnsFalse_WhenFlightStreamHasNoNext() throws Exception { + when(mockFlightStream.next()).thenReturn(false); + assertFalse(iterator.next()); + verify(mockFlightStream).next(); + } + + public void testGetRoot_ReturnsRootFromFlightStream() throws Exception { + VectorSchemaRoot returnedRoot = iterator.getRoot(); + assertEquals(root, returnedRoot); + verify(mockFlightStream).getRoot(); + } + + public void testClose_CallsCloseOnFlightStream() throws Exception { + iterator.close(); + verify(mockFlightStream).close(); + } + + public void testClose_WrapsExceptionInRuntimeException() throws Exception { + doThrow(new Exception("Test exception")).when(mockFlightStream).close(); + assertThrows(RuntimeException.class, () -> iterator.close()); + verify(mockFlightStream).close(); + } +} diff --git a/plugins/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/impl/FlightStreamTicketTests.java b/plugins/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/impl/FlightStreamTicketTests.java new file mode 100644 index 0000000000000..819da2826c173 --- /dev/null +++ b/plugins/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/impl/FlightStreamTicketTests.java @@ -0,0 +1,111 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.arrow.flight.impl; + +import org.opensearch.arrow.spi.StreamTicket; +import org.opensearch.test.OpenSearchTestCase; + +import java.nio.charset.StandardCharsets; +import java.util.Base64; + +public class FlightStreamTicketTests extends OpenSearchTestCase { + + public void testConstructorAndGetters() { + String ticketID = "ticket123"; + String nodeID = "node456"; + StreamTicket ticket = new FlightStreamTicket(ticketID, nodeID); + + assertEquals(ticketID, ticket.getTicketId()); + assertEquals(nodeID, ticket.getNodeId()); + } + + public void testToBytes() { + StreamTicket ticket = new FlightStreamTicket("ticket123", "node456"); + byte[] bytes = ticket.toBytes(); + + assertNotNull(bytes); + assertTrue(bytes.length > 0); + + // Decode the Base64 and check the structure + byte[] decoded = Base64.getDecoder().decode(bytes); + assertEquals(2 + 9 + 2 + 7, decoded.length); // 2 shorts + "ticket123" + "node456" + } + + public void testFromBytes() { + StreamTicket original = new FlightStreamTicket("ticket123", "node456"); + byte[] bytes = original.toBytes(); + + StreamTicket reconstructed = FlightStreamTicket.fromBytes(bytes); + + assertEquals(original.getTicketId(), reconstructed.getTicketId()); + assertEquals(original.getNodeId(), reconstructed.getNodeId()); + } + + public void testToBytesWithLongStrings() { + String longString = randomAlphaOfLength(Short.MAX_VALUE + 1); + StreamTicket ticket = new FlightStreamTicket(longString, "node456"); + + IllegalArgumentException exception = expectThrows(IllegalArgumentException.class, ticket::toBytes); + assertEquals("Field lengths exceed the maximum allowed size.", exception.getMessage()); + } + + public void testNullInput() { + IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> FlightStreamTicket.fromBytes(null)); + assertEquals("Invalid byte array input.", e.getMessage()); + } + + public void testEmptyInput() { + IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> FlightStreamTicket.fromBytes(new byte[0])); + assertEquals("Invalid byte array input.", e.getMessage()); + } + + public void testMalformedBase64() { + byte[] invalidBase64 = "Invalid Base64!@#$".getBytes(StandardCharsets.UTF_8); + IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> FlightStreamTicket.fromBytes(invalidBase64)); + assertEquals("Illegal base64 character 20", e.getMessage()); + } + + public void testModifiedLengthFields() { + StreamTicket original = new FlightStreamTicket("ticket123", "node456"); + byte[] bytes = original.toBytes(); + byte[] decoded = Base64.getDecoder().decode(bytes); + + // Modify the length field to be larger than actual data + decoded[0] = (byte) 0xFF; + decoded[1] = (byte) 0xFF; + + byte[] modified = Base64.getEncoder().encode(decoded); + IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> FlightStreamTicket.fromBytes(modified)); + assertEquals("Invalid ticketID length: -1", e.getMessage()); + } + + public void testEquals() { + StreamTicket ticket1 = new FlightStreamTicket("ticket123", "node456"); + StreamTicket ticket2 = new FlightStreamTicket("ticket123", "node456"); + StreamTicket ticket3 = new FlightStreamTicket("ticket789", "node456"); + + assertEquals(ticket1, ticket2); + assertNotEquals(ticket1, ticket3); + assertNotEquals(null, ticket1); + assertNotEquals("Not a StreamTicket", ticket1); + } + + public void testHashCode() { + StreamTicket ticket1 = new FlightStreamTicket("ticket123", "node456"); + StreamTicket ticket2 = new FlightStreamTicket("ticket123", "node456"); + + assertEquals(ticket1.hashCode(), ticket2.hashCode()); + } + + public void testToString() { + StreamTicket ticket = new FlightStreamTicket("ticket123", "node456"); + String expected = "FlightStreamTicket{ticketID='ticket123', nodeID='node456'}"; + assertEquals(expected, ticket.toString()); + } +} diff --git a/plugins/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/impl/ProxyStreamProducerTests.java b/plugins/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/impl/ProxyStreamProducerTests.java new file mode 100644 index 0000000000000..6a6273a601f21 --- /dev/null +++ b/plugins/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/impl/ProxyStreamProducerTests.java @@ -0,0 +1,112 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.arrow.flight.impl; + +import org.apache.arrow.flight.FlightStream; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.opensearch.arrow.spi.StreamProducer; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.test.OpenSearchTestCase; +import org.junit.After; + +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public class ProxyStreamProducerTests extends OpenSearchTestCase { + + private FlightStream mockRemoteStream; + private BufferAllocator mockAllocator; + private ProxyStreamProducer proxyStreamProducer; + + @Override + public void setUp() throws Exception { + super.setUp(); + mockRemoteStream = mock(FlightStream.class); + mockAllocator = mock(BufferAllocator.class); + proxyStreamProducer = new ProxyStreamProducer(new FlightStreamReader(mockRemoteStream)); + } + + public void testCreateRoot() throws Exception { + VectorSchemaRoot mockRoot = mock(VectorSchemaRoot.class); + when(mockRemoteStream.getRoot()).thenReturn(mockRoot); + + VectorSchemaRoot result = proxyStreamProducer.createRoot(mockAllocator); + + assertEquals(mockRoot, result); + verify(mockRemoteStream).getRoot(); + } + + public void testDefaults() { + assertEquals("", proxyStreamProducer.getAction()); + assertEquals(-1, proxyStreamProducer.estimatedRowCount()); + } + + public void testCreateJob() { + StreamProducer.BatchedJob job = proxyStreamProducer.createJob(mockAllocator); + + assertNotNull(job); + assertTrue(job instanceof ProxyStreamProducer.ProxyBatchedJob); + } + + public void testProxyBatchedJob() throws Exception { + StreamProducer.BatchedJob job = proxyStreamProducer.createJob(mockAllocator); + VectorSchemaRoot mockRoot = mock(VectorSchemaRoot.class); + StreamProducer.FlushSignal mockFlushSignal = mock(StreamProducer.FlushSignal.class); + + when(mockRemoteStream.next()).thenReturn(true, true, false); + + job.run(mockRoot, mockFlushSignal); + + verify(mockRemoteStream, times(3)).next(); + verify(mockFlushSignal, times(2)).awaitConsumption(TimeValue.timeValueMillis(1000)); + } + + public void testProxyBatchedJobWithException() throws Exception { + StreamProducer.BatchedJob job = proxyStreamProducer.createJob(mockAllocator); + VectorSchemaRoot mockRoot = mock(VectorSchemaRoot.class); + StreamProducer.FlushSignal mockFlushSignal = mock(StreamProducer.FlushSignal.class); + + doThrow(new RuntimeException("Test exception")).when(mockRemoteStream).next(); + + try { + job.run(mockRoot, mockFlushSignal); + fail("Expected RuntimeException"); + } catch (RuntimeException e) { + assertEquals("Test exception", e.getMessage()); + } + + verify(mockRemoteStream, times(1)).next(); + } + + public void testProxyBatchedJobOnCancel() throws Exception { + StreamProducer.BatchedJob job = proxyStreamProducer.createJob(mockAllocator); + VectorSchemaRoot mockRoot = mock(VectorSchemaRoot.class); + StreamProducer.FlushSignal mockFlushSignal = mock(StreamProducer.FlushSignal.class); + when(mockRemoteStream.next()).thenReturn(true, true, false); + + // cancel the job + job.onCancel(); + job.run(mockRoot, mockFlushSignal); + verify(mockRemoteStream, times(0)).next(); + verify(mockFlushSignal, times(0)).awaitConsumption(TimeValue.timeValueMillis(1000)); + assertTrue(job.isCancelled()); + } + + @After + public void tearDown() throws Exception { + if (proxyStreamProducer != null) { + proxyStreamProducer.close(); + } + super.tearDown(); + } +}