From d3e7a3fa497a05dc5dd48b62787da8d0516c5cee Mon Sep 17 00:00:00 2001 From: Santiago Pericasgeertsen Date: Fri, 9 Feb 2024 10:10:45 -0500 Subject: [PATCH] Basic support for unary, serverStream, streamClient and bidi invocations. Signed-off-by: Santiago Pericasgeertsen --- .../webserver/protocols/GrpcTest.java | 151 ++++++++++++++- .../webclient/grpc/GrpcClientCall.java | 176 ++++++++++++++---- .../webclient/grpc/GrpcServiceClient.java | 16 +- .../webclient/grpc/GrpcServiceClientImpl.java | 102 ++++++++-- .../webclient/http2/Http2ClientStream.java | 4 +- 5 files changed, 374 insertions(+), 75 deletions(-) diff --git a/examples/webserver/protocols/src/test/java/io/helidon/examples/webserver/protocols/GrpcTest.java b/examples/webserver/protocols/src/test/java/io/helidon/examples/webserver/protocols/GrpcTest.java index 2a77c407df6..2e15cc0339a 100644 --- a/examples/webserver/protocols/src/test/java/io/helidon/examples/webserver/protocols/GrpcTest.java +++ b/examples/webserver/protocols/src/test/java/io/helidon/examples/webserver/protocols/GrpcTest.java @@ -16,9 +16,12 @@ package io.helidon.examples.webserver.protocols; +import java.util.Iterator; +import java.util.List; import java.util.Locale; import com.google.protobuf.StringValue; +import io.grpc.stub.StreamObserver; import io.helidon.examples.grpc.strings.Strings; import io.helidon.webclient.grpc.GrpcClient; import io.helidon.webclient.grpc.GrpcClientMethodDescriptor; @@ -27,7 +30,7 @@ import io.helidon.webserver.grpc.GrpcRouting; import io.helidon.webserver.testing.junit5.ServerTest; import io.helidon.webserver.testing.junit5.SetUpServer; -import org.junit.jupiter.api.RepeatedTest; +import org.junit.jupiter.api.Test; import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.MatcherAssert.assertThat; @@ -44,20 +47,94 @@ private GrpcTest(GrpcClient grpcClient) { @SetUpServer public static void setup(WebServerConfig.Builder builder) { builder.addRouting(GrpcRouting.builder() - .unary(Strings.getDescriptor(), - "StringService", - "Upper", - GrpcTest::blockingGrpcUpper)); + .unary(Strings.getDescriptor(), + "StringService", + "Upper", + GrpcTest::upper) + .serverStream(Strings.getDescriptor(), + "StringService", + "Split", + GrpcTest::split) + .clientStream(Strings.getDescriptor(), + "StringService", + "Join", + GrpcTest::join) + .bidi(Strings.getDescriptor(), + "StringService", + "Echo", + GrpcTest::echo)); } - private static Strings.StringMessage blockingGrpcUpper(Strings.StringMessage reqT) { + private static Strings.StringMessage upper(Strings.StringMessage reqT) { return Strings.StringMessage.newBuilder() .setText(reqT.getText().toUpperCase(Locale.ROOT)) .build(); } - @RepeatedTest(3) - void testSimpleCall() { + private static void split(Strings.StringMessage reqT, + StreamObserver streamObserver) { + String[] strings = reqT.getText().split(" "); + for (String string : strings) { + streamObserver.onNext(Strings.StringMessage.newBuilder() + .setText(string) + .build()); + + } + streamObserver.onCompleted(); + } + + private static StreamObserver join(StreamObserver streamObserver) { + return new StreamObserver<>() { + private StringBuilder builder; + + @Override + public void onNext(Strings.StringMessage value) { + if (builder == null) { + builder = new StringBuilder(); + builder.append(value.getText()); + } else { + builder.append(" ").append(value.getText()); + } + } + + @Override + public void onError(Throwable t) { + streamObserver.onError(t); + } + + @Override + public void onCompleted() { + streamObserver.onNext(Strings.StringMessage.newBuilder() + .setText(builder.toString()) + .build()); + streamObserver.onCompleted(); + } + }; + } + + private static StreamObserver echo(StreamObserver streamObserver) { + return new StreamObserver<>() { + private StringBuilder builder; + + @Override + public void onNext(Strings.StringMessage value) { + streamObserver.onNext(value); + } + + @Override + public void onError(Throwable t) { + streamObserver.onError(t); + } + + @Override + public void onCompleted() { + streamObserver.onCompleted(); + } + }; + } + + @Test + void testUpper() { GrpcServiceDescriptor serviceDescriptor = GrpcServiceDescriptor.builder() .serviceName("StringService") @@ -70,7 +147,63 @@ void testSimpleCall() { StringValue r = grpcClient.serviceClient(serviceDescriptor) .unary("Upper", StringValue.of("hello")); - System.out.println("r = " + r.getValue()); assertThat(r.getValue(), is("HELLO")); } + + @Test + void testSplit() { + GrpcServiceDescriptor serviceDescriptor = + GrpcServiceDescriptor.builder() + .serviceName("StringService") + .putMethod("Split", + GrpcClientMethodDescriptor.serverStreaming("StringService", "Split") + .requestType(StringValue.class) + .responseType(StringValue.class) + .build()) + .build(); + + Iterator r = grpcClient.serviceClient(serviceDescriptor) + .serverStream("Split", StringValue.of("hello world")); + assertThat(r.next().getValue(), is("hello")); + assertThat(r.next().getValue(), is("world")); + assertThat(r.hasNext(), is(false)); + } + + @Test + void testJoin() { + GrpcServiceDescriptor serviceDescriptor = + GrpcServiceDescriptor.builder() + .serviceName("StringService") + .putMethod("Join", + GrpcClientMethodDescriptor.clientStreaming("StringService", "Join") + .requestType(StringValue.class) + .responseType(StringValue.class) + .build()) + .build(); + + StringValue r = grpcClient.serviceClient(serviceDescriptor) + .clientStream("Join", List.of(StringValue.of("hello"), + StringValue.of("world")).iterator()); + assertThat(r.getValue(), is("hello world")); + } + + @Test + void testEcho() { + GrpcServiceDescriptor serviceDescriptor = + GrpcServiceDescriptor.builder() + .serviceName("StringService") + .putMethod("Echo", + GrpcClientMethodDescriptor.bidirectional("StringService", "Echo") + .requestType(StringValue.class) + .responseType(StringValue.class) + .build()) + .build(); + + Iterator r = grpcClient.serviceClient(serviceDescriptor) + .bidi("Echo", List.of(StringValue.of("hello"), + StringValue.of("world")).iterator()); + assertThat(r.next().getValue(), is("hello")); + assertThat(r.next().getValue(), is("world")); + assertThat(r.hasNext(), is(false)); + } } diff --git a/webclient/grpc/src/main/java/io/helidon/webclient/grpc/GrpcClientCall.java b/webclient/grpc/src/main/java/io/helidon/webclient/grpc/GrpcClientCall.java index 3cf060f55f2..dc9a37aeaa4 100644 --- a/webclient/grpc/src/main/java/io/helidon/webclient/grpc/GrpcClientCall.java +++ b/webclient/grpc/src/main/java/io/helidon/webclient/grpc/GrpcClientCall.java @@ -19,9 +19,11 @@ import java.io.InputStream; import java.time.Duration; import java.util.Collections; -import java.util.Queue; +import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutorService; +import java.util.concurrent.Future; import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; import io.grpc.ClientCall; @@ -34,8 +36,10 @@ import io.helidon.http.HeaderNames; import io.helidon.http.HeaderValues; import io.helidon.http.WritableHeaders; +import io.helidon.http.http2.Http2FrameData; import io.helidon.http.http2.Http2Headers; import io.helidon.http.http2.Http2Settings; +import io.helidon.http.http2.Http2StreamState; import io.helidon.webclient.api.ClientConnection; import io.helidon.webclient.api.ClientUri; import io.helidon.webclient.api.ConnectionKey; @@ -47,32 +51,48 @@ import io.helidon.webclient.http2.Http2ClientConnection; import io.helidon.webclient.http2.Http2ClientImpl; import io.helidon.webclient.http2.Http2StreamConfig; +import io.helidon.webclient.http2.StreamTimeoutException; + +import static java.lang.System.Logger.Level.DEBUG; /** * A gRPC client call handler. The typical order of calls will be: * - * start request* sendMessage* halfClose - * - * TODO: memory synchronization across method calls + * start (request | sendMessage)* (halfClose | cancel) * * @param * @param */ class GrpcClientCall extends ClientCall { + private static final System.Logger LOGGER = System.getLogger(GrpcClientCall.class.getName()); + private static final Header GRPC_ACCEPT_ENCODING = HeaderValues.create(HeaderNames.ACCEPT_ENCODING, "gzip"); private static final Header GRPC_CONTENT_TYPE = HeaderValues.create(HeaderNames.CONTENT_TYPE, "application/grpc"); + private static final int WAIT_TIME_MILLIS = 100; + private static final Duration WAIT_TIME_MILLIS_DURATION = Duration.ofMillis(WAIT_TIME_MILLIS); + + private static final BufferData EMPTY_BUFFER_DATA = BufferData.empty(); + + private final ExecutorService executor; private final GrpcClientImpl grpcClient; private final GrpcClientMethodDescriptor method; - private final AtomicInteger messages = new AtomicInteger(); + private final AtomicInteger messageRequest = new AtomicInteger(); private final MethodDescriptor.Marshaller requestMarshaller; private final MethodDescriptor.Marshaller responseMarshaller; - private final Queue messageQueue = new LinkedBlockingQueue<>(); + + private final LinkedBlockingQueue sendingQueue = new LinkedBlockingQueue<>(); + private final LinkedBlockingQueue receivingQueue = new LinkedBlockingQueue<>(); + + private final CountDownLatch startReadBarrier = new CountDownLatch(1); + private final CountDownLatch startWriteBarrier = new CountDownLatch(1); private volatile Http2ClientConnection connection; private volatile GrpcClientStream clientStream; private volatile Listener responseListener; + private volatile Future readStreamFuture; + private volatile Future writeStreamFuture; @SuppressWarnings("unchecked") GrpcClientCall(GrpcClientImpl grpcClient, GrpcClientMethodDescriptor method) { @@ -80,6 +100,7 @@ class GrpcClientCall extends ClientCall { this.method = method; this.requestMarshaller = (MethodDescriptor.Marshaller) method.descriptor().getRequestMarshaller(); this.responseMarshaller = (MethodDescriptor.Marshaller) method.descriptor().getResponseMarshaller(); + this.executor = grpcClient.webClient().executor(); } @Override @@ -114,6 +135,9 @@ public Duration readTimeout() { null, // Http2ClientConfig connection.streamIdSequence()); + // start streaming threads + startStreamingThreads(); + // send HEADERS frame ClientUri clientUri = grpcClient.prototype().baseUri().orElseThrow(); WritableHeaders headers = WritableHeaders.create(); @@ -128,44 +152,20 @@ public Duration readTimeout() { @Override public void request(int numMessages) { - messages.addAndGet(numMessages); - - ExecutorService executor = grpcClient.webClient().executor(); - executor.submit(() -> { - clientStream.readHeaders(); - while (messages.decrementAndGet() > 0) { - BufferData bufferData = clientStream.read(); - bufferData.read(); // compression - bufferData.readUnsignedInt32(); // length prefixed - ResT res = responseMarshaller.parse(new InputStream() { - @Override - public int read() { - return bufferData.available() > 0 ? bufferData.read() : -1; - } - }); - responseListener.onMessage(res); - } - responseListener.onClose(Status.OK, new Metadata()); - clientStream.close(); - connection.close(); - }); + messageRequest.addAndGet(numMessages); + LOGGER.log(DEBUG, () -> "Messages requested " + numMessages); + startReadBarrier.countDown(); } @Override public void cancel(String message, Throwable cause) { - // close the stream/connection via RST_STREAM - messageQueue.clear(); - clientStream.cancel(); - connection.close(); + responseListener.onClose(Status.CANCELLED, new Metadata()); + close(); } @Override public void halfClose() { - // drain the message queue - while (!messageQueue.isEmpty()) { - BufferData msg = messageQueue.poll(); - clientStream.writeData(msg, messageQueue.isEmpty()); - } + sendingQueue.add(EMPTY_BUFFER_DATA); // end marker } @Override @@ -176,7 +176,84 @@ public void sendMessage(ReqT message) { BufferData headerData = BufferData.create(5); headerData.writeInt8(0); // no compression headerData.writeUnsignedInt32(messageData.available()); // length prefixed - messageQueue.add(BufferData.create(headerData, messageData)); + sendingQueue.add(BufferData.create(headerData, messageData)); + startWriteBarrier.countDown(); + } + + private void startStreamingThreads() { + // write streaming thread + writeStreamFuture = executor.submit(() -> { + try { + startWriteBarrier.await(); + LOGGER.log(DEBUG, "[Writing thread] started"); + + while (isRemoteOpen()) { + LOGGER.log(DEBUG, "[Writing thread] polling sending queue"); + BufferData bufferData = sendingQueue.poll(WAIT_TIME_MILLIS, TimeUnit.MILLISECONDS); + if (bufferData != null) { + if (bufferData == EMPTY_BUFFER_DATA) { // end marker + LOGGER.log(DEBUG, "[Writing thread] sending queue end marker found"); + break; + } + boolean endOfStream = (sendingQueue.peek() == EMPTY_BUFFER_DATA); + LOGGER.log(DEBUG, () -> "[Writing thread] writing bufferData " + endOfStream); + clientStream.writeData(bufferData, endOfStream); + } + } + } catch (InterruptedException e) { + // falls through + } + LOGGER.log(DEBUG, "[Writing thread] exiting"); + }); + + // read streaming thread + readStreamFuture = executor.submit(() -> { + try { + startReadBarrier.await(); + LOGGER.log(DEBUG, "[Reading thread] started"); + + // read response headers + clientStream.readHeaders(); + + while (isRemoteOpen()) { + // attempt to send queued messages + drainReceivingQueue(); + + // attempt to read and queue + Http2FrameData frameData; + try { + frameData = clientStream.readOne(WAIT_TIME_MILLIS_DURATION); + } catch (StreamTimeoutException e) { + LOGGER.log(DEBUG, "[Reading thread] read timeout"); + continue; + } + if (frameData != null) { + receivingQueue.add(frameData.data()); + LOGGER.log(DEBUG, "[Reading thread] adding bufferData to receiving queue"); + } + + // trailers received? + if (clientStream.trailers().isDone()) { + drainReceivingQueue(); // one more attempt + break; + } + } + + responseListener.onClose(Status.OK, new Metadata()); + close(); + } catch (InterruptedException e) { + // falls through + } + LOGGER.log(DEBUG, "[Reading thread] exiting"); + }); + } + + private void close() { + readStreamFuture.cancel(true); + writeStreamFuture.cancel(true); + sendingQueue.clear(); + clientStream.cancel(); + connection.close(); } private ClientConnection clientConnection() { @@ -202,4 +279,29 @@ private ClientConnection clientConnection() { connection -> { }).connect(); } + + private boolean isRemoteOpen() { + return clientStream.streamState() != Http2StreamState.HALF_CLOSED_REMOTE + && clientStream.streamState() != Http2StreamState.CLOSED; + } + + private ResT toResponse(BufferData bufferData) { + bufferData.read(); // compression + bufferData.readUnsignedInt32(); // length prefixed + return responseMarshaller.parse(new InputStream() { + @Override + public int read() { + return bufferData.available() > 0 ? bufferData.read() : -1; + } + }); + } + + private void drainReceivingQueue() { + while (messageRequest.get() > 0 && !receivingQueue.isEmpty()) { + messageRequest.getAndDecrement(); + ResT res = toResponse(receivingQueue.remove()); + LOGGER.log(DEBUG, "[Reading thread] sending response to listener"); + responseListener.onMessage(res); + } + } } diff --git a/webclient/grpc/src/main/java/io/helidon/webclient/grpc/GrpcServiceClient.java b/webclient/grpc/src/main/java/io/helidon/webclient/grpc/GrpcServiceClient.java index 0a578892e78..b5f54bdd5ed 100644 --- a/webclient/grpc/src/main/java/io/helidon/webclient/grpc/GrpcServiceClient.java +++ b/webclient/grpc/src/main/java/io/helidon/webclient/grpc/GrpcServiceClient.java @@ -16,7 +16,7 @@ package io.helidon.webclient.grpc; -import java.util.Collection; +import java.util.Iterator; import io.grpc.stub.StreamObserver; @@ -35,17 +35,17 @@ public interface GrpcServiceClient { RespT unary(String methodName, ReqT request); - StreamObserver unary(String methodName, StreamObserver responseObserver); + StreamObserver unary(String methodName, StreamObserver response); - Collection serverStream(String methodName, ReqT request); + Iterator serverStream(String methodName, ReqT request); - void serverStream(String methodName, ReqT request, StreamObserver responseObserver); + void serverStream(String methodName, ReqT request, StreamObserver response); - RespT clientStream(String methodName, Collection request); + RespT clientStream(String methodName, Iterator request); - StreamObserver clientStream(String methodName, StreamObserver responseObserver); + StreamObserver clientStream(String methodName, StreamObserver response); - Collection bidi(String methodName, Collection responseObserver); + Iterator bidi(String methodName, Iterator request); - StreamObserver bidi(String methodName, StreamObserver responseObserver); + StreamObserver bidi(String methodName, StreamObserver response); } diff --git a/webclient/grpc/src/main/java/io/helidon/webclient/grpc/GrpcServiceClientImpl.java b/webclient/grpc/src/main/java/io/helidon/webclient/grpc/GrpcServiceClientImpl.java index 9655c8755c2..34ac48f73bd 100644 --- a/webclient/grpc/src/main/java/io/helidon/webclient/grpc/GrpcServiceClientImpl.java +++ b/webclient/grpc/src/main/java/io/helidon/webclient/grpc/GrpcServiceClientImpl.java @@ -16,13 +16,15 @@ package io.helidon.webclient.grpc; -import java.util.Collection; - -import io.grpc.stub.StreamObserver; +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; +import java.util.concurrent.CompletableFuture; import io.grpc.ClientCall; import io.grpc.MethodDescriptor; import io.grpc.stub.ClientCalls; +import io.grpc.stub.StreamObserver; class GrpcServiceClientImpl implements GrpcServiceClient { private final GrpcServiceDescriptor descriptor; @@ -40,53 +42,115 @@ public String serviceName() { @Override public RespT unary(String methodName, ReqT request) { - ClientCall call = ensureMethod(methodName, MethodDescriptor.MethodType.UNARY); + ClientCall call = ensureMethod(methodName, MethodDescriptor.MethodType.UNARY); return ClientCalls.blockingUnaryCall(call, request); } @Override - public StreamObserver unary(String methodName, StreamObserver responseObserver) { + public StreamObserver unary(String methodName, StreamObserver response) { return null; } @Override - public Collection serverStream(String methodName, ReqT request) { - return null; + public Iterator serverStream(String methodName, ReqT request) { + ClientCall call = ensureMethod(methodName, MethodDescriptor.MethodType.SERVER_STREAMING); + return ClientCalls.blockingServerStreamingCall(call, request); } @Override - public void serverStream(String methodName, ReqT request, StreamObserver responseObserver) { + public void serverStream(String methodName, ReqT request, StreamObserver response) { } @Override - public RespT clientStream(String methodName, Collection request) { - return null; + public RespT clientStream(String methodName, Iterator request) { + ClientCall call = ensureMethod(methodName, MethodDescriptor.MethodType.CLIENT_STREAMING); + CompletableFuture future = new CompletableFuture<>(); + StreamObserver observer = ClientCalls.asyncClientStreamingCall(call, new StreamObserver<>() { + private RespT value; + + @Override + public void onNext(RespT value) { + this.value = value; + } + + @Override + public void onError(Throwable t) { + future.completeExceptionally(t); + } + + @Override + public void onCompleted() { + future.complete(value); + } + }); + + // send client stream + while (request.hasNext()) { + observer.onNext(request.next()); + } + observer.onCompleted(); + + // block waiting for response + try { + return future.get(); + } catch (Exception e) { + throw new RuntimeException(e); + } } @Override - public StreamObserver clientStream(String methodName, StreamObserver responseObserver) { + public StreamObserver clientStream(String methodName, StreamObserver response) { return null; } @Override - public Collection bidi(String methodName, Collection responseObserver) { - return null; + public Iterator bidi(String methodName, Iterator request) { + ClientCall call = ensureMethod(methodName, MethodDescriptor.MethodType.BIDI_STREAMING); + CompletableFuture> future = new CompletableFuture<>(); + StreamObserver observer = ClientCalls.asyncBidiStreamingCall(call, new StreamObserver<>() { + private final List values = new ArrayList<>(); + + @Override + public void onNext(RespT value) { + values.add(value); + } + + @Override + public void onError(Throwable t) { + future.completeExceptionally(t); + } + + @Override + public void onCompleted() { + future.complete(values.iterator()); + } + }); + + // send client stream + while (request.hasNext()) { + observer.onNext(request.next()); + } + observer.onCompleted(); + + // block waiting for response + try { + return future.get(); + } catch (Exception e) { + throw new RuntimeException(e); + } } @Override - public StreamObserver bidi(String methodName, StreamObserver responseObserver) { + public StreamObserver bidi(String methodName, StreamObserver response) { return null; } private ClientCall ensureMethod(String methodName, MethodDescriptor.MethodType methodType) { GrpcClientMethodDescriptor method = descriptor.method(methodName); if (!method.type().equals(methodType)) { - throw new IllegalArgumentException("Method " + methodName + " is of type " + method.type() + ", yet " + methodType + " was requested."); + throw new IllegalArgumentException("Method " + methodName + " is of type " + method.type() + + ", yet " + methodType + " was requested."); } - return createClientCall(method); - } - - private ClientCall createClientCall(GrpcClientMethodDescriptor method) { return new GrpcClientCall<>(grpcClient, method); } } diff --git a/webclient/http2/src/main/java/io/helidon/webclient/http2/Http2ClientStream.java b/webclient/http2/src/main/java/io/helidon/webclient/http2/Http2ClientStream.java index 9834b5ee8f8..9e38315ed86 100644 --- a/webclient/http2/src/main/java/io/helidon/webclient/http2/Http2ClientStream.java +++ b/webclient/http2/src/main/java/io/helidon/webclient/http2/Http2ClientStream.java @@ -182,7 +182,7 @@ void trailers(Http2Headers headers, boolean endOfStream) { trailers.complete(headers.httpHeaders()); } - CompletableFuture trailers() { + public CompletableFuture trailers() { return trailers; } @@ -319,7 +319,7 @@ public SocketContext ctx() { return ctx; } - private Http2FrameData readOne(Duration pollTimeout) { + public Http2FrameData readOne(Duration pollTimeout) { Http2FrameData frameData = buffer.poll(pollTimeout); if (frameData != null) {