Skip to content

Commit

Permalink
Simple unary method call without using stubs.
Browse files Browse the repository at this point in the history
Signed-off-by: Santiago Pericasgeertsen <[email protected]>
  • Loading branch information
spericas committed Feb 7, 2024
1 parent 286812a commit a22f8fe
Show file tree
Hide file tree
Showing 6 changed files with 148 additions and 57 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,10 @@
import io.helidon.webserver.grpc.GrpcRouting;
import io.helidon.webserver.testing.junit5.ServerTest;
import io.helidon.webserver.testing.junit5.SetUpServer;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.RepeatedTest;

import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.MatcherAssert.assertThat;

@ServerTest
class GrpcTest {
Expand All @@ -53,7 +56,7 @@ private static Strings.StringMessage blockingGrpcUpper(Strings.StringMessage req
.build();
}

@Test
@RepeatedTest(3)
void testSimpleCall() {
GrpcServiceDescriptor serviceDescriptor =
GrpcServiceDescriptor.builder()
Expand All @@ -65,7 +68,9 @@ void testSimpleCall() {
.build())
.build();

String r = grpcClient.serviceClient(serviceDescriptor)
StringValue r = grpcClient.serviceClient(serviceDescriptor)
.unary("Upper", StringValue.of("hello"));
System.out.println("r = " + r.getValue());
assertThat(r.getValue(), is("HELLO"));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,22 +16,32 @@

package io.helidon.webclient.grpc;

import java.io.InputStream;
import java.time.Duration;
import java.util.Collections;
import java.util.Queue;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;

import io.grpc.ClientCall;
import io.grpc.Metadata;
import io.grpc.MethodDescriptor;
import io.grpc.Status;
import io.helidon.common.buffers.BufferData;
import io.helidon.common.tls.Tls;
import io.helidon.http.Header;
import io.helidon.http.HeaderNames;
import io.helidon.http.HeaderValues;
import io.helidon.http.WritableHeaders;
import io.helidon.http.http2.Http2Headers;
import io.helidon.http.http2.Http2Settings;
import io.helidon.webclient.api.ClientConnection;
import io.helidon.webclient.api.ClientUri;
import io.helidon.webclient.api.ConnectionKey;
import io.helidon.webclient.api.DefaultDnsResolver;
import io.helidon.webclient.api.DnsAddressLookup;
import io.helidon.webclient.api.Proxy;
import io.helidon.webclient.api.TcpClientConnection;
import io.helidon.webclient.api.WebClient;
import io.helidon.webclient.http2.Http2ClientConnection;
Expand All @@ -41,102 +51,155 @@
/**
* A gRPC client call handler. The typical order of calls will be:
*
* start request* sendMessage* halfClose
* start request* sendMessage* halfClose
*
* TODO: memory synchronization across method calls
*
* @param <ReqT>
* @param <ResT>
*/
class GrpcClientCall<ReqT, ResT> extends ClientCall<ReqT, ResT> {
private static final Header GRPC_ACCEPT_ENCODING = HeaderValues.create(HeaderNames.ACCEPT_ENCODING, "gzip");
private static final Header GRPC_CONTENT_TYPE = HeaderValues.create(HeaderNames.CONTENT_TYPE, "application/grpc");

private final AtomicReference<Listener<ResT>> responseListener = new AtomicReference<>();
private final GrpcClientImpl grpcClient;
private final GrpcClientMethodDescriptor method;
private final AtomicInteger messages = new AtomicInteger();

private final MethodDescriptor.Marshaller<ReqT> requestMarshaller;
private final MethodDescriptor.Marshaller<ResT> responseMarshaller;
private final Queue<BufferData> messageQueue = new LinkedBlockingQueue<>();

private volatile Http2ClientConnection connection;
private volatile GrpcClientStream clientStream;
private volatile Listener<ResT> responseListener;

@SuppressWarnings("unchecked")
GrpcClientCall(GrpcClientImpl grpcClient, GrpcClientMethodDescriptor method) {
this.grpcClient = grpcClient;
this.method = method;
this.requestMarshaller = (MethodDescriptor.Marshaller<ReqT>) method.descriptor().getRequestMarshaller();
this.responseMarshaller = (MethodDescriptor.Marshaller<ResT>) method.descriptor().getResponseMarshaller();
}

@Override
public void start(Listener<ResT> responseListener, Metadata headers) {
if (this.responseListener.compareAndSet(null, responseListener)) {
// obtain HTTP2 connection
Http2ClientConnection connection = Http2ClientConnection.create(
(Http2ClientImpl) grpcClient.http2Client(), clientConnection(), true);

// create HTTP2 stream from connection
GrpcClientStream clientStream = new GrpcClientStream(
connection,
Http2Settings.create(), // Http2Settings
null, // SocketContext
new Http2StreamConfig() {
@Override
public boolean priorKnowledge() {
return true;
}

@Override
public int priority() {
return 0;
}

@Override
public Duration readTimeout() {
return grpcClient.prototype().readTimeout().orElse(null);
}
},
null, // Http2ClientConfig
connection.streamIdSequence());

// send HEADERS frame
} else {
throw new IllegalStateException("Response listener was already set");
}
public void start(Listener<ResT> responseListener, Metadata metadata) {
this.responseListener = responseListener;

// obtain HTTP2 connection
connection = Http2ClientConnection.create((Http2ClientImpl) grpcClient.http2Client(),
clientConnection(), true);

// create HTTP2 stream from connection
clientStream = new GrpcClientStream(
connection,
Http2Settings.create(), // Http2Settings
null, // SocketContext
new Http2StreamConfig() {
@Override
public boolean priorKnowledge() {
return true;
}

@Override
public int priority() {
return 0;
}

@Override
public Duration readTimeout() {
return grpcClient.prototype().readTimeout().orElse(Duration.ofSeconds(60));
}
},
null, // Http2ClientConfig
connection.streamIdSequence());

// send HEADERS frame
ClientUri clientUri = grpcClient.prototype().baseUri().orElseThrow();
WritableHeaders<?> headers = WritableHeaders.create();
headers.add(Http2Headers.AUTHORITY_NAME, clientUri.authority());
headers.add(Http2Headers.METHOD_NAME, "POST");
headers.add(Http2Headers.PATH_NAME, "/" + method.descriptor().getFullMethodName());
headers.add(Http2Headers.SCHEME_NAME, "http");
headers.add(GRPC_CONTENT_TYPE);
headers.add(GRPC_ACCEPT_ENCODING);
clientStream.writeHeaders(Http2Headers.create(headers), false);
}

@Override
public void request(int numMessages) {
messages.addAndGet(numMessages);

ExecutorService executor = grpcClient.webClient().executor();
executor.submit(() -> {
clientStream.readHeaders();
while (messages.decrementAndGet() > 0) {
BufferData bufferData = clientStream.read();
bufferData.read(); // compression
bufferData.readUnsignedInt32(); // length prefixed
ResT res = responseMarshaller.parse(new InputStream() {
@Override
public int read() {
return bufferData.available() > 0 ? bufferData.read() : -1;
}
});
responseListener.onMessage(res);
}
responseListener.onClose(Status.OK, new Metadata());
clientStream.close();
connection.close();
});
}

@Override
public void cancel(String message, Throwable cause) {
// close the stream/connection via RST_STREAM
// can be closed even if halfClosed
messageQueue.clear();
clientStream.cancel();
connection.close();
}

@Override
public void halfClose() {
// close the stream/connection
// GOAWAY frame
// drain the message queue
while (!messageQueue.isEmpty()) {
BufferData msg = messageQueue.poll();
clientStream.writeData(msg, messageQueue.isEmpty());
}
}

@Override
public void sendMessage(ReqT message) {
// send a DATA frame
// queue a message
BufferData messageData = BufferData.growing(512);
messageData.readFrom(requestMarshaller.stream(message));
BufferData headerData = BufferData.create(5);
headerData.writeInt8(0); // no compression
headerData.writeUnsignedInt32(messageData.available()); // length prefixed
messageQueue.add(BufferData.create(headerData, messageData));
}

private ClientConnection clientConnection() {
GrpcClientConfig clientConfig = grpcClient.prototype();
ClientUri clientUri = clientConfig.baseUri().orElseThrow();
WebClient webClient = grpcClient.webClient();

Tls tls = Tls.builder().enabled(false).build();
ConnectionKey connectionKey = new ConnectionKey(
clientUri.scheme(),
clientUri.host(),
clientUri.port(),
clientConfig.readTimeout().orElse(null),
null,
clientConfig.readTimeout().orElse(Duration.ZERO),
tls,
DefaultDnsResolver.create(),
DnsAddressLookup.defaultLookup(),
null);
Proxy.noProxy());

return TcpClientConnection.create(webClient,
connectionKey,
Collections.emptyList(),
connection -> false,
connection -> {}).connect();
connection -> {
}).connect();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@

package io.helidon.webclient.grpc;

import io.helidon.common.buffers.BufferData;
import io.helidon.common.socket.SocketContext;
import io.helidon.http.http2.Http2FrameHeader;
import io.helidon.http.http2.Http2Headers;
import io.helidon.http.http2.Http2Settings;
import io.helidon.webclient.http2.Http2ClientConfig;
import io.helidon.webclient.http2.Http2ClientConnection;
Expand All @@ -34,4 +37,24 @@ class GrpcClientStream extends Http2ClientStream {
LockingStreamIdSequence streamIdSeq) {
super(connection, serverSettings, ctx, http2StreamConfig, http2ClientConfig, streamIdSeq);
}

@Override
public void headers(Http2Headers headers, boolean endOfStream) {
super.headers(headers, endOfStream);
}

@Override
public void data(Http2FrameHeader header, BufferData data, boolean endOfStream) {
super.data(header, data, endOfStream);
}

@Override
public void cancel() {
super.cancel();
}

@Override
public void close() {
super.close();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ void updateLastStreamId(int lastStreamId) {
this.lastStreamId = lastStreamId;
}

void close() {
public void close() {
this.goAway(0, Http2ErrorCode.NO_ERROR, "Closing connection");
if (state.getAndSet(State.CLOSED) != State.CLOSED) {
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ boolean hasEntity() {
return hasEntity;
}

void cancel() {
public void cancel() {
if (NON_CANCELABLE.contains(state)) {
return;
}
Expand All @@ -206,7 +206,7 @@ void cancel() {
}
}

void close() {
public void close() {
connection.removeStream(streamId);
}

Expand All @@ -227,7 +227,7 @@ BufferData read(int i) {
return read();
}

BufferData read() {
public BufferData read() {
while (state == Http2StreamState.HALF_CLOSED_LOCAL && readState != ReadState.END && hasEntity) {
Http2FrameData frameData = readOne(timeout);
if (frameData != null) {
Expand Down Expand Up @@ -258,7 +258,7 @@ Status waitFor100Continue() {
return null;
}

void writeHeaders(Http2Headers http2Headers, boolean endOfStream) {
public void writeHeaders(Http2Headers http2Headers, boolean endOfStream) {
this.state = Http2StreamState.checkAndGetState(this.state, Http2FrameType.HEADERS, true, endOfStream, true);
this.readState = readState.check(http2Headers.httpHeaders().contains(HeaderValues.EXPECT_100)
? ReadState.CONTINUE_100_HEADERS
Expand Down Expand Up @@ -294,7 +294,7 @@ void writeHeaders(Http2Headers http2Headers, boolean endOfStream) {
}
}

void writeData(BufferData entityBytes, boolean endOfStream) {
public void writeData(BufferData entityBytes, boolean endOfStream) {
Http2FrameHeader frameHeader = Http2FrameHeader.create(entityBytes.available(),
Http2FrameTypes.DATA,
Http2Flag.DataFlags.create(endOfStream
Expand All @@ -305,7 +305,7 @@ void writeData(BufferData entityBytes, boolean endOfStream) {
splitAndWrite(frameData);
}

Http2Headers readHeaders() {
public Http2Headers readHeaders() {
while (readState == ReadState.HEADERS) {
Http2FrameData frameData = readOne(timeout);
if (frameData != null) {
Expand All @@ -315,7 +315,7 @@ Http2Headers readHeaders() {
return currentHeaders;
}

SocketContext ctx() {
public SocketContext ctx() {
return ctx;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ public RouterImpl build() {
public Router.Builder addRouting(io.helidon.common.Builder<?, ? extends Routing> routing) {
var previous = this.routings.put(routing.getClass(), routing);
if (previous != null) {
Thread.dumpStack();
// Thread.dumpStack();
LOGGER.log(System.Logger.Level.WARNING, "Second routing of the same type is registered. "
+ "The first instance will be ignored. Type: " + routing.getClass().getName());
}
Expand Down

0 comments on commit a22f8fe

Please sign in to comment.