diff --git a/src/Grpc.Net.Client/Internal/IGrpcCall.cs b/src/Grpc.Net.Client/Internal/IGrpcCall.cs index d796dd034..c1a79e2c2 100644 --- a/src/Grpc.Net.Client/Internal/IGrpcCall.cs +++ b/src/Grpc.Net.Client/Internal/IGrpcCall.cs @@ -45,5 +45,7 @@ internal interface IGrpcCall : IDisposable void StartDuplexStreaming(); Task WriteClientStreamAsync(Func, Stream, CallOptions, TState, ValueTask> writeFunc, TState state); + + bool Disposed { get; } } } diff --git a/src/Grpc.Net.Client/Internal/Retry/HedgingCall.cs b/src/Grpc.Net.Client/Internal/Retry/HedgingCall.cs index e9c505c7f..01029fa6b 100644 --- a/src/Grpc.Net.Client/Internal/Retry/HedgingCall.cs +++ b/src/Grpc.Net.Client/Internal/Retry/HedgingCall.cs @@ -199,6 +199,13 @@ private async Task StartCall(Action> startCallFunc } } } + + if (CommitedCallTask.IsCompletedSuccessfully() && CommitedCallTask.Result == call) + { + // Wait until the commited call is finished and then clean up hedging call. + await call.CallTask.ConfigureAwait(false); + Cleanup(); + } } protected override void OnCommitCall(IGrpcCall call) diff --git a/src/Grpc.Net.Client/Internal/Retry/RetryCall.cs b/src/Grpc.Net.Client/Internal/Retry/RetryCall.cs index eec52fc2e..199057d5c 100644 --- a/src/Grpc.Net.Client/Internal/Retry/RetryCall.cs +++ b/src/Grpc.Net.Client/Internal/Retry/RetryCall.cs @@ -245,6 +245,16 @@ private async Task StartRetry(Action> startCallFun } finally { + if (CommitedCallTask.IsCompletedSuccessfully()) + { + if (CommitedCallTask.Result is GrpcCall call) + { + // Wait until the commited call is finished and then clean up retry call. + await call.CallTask.ConfigureAwait(false); + Cleanup(); + } + } + Log.StoppingRetryWorker(Logger); } } diff --git a/src/Grpc.Net.Client/Internal/Retry/RetryCallBase.cs b/src/Grpc.Net.Client/Internal/Retry/RetryCallBase.cs index 66373c820..72f05dc8f 100644 --- a/src/Grpc.Net.Client/Internal/Retry/RetryCallBase.cs +++ b/src/Grpc.Net.Client/Internal/Retry/RetryCallBase.cs @@ -43,7 +43,9 @@ internal abstract partial class RetryCallBase : IGrpcCall> _commitedCallTcs; private RetryCallBaseClientStreamReader? _retryBaseClientStreamReader; private RetryCallBaseClientStreamWriter? _retryBaseClientStreamWriter; - private CancellationTokenRegistration? _ctsRegistration; + + // Internal for unit testing. + internal CancellationTokenRegistration? _ctsRegistration; protected object Lock { get; } = new object(); protected ILogger Logger { get; } @@ -52,7 +54,6 @@ internal abstract partial class RetryCallBase : IGrpcCall?>? NewActiveCallTcs { get; set; } - protected bool Disposed { get; private set; } public GrpcChannel Channel { get; } public Task> CommitedCallTask => _commitedCallTcs.Task; @@ -60,6 +61,7 @@ internal abstract partial class RetryCallBase : IGrpcCall? ClientStreamWriter => _retryBaseClientStreamWriter ??= new RetryCallBaseClientStreamWriter(this); public WriteOptions? ClientStreamWriteOptions { get; internal set; } public bool ClientStreamComplete { get; set; } + public bool Disposed { get; private set; } protected int AttemptCount { get; private set; } protected List> BufferedMessages { get; } @@ -345,6 +347,16 @@ protected void CommitCall(IGrpcCall call, CommitReason comm NewActiveCallTcs?.SetResult(null); _commitedCallTcs.SetResult(call); + + // If the commited call has finished and cleaned up then it is safe for + // the wrapping retry call to clean up. This is required to unregister + // from the cancellation token and avoid a memory leak. + // + // A commited call that has already cleaned up is likely a StatusGrpcCall. + if (call.Disposed) + { + Cleanup(); + } } } } @@ -406,18 +418,24 @@ protected virtual void Dispose(bool disposing) if (disposing) { - _ctsRegistration?.Dispose(); - CancellationTokenSource.Cancel(); - if (CommitedCallTask.IsCompletedSuccessfully()) { CommitedCallTask.Result.Dispose(); } - ClearRetryBuffer(); + Cleanup(); } } + protected void Cleanup() + { + _ctsRegistration?.Dispose(); + _ctsRegistration = null; + CancellationTokenSource.Cancel(); + + ClearRetryBuffer(); + } + internal bool TryAddToRetryBuffer(ReadOnlyMemory message) { lock (Lock) diff --git a/src/Grpc.Net.Client/Internal/Retry/StatusGrpcCall.cs b/src/Grpc.Net.Client/Internal/Retry/StatusGrpcCall.cs index e2728cec4..e019e31a6 100644 --- a/src/Grpc.Net.Client/Internal/Retry/StatusGrpcCall.cs +++ b/src/Grpc.Net.Client/Internal/Retry/StatusGrpcCall.cs @@ -38,6 +38,7 @@ internal sealed class StatusGrpcCall : IGrpcCall? ClientStreamWriter => _clientStreamWriter ??= new StatusClientStreamWriter(_status); public IAsyncStreamReader? ClientStreamReader => _clientStreamReader ??= new StatusStreamReader(_status); + public bool Disposed => true; public StatusGrpcCall(Status status) { diff --git a/test/FunctionalTests/Client/RetryTests.cs b/test/FunctionalTests/Client/RetryTests.cs index c142ac0d1..5d6b052d3 100644 --- a/test/FunctionalTests/Client/RetryTests.cs +++ b/test/FunctionalTests/Client/RetryTests.cs @@ -19,6 +19,7 @@ using System; using System.Collections.Generic; using System.Linq; +using System.Runtime.CompilerServices; using System.Threading; using System.Threading.Tasks; using Google.Protobuf; @@ -356,6 +357,63 @@ Task UnaryFailure(DataMessage request, ServerCallContext context) tcs.SetResult(new DataMessage()); } + [Test] + public async Task ServerStreaming_CancellatonTokenSpecified_TokenUnregisteredAndResourcesReleased() + { + Task FakeServerStreamCall(DataMessage request, IServerStreamWriter responseStream, ServerCallContext context) + { + return Task.CompletedTask; + } + + // Arrange + var method = Fixture.DynamicGrpc.AddServerStreamingMethod(FakeServerStreamCall); + + var serviceConfig = ServiceConfigHelpers.CreateRetryServiceConfig(retryableStatusCodes: new List { StatusCode.DeadlineExceeded }); + var channel = CreateChannel(serviceConfig: serviceConfig); + + var references = new List(); + + // Checking that token register calls don't build up on CTS and create a memory leak. + var cts = new CancellationTokenSource(); + + // Act + // Send calls in a different method so there is no chance that a stack reference + // to a gRPC call is still alive after calls are complete. + await MakeCallsAsync(channel, method, references, cts.Token).DefaultTimeout(); + + // Assert + // There is a race when cleaning up cancellation token registry. + // Retry a few times to ensure GC is run after unregister. + await TestHelpers.AssertIsTrueRetryAsync(() => + { + GC.Collect(); + GC.WaitForPendingFinalizers(); + + for (var i = 0; i < references.Count; i++) + { + if (references[i].IsAlive) + { + return false; + } + } + + return true; + }, "Assert that retry call resources are released."); + } + + [MethodImpl(MethodImplOptions.NoInlining)] + private static async Task MakeCallsAsync(GrpcChannel channel, Method method, List references, CancellationToken cancellationToken) + { + var client = TestClientFactory.Create(channel, method); + for (int i = 0; i < 10; i++) + { + var call = client.ServerStreamingCall(new DataMessage(), new CallOptions(cancellationToken: cancellationToken)); + references.Add(new WeakReference(call.ResponseStream)); + + Assert.IsFalse(await call.ResponseStream.MoveNext()); + } + } + [TestCase(1)] [TestCase(20)] public async Task Unary_AttemptsGreaterThanDefaultClientLimit_LimitedAttemptsMade(int hedgingDelay) diff --git a/test/Grpc.Net.Client.Tests/Retry/HedgingCallTests.cs b/test/Grpc.Net.Client.Tests/Retry/HedgingCallTests.cs index fb02e1583..6b3807062 100644 --- a/test/Grpc.Net.Client.Tests/Retry/HedgingCallTests.cs +++ b/test/Grpc.Net.Client.Tests/Retry/HedgingCallTests.cs @@ -331,6 +331,7 @@ public async Task AsyncUnaryCall_CancellationDuringBackoff_CanceledStatus() // Act hedgingCall.StartUnary(new HelloRequest()); + Assert.IsNotNull(hedgingCall._ctsRegistration); // Assert await TestHelpers.AssertIsTrueRetryAsync(() => hedgingCall._activeCalls.Count == 0, "Wait for all calls to fail.").DefaultTimeout(); @@ -340,6 +341,37 @@ public async Task AsyncUnaryCall_CancellationDuringBackoff_CanceledStatus() var ex = await ExceptionAssert.ThrowsAsync(() => hedgingCall.GetResponseAsync()).DefaultTimeout(); Assert.AreEqual(StatusCode.Cancelled, ex.StatusCode); Assert.AreEqual("Call canceled by the client.", ex.Status.Detail); + Assert.IsNull(hedgingCall._ctsRegistration); + } + + [Test] + public async Task AsyncUnaryCall_CancellationTokenSuccess_CleanedUp() + { + // Arrange + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var httpClient = ClientTestHelpers.CreateTestClient(async request => + { + await tcs.Task; + + var reply = new HelloReply { Message = "Hello world" }; + var streamContent = await ClientTestHelpers.CreateResponseContent(reply).DefaultTimeout(); + return ResponseUtils.CreateResponse(HttpStatusCode.OK, streamContent); + }); + var cts = new CancellationTokenSource(); + var serviceConfig = ServiceConfigHelpers.CreateHedgingServiceConfig(hedgingDelay: TimeSpan.FromSeconds(10)); + var invoker = HttpClientCallInvokerFactory.Create(httpClient, serviceConfig: serviceConfig); + var hedgingCall = new HedgingCall(CreateHedgingPolicy(serviceConfig.MethodConfigs[0].HedgingPolicy), invoker.Channel, ClientTestHelpers.ServiceMethod, new CallOptions(cancellationToken: cts.Token)); + + // Act + hedgingCall.StartUnary(new HelloRequest()); + Assert.IsNotNull(hedgingCall._ctsRegistration); + tcs.SetResult(null); + + // Assert + await hedgingCall.GetResponseAsync().DefaultTimeout(); + + // There is a race between unregistering and GetResponseAsync returning. + await TestHelpers.AssertIsTrueRetryAsync(() => hedgingCall._ctsRegistration == null, "Hedge call CTS unregistered."); } [Test]