From e57e24b8681e8931970d57d7c5e81491ba20f0fc Mon Sep 17 00:00:00 2001
From: Larry Booker <larrybooker@deephaven.io>
Date: Mon, 28 Feb 2022 13:04:38 -0800
Subject: [PATCH] Increased compatibility with `Flight` reference
 implementation of `DoExchange` (#1964)

* added flight reference impl compatibility for DoExchange

* updated to require magic bytes in FlightData.FlightDescriptor.CMD
---
 java-client/flight-examples/build.gradle      |   4 +
 .../deephaven/client/examples/DoExchange.java |  92 +++++++++++
 .../deephaven/client/impl/FlightSession.java  |  18 ++-
 .../server/arrow/ArrowFlightUtil.java         |  51 +++++-
 .../test/FlightMessageRoundTripTest.java      | 152 ++++++++++++++++--
 5 files changed, 290 insertions(+), 27 deletions(-)
 create mode 100644 java-client/flight-examples/src/main/java/io/deephaven/client/examples/DoExchange.java

diff --git a/java-client/flight-examples/build.gradle b/java-client/flight-examples/build.gradle
index fd3f6d662ea..692d0066aba 100644
--- a/java-client/flight-examples/build.gradle
+++ b/java-client/flight-examples/build.gradle
@@ -7,6 +7,8 @@ dependencies {
     implementation project(':java-client-flight-dagger')
     implementation project(':java-client-example-utilities')
 
+    implementation "io.deephaven.barrage:barrage-format:0.4.0"
+
     Classpaths.inheritJUnitPlatform(project)
     Classpaths.inheritAssertJ(project)
     testImplementation 'org.junit.jupiter:junit-jupiter'
@@ -43,6 +45,8 @@ applicationDistribution.into('bin') {
     from(createApplication('aggregate-all', 'io.deephaven.client.examples.AggregateAllExample'))
     from(createApplication('agg-by', 'io.deephaven.client.examples.AggByExample'))
 
+    from(createApplication('do-exchange', 'io.deephaven.client.examples.DoExchange'))
+
     from(createApplication('do-put-new', 'io.deephaven.client.examples.DoPutNew'))
     from(createApplication('do-put-spray', 'io.deephaven.client.examples.DoPutSpray'))
     from(createApplication('do-put-table', 'io.deephaven.client.examples.DoPutTable'))
diff --git a/java-client/flight-examples/src/main/java/io/deephaven/client/examples/DoExchange.java b/java-client/flight-examples/src/main/java/io/deephaven/client/examples/DoExchange.java
new file mode 100644
index 00000000000..ac87477b609
--- /dev/null
+++ b/java-client/flight-examples/src/main/java/io/deephaven/client/examples/DoExchange.java
@@ -0,0 +1,92 @@
+package io.deephaven.client.examples;
+
+import com.google.flatbuffers.FlatBufferBuilder;
+import com.google.protobuf.ByteString;
+import io.deephaven.client.impl.FlightSession;
+import io.deephaven.proto.util.ScopeTicketHelper;
+import org.apache.arrow.flight.FlightClient;
+import org.apache.arrow.flight.FlightDescriptor;
+import org.apache.arrow.memory.ArrowBuf;
+import org.apache.arrow.memory.RootAllocator;
+import picocli.CommandLine;
+import picocli.CommandLine.ArgGroup;
+import picocli.CommandLine.Command;
+
+import io.deephaven.barrage.flatbuf.*;
+
+@Command(name = "do-exchange", mixinStandardHelpOptions = true,
+        description = "Start a DoExchange session with the server", version = "0.1.0")
+class DoExchange extends FlightExampleBase {
+
+    @ArgGroup(exclusive = true, multiplicity = "1")
+    Ticket ticket;
+
+    @Override
+    protected void execute(FlightSession flight) throws Exception {
+
+        // need to provide the MAGIC bytes as the FlightDescriptor.cmd in the initial message
+        byte[] cmd = new byte[] {100, 112, 104, 110}; // equivalent to '0x6E687064' (ASCII "dphn")
+
+        FlightDescriptor fd = FlightDescriptor.command(cmd);
+
+        // create the bi-directional reader/writer
+        try (FlightClient.ExchangeReaderWriter erw = flight.startExchange(fd);
+                final RootAllocator allocator = new RootAllocator(Integer.MAX_VALUE)) {
+
+            /////////////////////////////////////////////////////////////
+            // create a BarrageSnapshotRequest for ticket 's/timetable'
+            /////////////////////////////////////////////////////////////
+
+            // inner metadata for the snapshot request
+            final FlatBufferBuilder metadata = new FlatBufferBuilder();
+
+            int optOffset =
+                    BarrageSnapshotOptions.createBarrageSnapshotOptions(metadata, ColumnConversionMode.Stringify,
+                            false, 1000);
+
+            final int ticOffset =
+                    BarrageSnapshotRequest.createTicketVector(metadata,
+                            ScopeTicketHelper.nameToBytes(ticket.scopeField.variable));
+            BarrageSnapshotRequest.startBarrageSnapshotRequest(metadata);
+            BarrageSnapshotRequest.addColumns(metadata, 0);
+            BarrageSnapshotRequest.addViewport(metadata, 0);
+            BarrageSnapshotRequest.addSnapshotOptions(metadata, optOffset);
+            BarrageSnapshotRequest.addTicket(metadata, ticOffset);
+            metadata.finish(BarrageSnapshotRequest.endBarrageSnapshotRequest(metadata));
+
+            // outer metadata to ID the message type and provide the MAGIC bytes
+            final FlatBufferBuilder wrapper = new FlatBufferBuilder();
+            final int innerOffset = wrapper.createByteVector(metadata.dataBuffer());
+            wrapper.finish(BarrageMessageWrapper.createBarrageMessageWrapper(
+                    wrapper,
+                    0x6E687064, // the numerical representation of the ASCII "dphn".
+                    BarrageMessageType.BarrageSnapshotRequest,
+                    innerOffset));
+
+            // extract the bytes and package them in an ArrowBuf for transmission
+            cmd = wrapper.sizedByteArray();
+            ArrowBuf data = allocator.buffer(cmd.length);
+            data.writeBytes(cmd);
+
+            // `putMetadata()` makes the GRPC call
+            erw.getWriter().putMetadata(data);
+
+            // snapshot requests do not need to stay open on the client side
+            erw.getWriter().completed();
+
+            // read everything from the server
+            while (erw.getReader().next()) {
+                // NOP
+            }
+
+            // print the table data
+            System.out.println(erw.getReader().getSchema().toString());
+            System.out.println(erw.getReader().getRoot().contentToTSVString());
+        }
+    }
+
+    public static void main(String[] args) {
+        int execute = new CommandLine(new DoExchange()).execute(args);
+        System.exit(execute);
+    }
+}
diff --git a/java-client/flight/src/main/java/io/deephaven/client/impl/FlightSession.java b/java-client/flight/src/main/java/io/deephaven/client/impl/FlightSession.java
index 46ed4649c62..7934b1f9aef 100644
--- a/java-client/flight/src/main/java/io/deephaven/client/impl/FlightSession.java
+++ b/java-client/flight/src/main/java/io/deephaven/client/impl/FlightSession.java
@@ -5,11 +5,7 @@
 import io.deephaven.proto.flight.util.SchemaHelper;
 import io.deephaven.qst.table.NewTable;
 import io.grpc.ManagedChannel;
-import org.apache.arrow.flight.Criteria;
-import org.apache.arrow.flight.FlightClient;
-import org.apache.arrow.flight.FlightGrpcUtilsExtension;
-import org.apache.arrow.flight.FlightInfo;
-import org.apache.arrow.flight.FlightStream;
+import org.apache.arrow.flight.*;
 import org.apache.arrow.memory.BufferAllocator;
 import org.apache.arrow.vector.types.pojo.Schema;
 
@@ -79,6 +75,18 @@ public FlightStream stream(HasTicketId ticketId) {
         return FlightClientHelper.get(client, ticketId);
     }
 
+    /**
+     * Creates a new server side DoExchange session.
+     *
+     * @param descriptor the FlightDescriptor object to include on the first FlightData message (other fields will
+     *        remain null)
+     * @param options the GRPC otions to apply to this call
+     * @return the bi-directional ReaderWriter object
+     */
+    public FlightClient.ExchangeReaderWriter startExchange(FlightDescriptor descriptor, CallOption... options) {
+        return client.doExchange(descriptor, options);
+    }
+
     /**
      * Creates a new server side exported table backed by the server semantics of DoPut with a {@link NewTable} payload.
      *
diff --git a/server/src/main/java/io/deephaven/server/arrow/ArrowFlightUtil.java b/server/src/main/java/io/deephaven/server/arrow/ArrowFlightUtil.java
index 5491beaeffb..dd00550de85 100644
--- a/server/src/main/java/io/deephaven/server/arrow/ArrowFlightUtil.java
+++ b/server/src/main/java/io/deephaven/server/arrow/ArrowFlightUtil.java
@@ -9,6 +9,7 @@
 import gnu.trove.list.array.TLongArrayList;
 import io.deephaven.UncheckedDeephavenException;
 import io.deephaven.barrage.flatbuf.BarrageMessageType;
+import io.deephaven.barrage.flatbuf.BarrageMessageWrapper;
 import io.deephaven.barrage.flatbuf.BarrageSnapshotRequest;
 import io.deephaven.barrage.flatbuf.BarrageSubscriptionRequest;
 import io.deephaven.chunk.ChunkType;
@@ -49,6 +50,7 @@
 import java.io.IOException;
 import java.io.InputStream;
 import java.nio.ByteBuffer;
+import java.nio.ByteOrder;
 import java.util.ArrayDeque;
 import java.util.BitSet;
 import java.util.Iterator;
@@ -296,6 +298,8 @@ public interface Factory {
 
         private boolean isClosed = false;
 
+        private boolean isFirstMsg = true;
+
         private final TicketRouter ticketRouter;
         private final BarrageMessageProducer.Operation.Factory<BarrageStreamGenerator.View> operationFactory;
         private final BarrageMessageProducer.Adapter<BarrageSubscriptionRequest, BarrageSubscriptionOptions> optionsAdapter;
@@ -336,14 +340,19 @@ public void onNext(final InputStream request) {
             GrpcUtil.rpcWrapper(log, listener, () -> {
                 BarrageProtoUtil.MessageInfo message = BarrageProtoUtil.parseProtoMessage(request);
                 synchronized (this) {
-                    if (message.app_metadata == null
-                            || message.app_metadata.magic() != BarrageUtil.FLATBUFFER_MAGIC) {
-                        log.warn().append(myPrefix).append("received a message without app_metadata").endl();
+
+                    // `FlightData` messages from Barrage clients will provide app_metadata describing the request but
+                    // official Flight implementations may force a NULL metadata field in the first message. In that
+                    // case, identify a valid Barrage connection by verifying the `FlightDescriptor.CMD` field contains
+                    // the `Barrage` magic bytes
+
+                    if (requestHandler != null) {
+                        // rely on the handler to verify message type
+                        requestHandler.handleMessage(message);
                         return;
                     }
 
-                    // handle the different message types that can come over DoExchange
-                    if (requestHandler == null) {
+                    if (message.app_metadata != null) {
                         // handle the different message types that can come over DoExchange
                         switch (message.app_metadata.msgType()) {
                             case BarrageMessageType.BarrageSubscriptionRequest:
@@ -356,9 +365,37 @@ public void onNext(final InputStream request) {
                                 throw GrpcUtil.statusRuntimeException(Code.INVALID_ARGUMENT,
                                         myPrefix + "received a message with unhandled BarrageMessageType");
                         }
+                        requestHandler.handleMessage(message);
+                        return;
+                    }
+
+                    // handle the possible error cases
+                    if (!isFirstMsg) {
+                        // only the first messages is allowed to have null metadata
+                        throw GrpcUtil.statusRuntimeException(Code.INVALID_ARGUMENT,
+                                myPrefix + "failed to receive Barrage request metadata");
+                    }
+
+                    isFirstMsg = false;
+
+                    // The magic value is '0x6E687064'. It is the numerical representation of the ASCII "dphn".
+                    int size = message.descriptor.getCmd().size();
+                    if (size == 4) {
+                        ByteBuffer bb = message.descriptor.getCmd().asReadOnlyByteBuffer();
+
+                        // set the order to little-endian (FlatBuffers default)
+                        bb.order(ByteOrder.LITTLE_ENDIAN);
+
+                        // read and compare the value to the "magic" bytes
+                        long value = (long) bb.getInt(0) & 0xFFFFFFFFL;
+                        if (value != BarrageUtil.FLATBUFFER_MAGIC) {
+                            throw GrpcUtil.statusRuntimeException(Code.INVALID_ARGUMENT,
+                                    myPrefix + "expected BarrageMessageWrapper magic bytes in FlightDescriptor.cmd");
+                        }
+                    } else {
+                        throw GrpcUtil.statusRuntimeException(Code.INVALID_ARGUMENT,
+                                myPrefix + "expected BarrageMessageWrapper magic bytes in FlightDescriptor.cmd");
                     }
-                    // rely on the handler to verify message type
-                    requestHandler.handleMessage(message);
                 }
             });
         }
diff --git a/server/test/src/main/java/io/deephaven/server/test/FlightMessageRoundTripTest.java b/server/test/src/main/java/io/deephaven/server/test/FlightMessageRoundTripTest.java
index 318b7c4ee8d..f82b2448c9e 100644
--- a/server/test/src/main/java/io/deephaven/server/test/FlightMessageRoundTripTest.java
+++ b/server/test/src/main/java/io/deephaven/server/test/FlightMessageRoundTripTest.java
@@ -1,8 +1,12 @@
 package io.deephaven.server.test;
 
+import com.google.flatbuffers.DoubleVector;
+import com.google.flatbuffers.FlatBufferBuilder;
+import com.google.flatbuffers.LongVector;
 import dagger.Module;
 import dagger.Provides;
 import dagger.multibindings.IntoSet;
+import io.deephaven.barrage.flatbuf.*;
 import io.deephaven.base.verify.Assert;
 import io.deephaven.engine.liveness.LivenessScopeStack;
 import io.deephaven.engine.table.Table;
@@ -30,19 +34,11 @@
 import io.grpc.ManagedChannel;
 import io.grpc.ManagedChannelBuilder;
 import io.grpc.ServerInterceptor;
-import org.apache.arrow.flight.AsyncPutListener;
-import org.apache.arrow.flight.CallHeaders;
-import org.apache.arrow.flight.CallStatus;
-import org.apache.arrow.flight.Criteria;
-import org.apache.arrow.flight.FlightClient;
-import org.apache.arrow.flight.FlightClientMiddleware;
-import org.apache.arrow.flight.FlightDescriptor;
-import org.apache.arrow.flight.FlightInfo;
-import org.apache.arrow.flight.FlightStream;
-import org.apache.arrow.flight.Location;
-import org.apache.arrow.flight.Ticket;
+import org.apache.arrow.flight.*;
 import org.apache.arrow.flight.impl.Flight;
+import org.apache.arrow.memory.ArrowBuf;
 import org.apache.arrow.memory.RootAllocator;
+import org.apache.arrow.vector.FieldVector;
 import org.apache.arrow.vector.VectorSchemaRoot;
 import org.apache.arrow.vector.types.pojo.ArrowType;
 import org.apache.arrow.vector.types.pojo.Field;
@@ -64,10 +60,7 @@
 import java.util.concurrent.Executors;
 import java.util.concurrent.TimeUnit;
 
-import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertFalse;
-import static org.junit.Assert.assertNotNull;
-import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.*;
 
 /**
  * Deliberately much lower in scope (and running time) than BarrageMessageRoundTripTest, the only purpose of this test
@@ -387,6 +380,135 @@ public void testGetSchema() {
         }
     }
 
+    @Test
+    public void testDoExchangeSnapshot() {
+        final String staticTableName = "flightInfoTest";
+        final Table table = TableTools.emptyTable(10).update("I = i", "J = i + 0.01");
+
+        try (final SafeCloseable ignored = LivenessScopeStack.open(scriptSession, false)) {
+            // stuff table into the scope
+            scriptSession.setVariable(staticTableName, table);
+
+            // build up a snapshot request
+            byte[] magic = new byte[] {100, 112, 104, 110}; // equivalent to '0x6E687064' (ASCII "dphn")
+
+            FlightDescriptor fd = FlightDescriptor.command(magic);
+
+            try (FlightClient.ExchangeReaderWriter erw = client.doExchange(fd);
+                    final RootAllocator allocator = new RootAllocator(Integer.MAX_VALUE)) {
+
+                final FlatBufferBuilder metadata = new FlatBufferBuilder();
+
+                int optOffset =
+                        BarrageSnapshotOptions.createBarrageSnapshotOptions(metadata, ColumnConversionMode.Stringify,
+                                false, 1000);
+
+                final int ticOffset =
+                        BarrageSnapshotRequest.createTicketVector(metadata,
+                                ScopeTicketHelper.nameToBytes(staticTableName));
+                BarrageSnapshotRequest.startBarrageSnapshotRequest(metadata);
+                BarrageSnapshotRequest.addColumns(metadata, 0);
+                BarrageSnapshotRequest.addViewport(metadata, 0);
+                BarrageSnapshotRequest.addSnapshotOptions(metadata, optOffset);
+                BarrageSnapshotRequest.addTicket(metadata, ticOffset);
+                metadata.finish(BarrageSnapshotRequest.endBarrageSnapshotRequest(metadata));
+
+                final FlatBufferBuilder wrapper = new FlatBufferBuilder();
+                final int innerOffset = wrapper.createByteVector(metadata.dataBuffer());
+                wrapper.finish(BarrageMessageWrapper.createBarrageMessageWrapper(
+                        wrapper,
+                        0x6E687064, // the numerical representation of the ASCII "dphn".
+                        BarrageMessageType.BarrageSnapshotRequest,
+                        innerOffset));
+
+                // extract the bytes and package them in an ArrowBuf for transmission
+                byte[] msg = wrapper.sizedByteArray();
+                ArrowBuf data = allocator.buffer(msg.length);
+                data.writeBytes(msg);
+
+                erw.getWriter().putMetadata(data);
+                erw.getWriter().completed();
+
+                // read everything from the server (expecting schema message and one data message)
+                int numMessages = 0;
+                while (erw.getReader().next()) {
+                    ++numMessages;
+                }
+                assertEquals(1, numMessages); // only one data message
+
+                // at this point should have the data, verify it matches the created table
+                assertEquals(erw.getReader().getRoot().getRowCount(), table.size());
+
+                // check the values against the source table
+                org.apache.arrow.vector.IntVector iv =
+                        (org.apache.arrow.vector.IntVector) erw.getReader().getRoot().getVector(0);
+                for (int i = 0; i < table.size(); i++) {
+                    assertEquals("int match:", table.getColumn(0).get(i), iv.get(i));
+                }
+                org.apache.arrow.vector.Float8Vector dv =
+                        (org.apache.arrow.vector.Float8Vector) erw.getReader().getRoot().getVector(1);
+                for (int i = 0; i < table.size(); i++) {
+                    assertEquals("double match: ", table.getColumn(1).get(i), dv.get(i));
+                }
+            } catch (Exception e) {
+                e.printStackTrace();
+            }
+        }
+    }
+
+    @Test
+    public void testDoExchangeProtocol() {
+        final String staticTableName = "flightInfoTest";
+        final Table table = TableTools.emptyTable(10).update("I = i", "J = i + 0.01");
+
+        try (final SafeCloseable ignored = LivenessScopeStack.open(scriptSession, false)) {
+            // stuff table into the scope
+            scriptSession.setVariable(staticTableName, table);
+
+            // build up a snapshot request incorrectly
+            byte[] empty = new byte[0];
+
+            FlightDescriptor fd = FlightDescriptor.command(empty);
+
+            try (FlightClient.ExchangeReaderWriter erw = client.doExchange(fd)) {
+
+                Exception exception = assertThrows(FlightRuntimeException.class, () -> {
+                    erw.getReader().next();
+                });
+
+                String expectedMessage = "expected BarrageMessageWrapper magic bytes in FlightDescriptor.cmd";
+                String actualMessage = exception.getMessage();
+
+                assertTrue(actualMessage.contains(expectedMessage));
+            }
+
+            byte[] magic = new byte[] {100, 112, 104, 110}; // equivalent to '0x6E687064' (ASCII "dphn")
+            fd = FlightDescriptor.command(magic);
+            try (FlightClient.ExchangeReaderWriter erw = client.doExchange(fd);
+                    final RootAllocator allocator = new RootAllocator(Integer.MAX_VALUE)) {
+
+                byte[] msg = new byte[0];
+                ArrowBuf data = allocator.buffer(msg.length);
+                data.writeBytes(msg);
+
+                erw.getWriter().putMetadata(data);
+                erw.getWriter().completed();
+
+                Exception exception = assertThrows(FlightRuntimeException.class, () -> {
+                    erw.getReader().next();
+                });
+
+                String expectedMessage = "failed to receive Barrage request metadata";
+                String actualMessage = exception.getMessage();
+
+                assertTrue(actualMessage.contains(expectedMessage));
+            }
+
+        } catch (Exception e) {
+            e.printStackTrace();
+        }
+    }
+
     private static FlightDescriptor arrowFlightDescriptorForName(String name) {
         return FlightDescriptor.path(ScopeTicketHelper.nameToPath(name));
     }