diff --git a/internal/conn/conn.go b/internal/conn/conn.go index 66278bb94..b97456927 100644 --- a/internal/conn/conn.go +++ b/internal/conn/conn.go @@ -56,6 +56,7 @@ type conn struct { closed bool state atomic.Uint32 lastUsage *lastUsage + childStreams *xcontext.CancelsGuard onClose []func(*conn) onTransportErrors []func(ctx context.Context, cc Conn, cause error) } @@ -392,7 +393,7 @@ func (c *conn) NewStream( desc *grpc.StreamDesc, method string, opts ...grpc.CallOption, -) (_ grpc.ClientStream, err error) { +) (_ grpc.ClientStream, finalErr error) { var ( onDone = trace.DriverOnConnNewStream( c.config.Trace(), &ctx, @@ -400,15 +401,13 @@ func (c *conn) NewStream( c.endpoint.Copy(), trace.Method(method), ) useWrapping = UseWrapping(ctx) - cc *grpc.ClientConn - s grpc.ClientStream ) defer func() { - onDone(err, c.GetState()) + onDone(finalErr, c.GetState()) }() - cc, err = c.realConn(ctx) + cc, err := c.realConn(ctx) if err != nil { return nil, c.wrapError(err) } @@ -423,7 +422,19 @@ func (c *conn) NewStream( ctx, sentMark := markContext(meta.WithTraceID(ctx, traceID)) - s, err = cc.NewStream(ctx, desc, method, opts...) + ctx, cancel := xcontext.WithCancel(ctx) + defer func() { + if finalErr != nil { + cancel() + } else { + c.childStreams.Remember(&cancel) + } + }() + + s, err := cc.NewStream(ctx, desc, method, append(opts, grpc.OnFinish(func(err error) { + cancel() + c.childStreams.Forget(&cancel) + }))...) if err != nil { if xerrors.IsContextError(err) { return nil, xerrors.WithStackTrace(err) @@ -490,10 +501,16 @@ func withOnTransportError(onTransportError func(ctx context.Context, cc Conn, ca func newConn(e endpoint.Endpoint, config Config, opts ...option) *conn { c := &conn{ - endpoint: e, - config: config, - done: make(chan struct{}), - lastUsage: newLastUsage(nil), + endpoint: e, + config: config, + done: make(chan struct{}), + lastUsage: newLastUsage(nil), + childStreams: xcontext.NewCancelsGuard(), + onClose: []func(*conn){ + func(c *conn) { + c.childStreams.Cancel() + }, + }, } c.state.Store(uint32(Created)) for _, opt := range opts { diff --git a/internal/xcontext/cancels_quard.go b/internal/xcontext/cancels_quard.go new file mode 100644 index 000000000..7fd344696 --- /dev/null +++ b/internal/xcontext/cancels_quard.go @@ -0,0 +1,38 @@ +package xcontext + +import ( + "context" + "sync" +) + +type CancelsGuard struct { + mu sync.Mutex + cancels map[*context.CancelFunc]struct{} +} + +func NewCancelsGuard() *CancelsGuard { + return &CancelsGuard{ + cancels: make(map[*context.CancelFunc]struct{}), + } +} + +func (g *CancelsGuard) Remember(cancel *context.CancelFunc) { + g.mu.Lock() + defer g.mu.Unlock() + g.cancels[cancel] = struct{}{} +} + +func (g *CancelsGuard) Forget(cancel *context.CancelFunc) { + g.mu.Lock() + defer g.mu.Unlock() + delete(g.cancels, cancel) +} + +func (g *CancelsGuard) Cancel() { + g.mu.Lock() + defer g.mu.Unlock() + for cancel := range g.cancels { + (*cancel)() + } + g.cancels = make(map[*context.CancelFunc]struct{}) +} diff --git a/internal/xcontext/cancels_quard_test.go b/internal/xcontext/cancels_quard_test.go new file mode 100644 index 000000000..98c2faf2c --- /dev/null +++ b/internal/xcontext/cancels_quard_test.go @@ -0,0 +1,24 @@ +package xcontext + +import ( + "testing" + + "github.com/stretchr/testify/require" + "golang.org/x/net/context" +) + +func TestCancelsGuard(t *testing.T) { + g := NewCancelsGuard() + ctx, cancel1 := context.WithCancel(context.Background()) + g.Remember(&cancel1) + require.Len(t, g.cancels, 1) + g.Forget(&cancel1) + require.Empty(t, g.cancels, 0) + cancel2 := context.CancelFunc(func() { + cancel1() + }) + g.Remember(&cancel2) + require.Len(t, g.cancels, 1) + g.Cancel() + require.Error(t, ctx.Err()) +}