Skip to content

Commit

Permalink
Basic support for unary and serverStream invocations.
Browse files Browse the repository at this point in the history
Signed-off-by: Santiago Pericasgeertsen <[email protected]>
  • Loading branch information
spericas committed Feb 9, 2024
1 parent a22f8fe commit c892010
Show file tree
Hide file tree
Showing 5 changed files with 200 additions and 65 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@

package io.helidon.examples.webserver.protocols;

import java.util.Iterator;
import java.util.Locale;

import com.google.protobuf.StringValue;
import io.grpc.stub.StreamObserver;
import io.helidon.examples.grpc.strings.Strings;
import io.helidon.webclient.grpc.GrpcClient;
import io.helidon.webclient.grpc.GrpcClientMethodDescriptor;
Expand All @@ -27,7 +29,7 @@
import io.helidon.webserver.grpc.GrpcRouting;
import io.helidon.webserver.testing.junit5.ServerTest;
import io.helidon.webserver.testing.junit5.SetUpServer;
import org.junit.jupiter.api.RepeatedTest;
import org.junit.jupiter.api.Test;

import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.MatcherAssert.assertThat;
Expand All @@ -44,20 +46,36 @@ private GrpcTest(GrpcClient grpcClient) {
@SetUpServer
public static void setup(WebServerConfig.Builder builder) {
builder.addRouting(GrpcRouting.builder()
.unary(Strings.getDescriptor(),
"StringService",
"Upper",
GrpcTest::blockingGrpcUpper));
.unary(Strings.getDescriptor(),
"StringService",
"Upper",
GrpcTest::upper)
.serverStream(Strings.getDescriptor(),
"StringService",
"Split",
GrpcTest::split));
}

