diff --git a/src/DotNetty.Handlers/Tls/TlsHandler.cs b/src/DotNetty.Handlers/Tls/TlsHandler.cs index 063aa2db9..9976d1611 100644 --- a/src/DotNetty.Handlers/Tls/TlsHandler.cs +++ b/src/DotNetty.Handlers/Tls/TlsHandler.cs @@ -26,7 +26,8 @@ public sealed class TlsHandler : ByteToMessageDecoder const int UnencryptedWriteBatchSize = 14 * 1024; static readonly Exception ChannelClosedException = new IOException("Channel is closed"); - static readonly Action HandshakeCompletionCallback = new Action(HandleHandshakeCompleted); + static readonly Action HandshakeCompletionCallback = new Action(HandleHandshakeCompleted); + static readonly Action, object> UnwrapCompletedCallback = new Action, object>(UnwrapCompleted); readonly SslStream sslStream; readonly MediationStream mediationStream; @@ -38,7 +39,9 @@ public sealed class TlsHandler : ByteToMessageDecoder BatchingPendingWriteQueue pendingUnencryptedWrites; Task lastContextWriteTask; bool firedChannelRead; + volatile FlushMode flushMode = FlushMode.ForceFlush; IByteBuffer pendingSslStreamReadBuffer; + int pendingSslStreamReadLength; Task pendingSslStreamReadFuture; public TlsHandler(TlsSettings settings) @@ -115,9 +118,16 @@ bool IgnoreException(Exception t) return false; } - static void HandleHandshakeCompleted(Task task, object state) + static void HandleHandshakeCompleted(object context, object state) { var self = (TlsHandler)state; + var capturedContext = self.capturedContext; + if (!capturedContext.Executor.InEventLoop) + { + capturedContext.Executor.Execute(HandshakeCompletionCallback, context, state); + return; + } + var task = (Task)context; switch (task.Status) { case TaskStatus.RanToCompletion: @@ -136,8 +146,7 @@ static void HandleHandshakeCompleted(Task task, object state) if (oldState.Has(TlsHandlerState.FlushedBeforeHandshake)) { - self.Wrap(self.capturedContext); - self.capturedContext.Flush(); + self.WrapAndFlush(self.capturedContext); } break; } @@ -319,7 +328,7 @@ void Unwrap(IChannelHandlerContext ctx, IByteBuffer packet, int offset, int leng try { ArraySegment inputIoBuffer = packet.GetIoBuffer(offset, length); - this.mediationStream.SetSource(inputIoBuffer.Array, inputIoBuffer.Offset); + this.mediationStream.SetSource(inputIoBuffer.Array, inputIoBuffer.Offset, ctx.Allocator); int packetIndex = 0; @@ -342,10 +351,11 @@ void Unwrap(IChannelHandlerContext ctx, IByteBuffer packet, int offset, int leng Contract.Assert(this.pendingSslStreamReadBuffer != null); outputBuffer = this.pendingSslStreamReadBuffer; - outputBufferLength = outputBuffer.WritableBytes; + outputBufferLength = this.pendingSslStreamReadLength; this.pendingSslStreamReadFuture = null; this.pendingSslStreamReadBuffer = null; + this.pendingSslStreamReadLength = 0; } else { @@ -358,86 +368,78 @@ void Unwrap(IChannelHandlerContext ctx, IByteBuffer packet, int offset, int leng int currentPacketLength = packetLengths[packetIndex]; this.mediationStream.ExpandSource(currentPacketLength); - if (currentReadFuture != null) + while (true) { - // there was a read pending already, so we make sure we completed that first - - if (!currentReadFuture.IsCompleted) + int totalRead = 0; + if (currentReadFuture != null) { - // we did feed the whole current packet to SslStream yet it did not produce any result -> move to the next packet in input + // there was a read pending already, so we make sure we completed that first - continue; - } + if (!currentReadFuture.IsCompleted) + { + // we did feed the whole current packet to SslStream yet it did not produce any result -> move to the next packet in input - int read = currentReadFuture.Result; + break; + } - if (read == 0) - { - //Stream closed - return; - } + int read = currentReadFuture.Result; + totalRead += read; - // Now output the result of previous read and decide whether to do an extra read on the same source or move forward - AddBufferToOutput(outputBuffer, read, output); + if (read == 0) + { + //Stream closed + return; + } - currentReadFuture = null; - outputBuffer = null; - if (this.mediationStream.SourceReadableBytes == 0) - { - // we just made a frame available for reading but there was already pending read so SslStream read it out to make further progress there + // Now output the result of previous read and decide whether to do an extra read on the same source or move forward + AddBufferToOutput(outputBuffer, read, output); - if (read < outputBufferLength) + currentReadFuture = null; + outputBuffer = null; + if (this.mediationStream.TotalReadableBytes == 0) { - // SslStream returned non-full buffer and there's no more input to go through -> - // typically it means SslStream is done reading current frame so we skip - continue; + // we just made a frame available for reading but there was already pending read so SslStream read it out to make further progress there + + if (read < outputBufferLength) + { + // SslStream returned non-full buffer and there's no more input to go through -> + // typically it means SslStream is done reading current frame so we skip + break; + } + + // we've read out `read` bytes out of current packet to fulfil previously outstanding read + outputBufferLength = currentPacketLength - totalRead; + if (outputBufferLength <= 0) + { + // after feeding to SslStream current frame it read out more bytes than current packet size + outputBufferLength = FallbackReadBufferSize; + } } - - // we've read out `read` bytes out of current packet to fulfil previously outstanding read - outputBufferLength = currentPacketLength - read; - if (outputBufferLength <= 0) + else { - // after feeding to SslStream current frame it read out more bytes than current packet size - outputBufferLength = FallbackReadBufferSize; + // SslStream did not get to reading current frame so it completed previous read sync + // and the next read will likely read out the new frame + outputBufferLength = currentPacketLength; } } else { - // SslStream did not get to reading current frame so it completed previous read sync - // and the next read will likely read out the new frame + // there was no pending read before so we estimate buffer of `currentPacketLength` bytes to be sufficient outputBufferLength = currentPacketLength; } - } - else - { - // there was no pending read before so we estimate buffer of `currentPacketLength` bytes to be sufficient - outputBufferLength = currentPacketLength; - } - outputBuffer = ctx.Allocator.Buffer(outputBufferLength); - currentReadFuture = this.ReadFromSslStreamAsync(outputBuffer, outputBufferLength); + outputBuffer = ctx.Allocator.Buffer(outputBufferLength); + currentReadFuture = this.ReadFromSslStreamAsync(outputBuffer, outputBufferLength); + } } - // read out the rest of SslStream's output (if any) at risk of going async - // using FallbackReadBufferSize - buffer size we're ok to have pinned with the SslStream until it's done reading - while (true) + if (currentReadFuture != null) { - if (currentReadFuture != null) - { - if (!currentReadFuture.IsCompleted) - { - break; - } - int read = currentReadFuture.Result; - AddBufferToOutput(outputBuffer, read, output); - } - outputBuffer = ctx.Allocator.Buffer(FallbackReadBufferSize); - currentReadFuture = this.ReadFromSslStreamAsync(outputBuffer, FallbackReadBufferSize); + pending = true; + this.pendingSslStreamReadBuffer = outputBuffer; + this.pendingSslStreamReadFuture = currentReadFuture; + this.pendingSslStreamReadLength = outputBufferLength; } - - pending = true; - this.pendingSslStreamReadBuffer = outputBuffer; - this.pendingSslStreamReadFuture = currentReadFuture; } catch (Exception ex) { @@ -446,7 +448,7 @@ void Unwrap(IChannelHandlerContext ctx, IByteBuffer packet, int offset, int leng } finally { - this.mediationStream.ResetSource(); + this.mediationStream.ResetSource(ctx.Allocator); if (!pending && outputBuffer != null) { if (outputBuffer.IsReadable()) @@ -458,6 +460,91 @@ void Unwrap(IChannelHandlerContext ctx, IByteBuffer packet, int offset, int leng outputBuffer.SafeRelease(); } } + + if (pending) + { + //Can't use ExecuteSynchronously here for it may change the order of output if task is already completed here. + this.pendingSslStreamReadFuture?.ContinueWith(UnwrapCompletedCallback, this, TaskContinuationOptions.None); + } + } + } + + static void UnwrapCompleted(Task task, object state) + { + // Mono(with legacy provider) finish ReadAsync in async, + // so extra check is needed to receive data in async + var self = (TlsHandler)state; + Debug.Assert(self.capturedContext.Executor.InEventLoop); + + //Ignore task completed in Unwrap + if (task == self.pendingSslStreamReadFuture) + { + IByteBuffer buf = self.pendingSslStreamReadBuffer; + int outputBufferLength = self.pendingSslStreamReadLength; + + self.pendingSslStreamReadFuture = null; + self.pendingSslStreamReadBuffer = null; + self.pendingSslStreamReadLength = 0; + + while (true) + { + switch (task.Status) + { + case TaskStatus.RanToCompletion: + { + //The logic is the same as the one in Unwrap() + var read = task.Result; + //Stream Closed + if (read == 0) + return; + self.capturedContext.FireChannelRead(buf.SetWriterIndex(buf.WriterIndex + read)); + + if (self.mediationStream.TotalReadableBytes == 0) + { + self.capturedContext.FireChannelReadComplete(); + self.mediationStream.ResetSource(self.capturedContext.Allocator); + + if (read < outputBufferLength) + { + // SslStream returned non-full buffer and there's no more input to go through -> + // typically it means SslStream is done reading current frame so we skip + return; + } + } + + outputBufferLength = self.mediationStream.TotalReadableBytes; + if (outputBufferLength <= 0) + outputBufferLength = FallbackReadBufferSize; + + buf = self.capturedContext.Allocator.Buffer(outputBufferLength); + task = self.ReadFromSslStreamAsync(buf, outputBufferLength); + if (task.IsCompleted) + { + continue; + } + + self.pendingSslStreamReadFuture = task; + self.pendingSslStreamReadBuffer = buf; + self.pendingSslStreamReadLength = outputBufferLength; + task.ContinueWith(UnwrapCompletedCallback, self, TaskContinuationOptions.ExecuteSynchronously); + return; + } + + case TaskStatus.Canceled: + case TaskStatus.Faulted: + { + buf.SafeRelease(); + self.HandleFailure(task.Exception); + return; + } + + default: + { + buf.SafeRelease(); + throw new ArgumentOutOfRangeException(nameof(task), "Unexpected task status: " + task.Status); + } + } + } } } @@ -530,6 +617,12 @@ public override void Flush(IChannelHandlerContext context) return; } + this.WrapAndFlush(context); + } + + void WrapAndFlush(IChannelHandlerContext context) + { + this.flushMode = FlushMode.NoFlush; try { this.Wrap(context); @@ -537,7 +630,20 @@ public override void Flush(IChannelHandlerContext context) finally { // We may have written some parts of data before an exception was thrown so ensure we always flush. - context.Flush(); + if (this.flushMode == FlushMode.NoFlush) + { + this.flushMode = FlushMode.ForceFlush; + context.Flush(); + } + else + { + context.Executor.Execute((state) => { + var self = (TlsHandler)state; + + self.flushMode = FlushMode.ForceFlush; + self.capturedContext.Flush(); + }, this); + } } } @@ -595,6 +701,12 @@ void Wrap(IChannelHandlerContext context) void FinishWrap(byte[] buffer, int offset, int count) { + // In Mono(with btls provider) on linux, and maybe also for apple provider, Write is called in another thread, + // so it will run after the call to Flush. + if (this.flushMode == FlushMode.NoFlush && !this.capturedContext.Executor.InEventLoop) + { + this.flushMode = FlushMode.PendingFlush; + } IByteBuffer output; if (count == 0) { @@ -606,7 +718,7 @@ void FinishWrap(byte[] buffer, int offset, int count) output.WriteBytes(buffer, offset, count); } - this.lastContextWriteTask = this.capturedContext.WriteAsync(output); + this.lastContextWriteTask = (this.flushMode == FlushMode.ForceFlush) ? this.capturedContext.WriteAndFlushAsync(output) : this.capturedContext.WriteAsync(output); } Task FinishWrapNonAppDataAsync(byte[] buffer, int offset, int count) @@ -665,9 +777,26 @@ void NotifyHandshakeFailure(Exception cause) } } + enum FlushMode : byte + { + /// + /// Do nothing with Flush. + /// + NoFlush = 0, + /// + /// An Flush is or will be posted to IEventExecutor. + /// + PendingFlush = 1, + /// + /// Force FinishWrap to call Flush. + /// + ForceFlush = 2, + } + sealed class MediationStream : Stream { readonly TlsHandler owner; + IByteBuffer ownBuffer; byte[] input; int inputStartOffset; int inputOffset; @@ -688,66 +817,112 @@ public MediationStream(TlsHandler owner) this.owner = owner; } + public int TotalReadableBytes => (this.ownBuffer?.ReadableBytes ?? 0) + SourceReadableBytes; + public int SourceReadableBytes => this.inputLength - this.inputOffset; - public void SetSource(byte[] source, int offset) + public void SetSource(byte[] source, int offset, IByteBufferAllocator alloc) { - this.input = source; - this.inputStartOffset = offset; - this.inputOffset = 0; - this.inputLength = 0; + lock (this) + { + ResetSource(alloc); + + this.input = source; + this.inputStartOffset = offset; + this.inputOffset = 0; + this.inputLength = 0; + } } - public void ResetSource() + public void ResetSource(IByteBufferAllocator alloc) { - this.input = null; - this.inputLength = 0; + //Mono will run BeginRead in async and it's running with ResetSource at the same time + //net5.0 can also hit this with `leftLen > 0` while `!this.EnsureAuthenticated()` + lock (this) + { + int leftLen = this.SourceReadableBytes; + IByteBuffer buf = this.ownBuffer; + + if (leftLen > 0) + { + if (buf != null) + { + buf.DiscardSomeReadBytes(); + } + else + { + buf = alloc.Buffer(leftLen); + this.ownBuffer = buf; + } + buf.WriteBytes(this.input, this.inputStartOffset + this.inputOffset, leftLen); + } + else if (buf != null) + { + if (!buf.IsReadable()) + { + buf.SafeRelease(); + this.ownBuffer = null; + } + else + { + buf.DiscardSomeReadBytes(); + } + } + + this.input = null; + this.inputStartOffset = 0; + this.inputOffset = 0; + this.inputLength = 0; + } } public void ExpandSource(int count) { Contract.Assert(this.input != null); - this.inputLength += count; - - ArraySegment sslBuffer = this.sslOwnedBuffer; - if (sslBuffer.Array == null) + lock (this) { - // there is no pending read operation - keep for future - return; - } - this.sslOwnedBuffer = default(ArraySegment); + this.inputLength += count; -#if NETSTANDARD1_3 - this.readByteCount = this.ReadFromInput(sslBuffer.Array, sslBuffer.Offset, sslBuffer.Count); - // hack: this tricks SslStream's continuation to run synchronously instead of dispatching to TP. Remove once Begin/EndRead are available. - new Task( - ms => + ArraySegment sslBuffer = this.sslOwnedBuffer; + if (sslBuffer.Array == null) { - var self = (MediationStream)ms; - TaskCompletionSource p = self.readCompletionSource; - self.readCompletionSource = null; - p.TrySetResult(self.readByteCount); - }, - this) - .RunSynchronously(TaskScheduler.Default); + // there is no pending read operation - keep for future + return; + } + this.sslOwnedBuffer = default(ArraySegment); + +#if NETSTANDARD1_3 + this.readByteCount = this.ReadFromInput(sslBuffer.Array, sslBuffer.Offset, sslBuffer.Count); + // hack: this tricks SslStream's continuation to run synchronously instead of dispatching to TP. Remove once Begin/EndRead are available. + new Task( + ms => + { + var self = (MediationStream)ms; + TaskCompletionSource p = self.readCompletionSource; + self.readCompletionSource = null; + p.TrySetResult(self.readByteCount); + }, + this) + .RunSynchronously(TaskScheduler.Default); #else - int read = this.ReadFromInput(sslBuffer.Array, sslBuffer.Offset, sslBuffer.Count); + int read = this.ReadFromInput(sslBuffer.Array, sslBuffer.Offset, sslBuffer.Count); - TaskCompletionSource promise = this.readCompletionSource; - this.readCompletionSource = null; - promise.TrySetResult(read); + TaskCompletionSource promise = this.readCompletionSource; + this.readCompletionSource = null; + promise.TrySetResult(read); - AsyncCallback callback = this.readCallback; - this.readCallback = null; - callback?.Invoke(promise.Task); + AsyncCallback callback = this.readCallback; + this.readCallback = null; + callback?.Invoke(promise.Task); #endif + } } #if NETSTANDARD1_3 public override Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) { - if (this.SourceReadableBytes > 0) + if (this.TotalReadableBytes > 0) { // we have the bytes available upfront - write out synchronously int read = this.ReadFromInput(buffer, offset, count); @@ -763,7 +938,7 @@ public override Task ReadAsync(byte[] buffer, int offset, int count, Cancel #else public override IAsyncResult BeginRead(byte[] buffer, int offset, int count, AsyncCallback callback, object state) { - if (this.SourceReadableBytes > 0) + if (this.TotalReadableBytes > 0) { // we have the bytes available upfront - write out synchronously int read = this.ReadFromInput(buffer, offset, count); @@ -791,15 +966,7 @@ public override int EndRead(IAsyncResult asyncResult) Debug.Assert(this.readCompletionSource == null || this.readCompletionSource.Task == asyncResult); Contract.Assert(!((Task)asyncResult).IsCanceled); - try - { - return ((Task)asyncResult).Result; - } - catch (AggregateException ex) - { - ExceptionDispatchInfo.Capture(ex.InnerException).Throw(); - throw; // unreachable - } + return ((Task)asyncResult).GetAwaiter().GetResult(); } IAsyncResult PrepareSyncReadResult(int readBytes, object state) @@ -883,15 +1050,7 @@ public override void EndWrite(IAsyncResult asyncResult) return; } - try - { - ((Task)asyncResult).Wait(); - } - catch (AggregateException ex) - { - ExceptionDispatchInfo.Capture(ex.InnerException).Throw(); - throw; - } + ((Task)asyncResult).GetAwaiter().GetResult(); } #endif @@ -899,12 +1058,45 @@ int ReadFromInput(byte[] destination, int destinationOffset, int destinationCapa { Contract.Assert(destination != null); - byte[] source = this.input; - int readableBytes = this.SourceReadableBytes; - int length = Math.Min(readableBytes, destinationCapacity); - Buffer.BlockCopy(source, this.inputStartOffset + this.inputOffset, destination, destinationOffset, length); - this.inputOffset += length; - return length; + lock (this) + { + int length = 0; + do + { + int readableBytes; + IByteBuffer buf = this.ownBuffer; + if (buf != null) + { + readableBytes = buf.ReadableBytes; + if (readableBytes > 0) + { + readableBytes = Math.Min(readableBytes, destinationCapacity); + buf.ReadBytes(destination, destinationOffset, readableBytes); + length += readableBytes; + destinationCapacity -= readableBytes; + + if (destinationCapacity == 0) + break; + } + } + + byte[] source = this.input; + if (source != null) + { + readableBytes = this.SourceReadableBytes; + if (readableBytes > 0) + { + readableBytes = Math.Min(readableBytes, destinationCapacity); + Buffer.BlockCopy(source, this.inputStartOffset + this.inputOffset, destination, destinationOffset, readableBytes); + length += readableBytes; + destinationCapacity -= readableBytes; + + this.inputOffset += readableBytes; + } + } + } while (false); + return length; + } } public override void Flush() @@ -938,10 +1130,7 @@ public override void SetLength(long value) throw new NotSupportedException(); } - public override int Read(byte[] buffer, int offset, int count) - { - throw new NotSupportedException(); - } + public override int Read(byte[] buffer, int offset, int count) => this.ReadAsync(buffer, offset, count).Result; public override bool CanRead => true; diff --git a/test/DotNetty.Handlers.Tests/SniHandlerTest.cs b/test/DotNetty.Handlers.Tests/SniHandlerTest.cs index 0f4d09328..d73f46433 100644 --- a/test/DotNetty.Handlers.Tests/SniHandlerTest.cs +++ b/test/DotNetty.Handlers.Tests/SniHandlerTest.cs @@ -94,7 +94,10 @@ public async Task TlsRead(int[] frameLengths, bool isClient, IWriteStrategy writ await Task.WhenAll(writeTasks).WithTimeout(TimeSpan.FromSeconds(5)); IByteBuffer finalReadBuffer = Unpooled.Buffer(16 * 1024); await ReadOutboundAsync(async () => ch.ReadInbound(), expectedBuffer.ReadableBytes, finalReadBuffer, TestTimeout); - Assert.True(ByteBufferUtil.Equals(expectedBuffer, finalReadBuffer), $"---Expected:\n{ByteBufferUtil.PrettyHexDump(expectedBuffer)}\n---Actual:\n{ByteBufferUtil.PrettyHexDump(finalReadBuffer)}"); + if (!ByteBufferUtil.Equals(expectedBuffer, finalReadBuffer)) + { + Assert.True(false, $"---Expected:\n{ByteBufferUtil.PrettyHexDump(expectedBuffer)}\n---Actual:\n{ByteBufferUtil.PrettyHexDump(finalReadBuffer)}"); + } if (!isClient) { @@ -171,7 +174,10 @@ await ReadOutboundAsync( return Unpooled.WrappedBuffer(readBuffer, 0, read); }, expectedBuffer.ReadableBytes, finalReadBuffer, TestTimeout); - Assert.True(ByteBufferUtil.Equals(expectedBuffer, finalReadBuffer), $"---Expected:\n{ByteBufferUtil.PrettyHexDump(expectedBuffer)}\n---Actual:\n{ByteBufferUtil.PrettyHexDump(finalReadBuffer)}"); + if (!ByteBufferUtil.Equals(expectedBuffer, finalReadBuffer)) + { + Assert.True(false, $"---Expected:\n{ByteBufferUtil.PrettyHexDump(expectedBuffer)}\n---Actual:\n{ByteBufferUtil.PrettyHexDump(finalReadBuffer)}"); + } if (!isClient) { diff --git a/test/DotNetty.Handlers.Tests/TlsHandlerTest.cs b/test/DotNetty.Handlers.Tests/TlsHandlerTest.cs index ba7748c59..1a99f3f91 100644 --- a/test/DotNetty.Handlers.Tests/TlsHandlerTest.cs +++ b/test/DotNetty.Handlers.Tests/TlsHandlerTest.cs @@ -102,10 +102,9 @@ public async Task TlsRead(int[] frameLengths, bool isClient, IWriteStrategy writ await Task.WhenAll(writeTasks).WithTimeout(TimeSpan.FromSeconds(5)); IByteBuffer finalReadBuffer = Unpooled.Buffer(16 * 1024); await ReadOutboundAsync(async () => ch.ReadInbound(), expectedBuffer.ReadableBytes, finalReadBuffer, TestTimeout); - bool isEqual = ByteBufferUtil.Equals(expectedBuffer, finalReadBuffer); - if (!isEqual) + if (!ByteBufferUtil.Equals(expectedBuffer, finalReadBuffer)) { - Assert.True(isEqual, $"---Expected:\n{ByteBufferUtil.PrettyHexDump(expectedBuffer)}\n---Actual:\n{ByteBufferUtil.PrettyHexDump(finalReadBuffer)}"); + Assert.True(false, $"---Expected:\n{ByteBufferUtil.PrettyHexDump(expectedBuffer)}\n---Actual:\n{ByteBufferUtil.PrettyHexDump(finalReadBuffer)}"); } driverStream.Dispose(); Assert.False(ch.Finish()); @@ -192,10 +191,9 @@ await ReadOutboundAsync( return Unpooled.WrappedBuffer(readBuffer, 0, read); }, expectedBuffer.ReadableBytes, finalReadBuffer, TestTimeout); - bool isEqual = ByteBufferUtil.Equals(expectedBuffer, finalReadBuffer); - if (!isEqual) + if (!ByteBufferUtil.Equals(expectedBuffer, finalReadBuffer)) { - Assert.True(isEqual, $"---Expected:\n{ByteBufferUtil.PrettyHexDump(expectedBuffer)}\n---Actual:\n{ByteBufferUtil.PrettyHexDump(finalReadBuffer)}"); + Assert.True(false, $"---Expected:\n{ByteBufferUtil.PrettyHexDump(expectedBuffer)}\n---Actual:\n{ByteBufferUtil.PrettyHexDump(finalReadBuffer)}"); } driverStream.Dispose(); Assert.False(ch.Finish());