diff --git a/pkg/remote/trans/nphttp2/client_conn.go b/pkg/remote/trans/nphttp2/client_conn.go index e855d30bfa..28c129c325 100644 --- a/pkg/remote/trans/nphttp2/client_conn.go +++ b/pkg/remote/trans/nphttp2/client_conn.go @@ -27,6 +27,7 @@ import ( "github.com/bytedance/gopkg/lang/dirtmake" + "github.com/cloudwego/kitex/pkg/gofunc" "github.com/cloudwego/kitex/pkg/remote" "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/codes" "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/grpc" @@ -102,6 +103,31 @@ func newClientConn(ctx context.Context, tr grpc.ClientTransport, addr string) (* if err != nil { return nil, err } + // gRPC unary do not need to monitor the stream ctx and transport ctx + // since it must invoke stream.Recv which would inspect the stream.ctx + if isStreaming { + gofunc.GoFunc(ctx, func() { + sCtx := s.Context() + select { + // For these scenarios, stream.ctx would be canceled: + // 1. user invoke cancel() + // 2. parent stream is done + case <-sCtx.Done(): + tr.CloseStream(s, sCtx.Err()) + return + // when http2Client.closeStream is called, stream.Done() would be closed. + // Important: http2Client.closeStream would not lead to stream.ctx canceled. + case <-s.Done(): + // since stream is closed, we just exit without doing anything + return + // For now, t.ctx would not be canceled. + // Pls check pkg/remote/trans/nphttp2/conn_pool for details. + case <-tr.Error(): + tr.CloseStream(s, grpc.ErrConnClosing) + return + } + }) + } return &clientConn{ tr: tr, s: s, diff --git a/pkg/remote/trans/nphttp2/client_conn_test.go b/pkg/remote/trans/nphttp2/client_conn_test.go index be9a7d4560..0a475049ab 100644 --- a/pkg/remote/trans/nphttp2/client_conn_test.go +++ b/pkg/remote/trans/nphttp2/client_conn_test.go @@ -17,7 +17,10 @@ package nphttp2 import ( + "context" + "errors" "testing" + "time" "github.com/cloudwego/kitex/internal/test" "github.com/cloudwego/kitex/pkg/rpcinfo" @@ -96,3 +99,46 @@ func Test_fullMethodName(t *testing.T) { test.Assert(t, got == "/pkg.svc/method") }) } + +func Test_streamingCancel(t *testing.T) { + t.Run("unary method", func(t *testing.T) { + // unary method + pool := newMockConnPool() + ctx, cancel := context.WithCancel(context.Background()) + ctx = rpcinfo.NewCtxWithRPCInfo(ctx, newMockRPCInfo(false)) + conn, err := pool.Get(ctx, "tcp", mockAddr0, newMockConnOption()) + test.Assert(t, err == nil, err) + cliConn, ok := conn.(*clientConn) + test.Assert(t, ok) + cancel() + st := cliConn.s + test.Assert(t, errors.Is(st.Context().Err(), context.Canceled)) + // Wait a while in case the monitor goroutine is actually created. + time.Sleep(10 * time.Millisecond) + select { + case <-st.Done(): + t.Fatal("stream.Done() should not be closed") + default: + } + }) + + t.Run("non-unary method", func(t *testing.T) { + // non-unary method + pool := newMockConnPool() + ctx, cancel := context.WithCancel(context.Background()) + ctx = rpcinfo.NewCtxWithRPCInfo(ctx, newMockRPCInfo(true)) + conn, err := pool.Get(ctx, "tcp", mockAddr1, newMockConnOption()) + test.Assert(t, err == nil, err) + cliConn, ok := conn.(*clientConn) + test.Assert(t, ok) + cancel() + st := cliConn.s + test.Assert(t, errors.Is(st.Context().Err(), context.Canceled)) + timer := time.NewTimer(10 * time.Second) + select { + case <-st.Done(): + case <-timer.C: + t.Fatal("stream.Done() should be closed quickly") + } + }) +} diff --git a/pkg/remote/trans/nphttp2/client_handler_test.go b/pkg/remote/trans/nphttp2/client_handler_test.go index f30f9e8efc..8fd16407ac 100644 --- a/pkg/remote/trans/nphttp2/client_handler_test.go +++ b/pkg/remote/trans/nphttp2/client_handler_test.go @@ -40,7 +40,7 @@ func TestClientHandler(t *testing.T) { return remote.NewProtocolInfo(transport.PurePayload, serviceinfo.Protobuf) } msg.RPCInfoFunc = func() rpcinfo.RPCInfo { - return newMockRPCInfo() + return newMockRPCInfo(false) } conn, err := opt.ConnPool.Get(ctx, "tcp", mockAddr0, remote.ConnOption{Dialer: opt.Dialer, ConnectTimeout: time.Second}) test.Assert(t, err == nil, err) diff --git a/pkg/remote/trans/nphttp2/mocks_test.go b/pkg/remote/trans/nphttp2/mocks_test.go index f77195dcd2..7a15397cec 100644 --- a/pkg/remote/trans/nphttp2/mocks_test.go +++ b/pkg/remote/trans/nphttp2/mocks_test.go @@ -308,7 +308,7 @@ func newMockServerOption() *remote.ServerOption { MaxConnectionIdleTime: 0, ReadWriteTimeout: 0, InitOrResetRPCInfoFunc: func(ri rpcinfo.RPCInfo, addr net.Addr) rpcinfo.RPCInfo { - return newMockRPCInfo() + return newMockRPCInfo(false) }, TracerCtl: &rpcinfo.TraceController{}, GRPCCfg: &grpc.ServerConfig{ @@ -361,10 +361,10 @@ func newMockDialerWithDialFunc(dialFunc func(network, address string, timeout ti } func newMockCtxWithRPCInfo() context.Context { - return rpcinfo.NewCtxWithRPCInfo(context.Background(), newMockRPCInfo()) + return rpcinfo.NewCtxWithRPCInfo(context.Background(), newMockRPCInfo(false)) } -func newMockRPCInfo() rpcinfo.RPCInfo { +func newMockRPCInfo(isStreaming bool) rpcinfo.RPCInfo { method := "method" c := rpcinfo.NewEndpointInfo("", method, nil, nil) endpointTags := map[string]string{} @@ -372,6 +372,9 @@ func newMockRPCInfo() rpcinfo.RPCInfo { s := rpcinfo.NewEndpointInfo("", method, nil, endpointTags) ink := rpcinfo.NewInvocation("", method) cfg := rpcinfo.NewRPCConfig() + if isStreaming { + cfg.(rpcinfo.MutableRPCConfig).SetInteractionMode(rpcinfo.Streaming) + } ri := rpcinfo.NewRPCInfo(c, s, ink, cfg, rpcinfo.NewRPCStats()) return ri } diff --git a/pkg/remote/trans/nphttp2/server_conn_test.go b/pkg/remote/trans/nphttp2/server_conn_test.go index 1b0d11560f..1163098a0e 100644 --- a/pkg/remote/trans/nphttp2/server_conn_test.go +++ b/pkg/remote/trans/nphttp2/server_conn_test.go @@ -147,7 +147,7 @@ func TestGetServerConn(t *testing.T) { SvcSearcher: mock_remote.NewDefaultSvcSearcher(), GRPCCfg: grpc.DefaultServerConfig(), InitOrResetRPCInfoFunc: func(info rpcinfo.RPCInfo, addr net.Addr) rpcinfo.RPCInfo { - return newMockRPCInfo() + return newMockRPCInfo(false) }, TracerCtl: &rpcinfo.TraceController{}, }) diff --git a/pkg/remote/trans/nphttp2/server_handler_test.go b/pkg/remote/trans/nphttp2/server_handler_test.go index 4d38c58447..d2493c98a8 100644 --- a/pkg/remote/trans/nphttp2/server_handler_test.go +++ b/pkg/remote/trans/nphttp2/server_handler_test.go @@ -84,7 +84,7 @@ func TestServerHandler(t *testing.T) { return remote.NewProtocolInfo(transport.PurePayload, serviceinfo.Protobuf) } msg.RPCInfoFunc = func() rpcinfo.RPCInfo { - return newMockRPCInfo() + return newMockRPCInfo(false) } npConn := newMockNpConn(mockAddr0) npConn.mockSettingFrame()