private static Strings.StringMessage blockingGrpcUpper(Strings.StringMessage reqT) {
private static Strings.StringMessage upper(Strings.StringMessage reqT) {
return Strings.StringMessage.newBuilder()
.setText(reqT.getText().toUpperCase(Locale.ROOT))
.build();
}

@RepeatedTest(3)
void testSimpleCall() {
private static void split(Strings.StringMessage reqT,
StreamObserver<Strings.StringMessage> streamObserver) {
String[] strings = reqT.getText().split(" ");
for (String string : strings) {
streamObserver.onNext(Strings.StringMessage.newBuilder()
.setText(string)
.build());

}
streamObserver.onCompleted();
}

@Test
void testUpper() {
GrpcServiceDescriptor serviceDescriptor =
GrpcServiceDescriptor.builder()
.serviceName("StringService")
Expand All @@ -70,7 +88,25 @@ void testSimpleCall() {

StringValue r = grpcClient.serviceClient(serviceDescriptor)
.unary("Upper", StringValue.of("hello"));
System.out.println("r = " + r.getValue());
assertThat(r.getValue(), is("HELLO"));
}

@Test
void testSplit() {
GrpcServiceDescriptor serviceDescriptor =
GrpcServiceDescriptor.builder()
.serviceName("StringService")
.putMethod("Split",
GrpcClientMethodDescriptor.unary("StringService", "Split")
.requestType(StringValue.class)
.responseType(StringValue.class)
.build())
.build();

Iterator<StringValue> iterator = grpcClient.serviceClient(serviceDescriptor)
.serverStream("Split", StringValue.of("hello world"));
assertThat(iterator.next().getValue(), is("hello"));
assertThat(iterator.next().getValue(), is("world"));
assertThat(iterator.hasNext(), is(false));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,11 @@
import java.io.InputStream;
import java.time.Duration;
import java.util.Collections;
import java.util.Queue;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;

import io.grpc.ClientCall;
Expand All @@ -34,8 +36,10 @@
import io.helidon.http.HeaderNames;
import io.helidon.http.HeaderValues;
import io.helidon.http.WritableHeaders;
import io.helidon.http.http2.Http2FrameData;
import io.helidon.http.http2.Http2Headers;
import io.helidon.http.http2.Http2Settings;
import io.helidon.http.http2.Http2StreamState;
import io.helidon.webclient.api.ClientConnection;
import io.helidon.webclient.api.ClientUri;
import io.helidon.webclient.api.ConnectionKey;
Expand All @@ -47,39 +51,56 @@
import io.helidon.webclient.http2.Http2ClientConnection;
import io.helidon.webclient.http2.Http2ClientImpl;
import io.helidon.webclient.http2.Http2StreamConfig;
import io.helidon.webclient.http2.StreamTimeoutException;

import static java.lang.System.Logger.Level.DEBUG;

/**
* A gRPC client call handler. The typical order of calls will be:
*
* start request* sendMessage* halfClose
*
* TODO: memory synchronization across method calls
* start (request | sendMessage)* (halfClose | cancel)
*
* @param <ReqT>
* @param <ResT>
*/
class GrpcClientCall<ReqT, ResT> extends ClientCall<ReqT, ResT> {
private static final System.Logger LOGGER = System.getLogger(GrpcClientCall.class.getName());

private static final Header GRPC_ACCEPT_ENCODING = HeaderValues.create(HeaderNames.ACCEPT_ENCODING, "gzip");
private static final Header GRPC_CONTENT_TYPE = HeaderValues.create(HeaderNames.CONTENT_TYPE, "application/grpc");

private static final int WAIT_TIME_MILLIS = 100;
private static final Duration WAIT_TIME_MILLIS_DURATION = Duration.ofMillis(WAIT_TIME_MILLIS);

private static final BufferData EMPTY_BUFFER_DATA = BufferData.empty();

private final ExecutorService executor;
private final GrpcClientImpl grpcClient;
private final GrpcClientMethodDescriptor method;
private final AtomicInteger messages = new AtomicInteger();
private final AtomicInteger messageRequest = new AtomicInteger();

private final MethodDescriptor.Marshaller<ReqT> requestMarshaller;
private final MethodDescriptor.Marshaller<ResT> responseMarshaller;
private final Queue<BufferData> messageQueue = new LinkedBlockingQueue<>();

private final LinkedBlockingQueue<BufferData> sendingQueue = new LinkedBlockingQueue<>();
private final LinkedBlockingQueue<BufferData> receivingQueue = new LinkedBlockingQueue<>();

private final CountDownLatch startReadBarrier = new CountDownLatch(1);
private final CountDownLatch startWriteBarrier = new CountDownLatch(1);

private volatile Http2ClientConnection connection;
private volatile GrpcClientStream clientStream;
private volatile Listener<ResT> responseListener;
private volatile Future<?> readStreamFuture;
private volatile Future<?> writeStreamFuture;

@SuppressWarnings("unchecked")
GrpcClientCall(GrpcClientImpl grpcClient, GrpcClientMethodDescriptor method) {
this.grpcClient = grpcClient;
this.method = method;
this.requestMarshaller = (MethodDescriptor.Marshaller<ReqT>) method.descriptor().getRequestMarshaller();
this.responseMarshaller = (MethodDescriptor.Marshaller<ResT>) method.descriptor().getResponseMarshaller();
this.executor = grpcClient.webClient().executor();
}

@Override
Expand Down Expand Up @@ -114,6 +135,9 @@ public Duration readTimeout() {
null, // Http2ClientConfig
connection.streamIdSequence());

// start streaming threads
startStreamingThreads();

// send HEADERS frame
ClientUri clientUri = grpcClient.prototype().baseUri().orElseThrow();
WritableHeaders<?> headers = WritableHeaders.create();
Expand All @@ -128,44 +152,20 @@ public Duration readTimeout() {

@Override
public void request(int numMessages) {
messages.addAndGet(numMessages);

ExecutorService executor = grpcClient.webClient().executor();
executor.submit(() -> {
clientStream.readHeaders();
while (messages.decrementAndGet() > 0) {
BufferData bufferData = clientStream.read();
bufferData.read(); // compression
bufferData.readUnsignedInt32(); // length prefixed
ResT res = responseMarshaller.parse(new InputStream() {
@Override
public int read() {
return bufferData.available() > 0 ? bufferData.read() : -1;
}
});
responseListener.onMessage(res);
}
responseListener.onClose(Status.OK, new Metadata());
clientStream.close();
connection.close();
});
messageRequest.addAndGet(numMessages);
LOGGER.log(DEBUG, () -> "Messages requested " + numMessages);
startReadBarrier.countDown();
}

@Override
public void cancel(String message, Throwable cause) {
// close the stream/connection via RST_STREAM
messageQueue.clear();
clientStream.cancel();
connection.close();
responseListener.onClose(Status.CANCELLED, new Metadata());
close();
}

@Override
public void halfClose() {
// drain the message queue
while (!messageQueue.isEmpty()) {
BufferData msg = messageQueue.poll();
clientStream.writeData(msg, messageQueue.isEmpty());
}
sendingQueue.add(EMPTY_BUFFER_DATA); // end marker
}

@Override
Expand All @@ -176,7 +176,84 @@ public void sendMessage(ReqT message) {
BufferData headerData = BufferData.create(5);
headerData.writeInt8(0); // no compression
headerData.writeUnsignedInt32(messageData.available()); // length prefixed
messageQueue.add(BufferData.create(headerData, messageData));
sendingQueue.add(BufferData.create(headerData, messageData));
startWriteBarrier.countDown();
}

private void startStreamingThreads() {
// write streaming thread
writeStreamFuture = executor.submit(() -> {
try {
startWriteBarrier.await();
LOGGER.log(DEBUG, "[Writing thread] started");

while (isRemoteOpen()) {
LOGGER.log(DEBUG, "[Writing thread] polling sending queue");
BufferData bufferData = sendingQueue.poll(WAIT_TIME_MILLIS, TimeUnit.MILLISECONDS);
if (bufferData != null) {
if (bufferData == EMPTY_BUFFER_DATA) { // end marker
LOGGER.log(DEBUG, "[Writing thread] sending queue end marker found");
break;
}
boolean endOfStream = (sendingQueue.peek() == EMPTY_BUFFER_DATA);
LOGGER.log(DEBUG, () -> "[Writing thread] writing bufferData " + endOfStream);
clientStream.writeData(bufferData, endOfStream);
}
}
} catch (InterruptedException e) {
// falls through
}
LOGGER.log(DEBUG, "[Writing thread] exiting");
});

// read streaming thread
readStreamFuture = executor.submit(() -> {
try {
startReadBarrier.await();
LOGGER.log(DEBUG, "[Reading thread] started");

// read response headers
clientStream.readHeaders();

while (isRemoteOpen()) {
// attempt to send queued messages
drainReceivingQueue();

// attempt to read and queue
Http2FrameData frameData;
try {
frameData = clientStream.readOne(WAIT_TIME_MILLIS_DURATION);
} catch (StreamTimeoutException e) {
LOGGER.log(DEBUG, "[Reading thread] read timeout");
continue;
}
if (frameData != null) {
receivingQueue.add(frameData.data());
LOGGER.log(DEBUG, "[Reading thread] adding bufferData to receiving queue");
}

// trailers received?
if (clientStream.trailers().isDone()) {
drainReceivingQueue(); // one more attempt
break;
}
}

responseListener.onClose(Status.OK, new Metadata());
close();
} catch (InterruptedException e) {
// falls through
}
LOGGER.log(DEBUG, "[Reading thread] exiting");
});
}

private void close() {
readStreamFuture.cancel(true);
writeStreamFuture.cancel(true);
sendingQueue.clear();
clientStream.cancel();
connection.close();
}

private ClientConnection clientConnection() {
Expand All @@ -202,4 +279,29 @@ private ClientConnection clientConnection() {
connection -> {
}).connect();
}

private boolean isRemoteOpen() {
return clientStream.streamState() != Http2StreamState.HALF_CLOSED_REMOTE
&& clientStream.streamState() != Http2StreamState.CLOSED;
}

private ResT toResponse(BufferData bufferData) {
bufferData.read(); // compression
bufferData.readUnsignedInt32(); // length prefixed
return responseMarshaller.parse(new InputStream() {
@Override
public int read() {
return bufferData.available() > 0 ? bufferData.read() : -1;
}
});
}

private void drainReceivingQueue() {
while (messageRequest.get() > 0 && !receivingQueue.isEmpty()) {
messageRequest.getAndDecrement();
ResT res = toResponse(receivingQueue.remove());
LOGGER.log(DEBUG, "[Reading thread] sending response to listener");
responseListener.onMessage(res);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

package io.helidon.webclient.grpc;

import java.util.Collection;
import java.util.Iterator;

import io.grpc.stub.StreamObserver;

Expand All @@ -37,15 +37,15 @@ public interface GrpcServiceClient {

<ReqT, RespT> StreamObserver<ReqT> unary(String methodName, StreamObserver<RespT> responseObserver);

<ReqT, RespT> Collection<RespT> serverStream(String methodName, ReqT request);
<ReqT, RespT> Iterator<RespT> serverStream(String methodName, ReqT request);

<ReqT, RespT> void serverStream(String methodName, ReqT request, StreamObserver<RespT> responseObserver);

<ReqT, RespT> RespT clientStream(String methodName, Collection<ReqT> request);
<ReqT, RespT> RespT clientStream(String methodName, Iterator<ReqT> request);

<ReqT, RespT> StreamObserver<ReqT> clientStream(String methodName, StreamObserver<RespT> responseObserver);

<ReqT, RespT> Collection<RespT> bidi(String methodName, Collection<ReqT> responseObserver);
<ReqT, RespT> Iterator<RespT> bidi(String methodName, Iterator<ReqT> responseObserver);

<ReqT, RespT> StreamObserver<ReqT> bidi(String methodName, StreamObserver<RespT> responseObserver);
}
Loading

0 comments on commit c892010

Please sign in to comment.