From 9c4eb9396113ebc01ea5325378771ec7e7c4c9a2 Mon Sep 17 00:00:00 2001 From: SilverFox Date: Tue, 14 Aug 2018 02:38:06 +0800 Subject: [PATCH] Pickup some commits not related to Mono from #374 (#410) * Add some missing method to LoggingHandler * Avoid to alloc an huge error message when the test not failed. * Update the unittest * Update Microsoft.NET.Test.Sdk from 15.0.0 to 15.7.2, fix that unable to debug an unittest for the second time. * Disable parallelization for InternalLoggerFactoryTest.TestMockReturned to avoid an rare test failure. * Remove `dotnet-xunit` since it's never used and will be discontinued, see https://xunit.github.io/releases/2.4-beta2 * Remove space from filename * Switch back to `DiscardSomeReadBytes` since it's avaliable * Rework some logic in TlsHandler * Make sure TlsHandler.MediationStream works well with different style of aync calls(Still not work for Mono, see #374) * Rework some logic in #366, now always close TlsHandler.MediationStream in TlsHandler.HandleFailure since it's never exported. * Workaround to fix issue 'Microsoft/vstest#1129'. * Change the default of TcpServerSocketChannel.Metadata.defaultMaxMessagesPerRead to 1 --- Directory.Build.targets | 4 + after.DotNetty.sln.targets | 5 + ...tric .cs => IByteBufferAllocatorMetric.cs} | 0 .../Multipart/{IHttpData .cs => IHttpData.cs} | 0 ...t .cs => DefaultBulkStringRedisContent.cs} | 0 src/DotNetty.Codecs/ByteToMessageDecoder.cs | 2 +- .../Logging/LoggingHandler.cs | 42 + src/DotNetty.Handlers/Tls/SniHandler.cs | 2 +- .../Tls/TlsHandler.Extensions.cs | 8 +- src/DotNetty.Handlers/Tls/TlsHandler.cs | 2011 +++++++++-------- .../Sockets/TcpServerSocketChannel.cs | 2 +- .../DotNetty.Codecs.Protobuf.Tests.csproj | 2 +- ...tNetty.Codecs.ProtocolBuffers.Tests.csproj | 2 +- .../DotNetty.Common.Tests.csproj | 1 - .../Logging/InternalLoggerFactoryTest.cs | 99 +- .../DotNetty.Handlers.Tests/TlsHandlerTest.cs | 12 +- 16 files changed, 1138 insertions(+), 1054 deletions(-) create mode 100644 after.DotNetty.sln.targets rename src/DotNetty.Buffers/{IByteBufferAllocatorMetric .cs => IByteBufferAllocatorMetric.cs} (100%) rename src/DotNetty.Codecs.Http/Multipart/{IHttpData .cs => IHttpData.cs} (100%) rename src/DotNetty.Codecs.Redis/Messages/{DefaultBulkStringRedisContent .cs => DefaultBulkStringRedisContent.cs} (100%) diff --git a/Directory.Build.targets b/Directory.Build.targets index 0df77903e..e823f5dec 100644 --- a/Directory.Build.targets +++ b/Directory.Build.targets @@ -8,4 +8,8 @@ $(Version). Commit Hash: $(GitHeadSha) + + + + diff --git a/after.DotNetty.sln.targets b/after.DotNetty.sln.targets new file mode 100644 index 000000000..0c0fbc6ae --- /dev/null +++ b/after.DotNetty.sln.targets @@ -0,0 +1,5 @@ + + + + + \ No newline at end of file diff --git a/src/DotNetty.Buffers/IByteBufferAllocatorMetric .cs b/src/DotNetty.Buffers/IByteBufferAllocatorMetric.cs similarity index 100% rename from src/DotNetty.Buffers/IByteBufferAllocatorMetric .cs rename to src/DotNetty.Buffers/IByteBufferAllocatorMetric.cs diff --git a/src/DotNetty.Codecs.Http/Multipart/IHttpData .cs b/src/DotNetty.Codecs.Http/Multipart/IHttpData.cs similarity index 100% rename from src/DotNetty.Codecs.Http/Multipart/IHttpData .cs rename to src/DotNetty.Codecs.Http/Multipart/IHttpData.cs diff --git a/src/DotNetty.Codecs.Redis/Messages/DefaultBulkStringRedisContent .cs b/src/DotNetty.Codecs.Redis/Messages/DefaultBulkStringRedisContent.cs similarity index 100% rename from src/DotNetty.Codecs.Redis/Messages/DefaultBulkStringRedisContent .cs rename to src/DotNetty.Codecs.Redis/Messages/DefaultBulkStringRedisContent.cs diff --git a/src/DotNetty.Codecs/ByteToMessageDecoder.cs b/src/DotNetty.Codecs/ByteToMessageDecoder.cs index 460013e09..f8148ab41 100644 --- a/src/DotNetty.Codecs/ByteToMessageDecoder.cs +++ b/src/DotNetty.Codecs/ByteToMessageDecoder.cs @@ -249,7 +249,7 @@ protected void DiscardSomeReadBytes() // See: // - https://github.com/netty/netty/issues/2327 // - https://github.com/netty/netty/issues/1764 - this.cumulation.DiscardReadBytes(); // todo: use discardSomeReadBytes + this.cumulation.DiscardSomeReadBytes(); } } diff --git a/src/DotNetty.Handlers/Logging/LoggingHandler.cs b/src/DotNetty.Handlers/Logging/LoggingHandler.cs index 3079fe8b6..5c2b1b2a4 100644 --- a/src/DotNetty.Handlers/Logging/LoggingHandler.cs +++ b/src/DotNetty.Handlers/Logging/LoggingHandler.cs @@ -209,6 +209,48 @@ public override void ChannelRead(IChannelHandlerContext ctx, object message) ctx.FireChannelRead(message); } + public override void ChannelReadComplete(IChannelHandlerContext ctx) + { + if (this.Logger.IsEnabled(this.InternalLevel)) + { + this.Logger.Log(this.InternalLevel, this.Format(ctx, "RECEIVED_COMPLETE")); + } + ctx.FireChannelReadComplete(); + } + + public override void ChannelWritabilityChanged(IChannelHandlerContext ctx) + { + if (this.Logger.IsEnabled(this.InternalLevel)) + { + this.Logger.Log(this.InternalLevel, this.Format(ctx, "WRITABILITY", ctx.Channel.IsWritable)); + } + ctx.FireChannelWritabilityChanged(); + } + + public override void HandlerAdded(IChannelHandlerContext ctx) + { + if (this.Logger.IsEnabled(this.InternalLevel)) + { + this.Logger.Log(this.InternalLevel, this.Format(ctx, "HANDLER_ADDED")); + } + } + public override void HandlerRemoved(IChannelHandlerContext ctx) + { + if (this.Logger.IsEnabled(this.InternalLevel)) + { + this.Logger.Log(this.InternalLevel, this.Format(ctx, "HANDLER_REMOVED")); + } + } + + public override void Read(IChannelHandlerContext ctx) + { + if (this.Logger.IsEnabled(this.InternalLevel)) + { + this.Logger.Log(this.InternalLevel, this.Format(ctx, "READ")); + } + ctx.Read(); + } + public override Task WriteAsync(IChannelHandlerContext ctx, object msg) { if (this.Logger.IsEnabled(this.InternalLevel)) diff --git a/src/DotNetty.Handlers/Tls/SniHandler.cs b/src/DotNetty.Handlers/Tls/SniHandler.cs index 2e5546151..d7082e370 100644 --- a/src/DotNetty.Handlers/Tls/SniHandler.cs +++ b/src/DotNetty.Handlers/Tls/SniHandler.cs @@ -28,7 +28,7 @@ public sealed class SniHandler : ByteToMessageDecoder bool readPending; public SniHandler(ServerTlsSniSettings settings) - : this(stream => new SslStream(stream, false), settings) + : this(stream => new SslStream(stream, true), settings) { } diff --git a/src/DotNetty.Handlers/Tls/TlsHandler.Extensions.cs b/src/DotNetty.Handlers/Tls/TlsHandler.Extensions.cs index 311d7440b..c7d44425c 100644 --- a/src/DotNetty.Handlers/Tls/TlsHandler.Extensions.cs +++ b/src/DotNetty.Handlers/Tls/TlsHandler.Extensions.cs @@ -40,18 +40,18 @@ private static SslStream CreateSslStream(TlsSettings settings, Stream stream) // Enable client certificate function only if ClientCertificateRequired is true in the configuration if (serverSettings.ClientCertificateMode == ClientCertificateMode.NoCertificate) { - return new SslStream(stream, false); + return new SslStream(stream, true); } #if DESKTOPCLR // SSL 版本 2 协议不支持客户端证书 if (serverSettings.EnabledProtocols == SslProtocols.Ssl2) { - return new SslStream(stream, false); + return new SslStream(stream, true); } #endif return new SslStream(stream, - leaveInnerStreamOpen: false, + leaveInnerStreamOpen: true, userCertificateValidationCallback: (sender, certificate, chain, sslPolicyErrors) => { if (certificate == null) @@ -84,7 +84,7 @@ private static SslStream CreateSslStream(TlsSettings settings, Stream stream) { var clientSettings = (ClientTlsSettings)settings; return new SslStream(stream, - leaveInnerStreamOpen: false, + leaveInnerStreamOpen: true, userCertificateValidationCallback: (sender, certificate, chain, sslPolicyErrors) => { if (clientSettings.ServerCertificateValidation != null) diff --git a/src/DotNetty.Handlers/Tls/TlsHandler.cs b/src/DotNetty.Handlers/Tls/TlsHandler.cs index fbf11eb2f..9ed1990b8 100644 --- a/src/DotNetty.Handlers/Tls/TlsHandler.cs +++ b/src/DotNetty.Handlers/Tls/TlsHandler.cs @@ -3,21 +3,22 @@ namespace DotNetty.Handlers.Tls { - using System; - using System.Collections.Generic; - using System.Diagnostics.Contracts; - using System.IO; - using System.Net.Security; - using System.Runtime.ExceptionServices; - using System.Security.Cryptography.X509Certificates; - using System.Threading; - using System.Threading.Tasks; - using DotNetty.Buffers; - using DotNetty.Codecs; - using DotNetty.Common.Concurrency; - using DotNetty.Common.Utilities; - using DotNetty.Common.Internal.Logging; - using DotNetty.Transport.Channels; + using System; + using System.Collections.Generic; + using System.Diagnostics; + using System.Diagnostics.Contracts; + using System.IO; + using System.Net.Security; + using System.Runtime.ExceptionServices; + using System.Security.Cryptography.X509Certificates; + using System.Threading; + using System.Threading.Tasks; + using DotNetty.Buffers; + using DotNetty.Codecs; + using DotNetty.Common.Concurrency; + using DotNetty.Common.Utilities; + using DotNetty.Common.Internal.Logging; + using DotNetty.Transport.Channels; #if !DESKTOPCLR && (NET40 || NET45 || NET451 || NET46 || NET461 || NET462 || NET47 || NET471) 确保编译不出问题 @@ -26,632 +27,573 @@ namespace DotNetty.Handlers.Tls 确保编译不出问题 #endif - public sealed partial class TlsHandler : ByteToMessageDecoder - { - #region @@ Fields @@ + public sealed partial class TlsHandler : ByteToMessageDecoder + { + #region @@ Fields @@ - private static readonly IInternalLogger s_logger = InternalLoggerFactory.GetInstance(); + private static readonly IInternalLogger s_logger = InternalLoggerFactory.GetInstance(); - private readonly TlsSettings _settings; - private const int c_fallbackReadBufferSize = 256; - private const int c_unencryptedWriteBatchSize = 14 * 1024; + private readonly TlsSettings _settings; + private const int c_fallbackReadBufferSize = 256; + private const int c_unencryptedWriteBatchSize = 14 * 1024; - private static readonly Exception s_channelClosedException = new IOException("Channel is closed"); + private static readonly Exception s_channelClosedException = new IOException("Channel is closed"); #if !NET40 - private static readonly Action s_handshakeCompletionCallback = new Action(HandleHandshakeCompleted); + private static readonly Action s_handshakeCompletionCallback = new Action(HandleHandshakeCompleted); #endif - private readonly SslStream _sslStream; - private readonly MediationStream _mediationStream; - private readonly TaskCompletionSource _closeFuture; - - private TlsHandlerState _state; - private int _packetLength; - private volatile IChannelHandlerContext _capturedContext; - private BatchingPendingWriteQueue _pendingUnencryptedWrites; - private Task _lastContextWriteTask; - private bool _firedChannelRead; - private IByteBuffer _pendingSslStreamReadBuffer; - private Task _pendingSslStreamReadFuture; + private readonly SslStream _sslStream; + private readonly MediationStream _mediationStream; + private readonly TaskCompletionSource _closeFuture; - #endregion + private TlsHandlerState _state; + private int _packetLength; + private volatile IChannelHandlerContext _capturedContext; + private BatchingPendingWriteQueue _pendingUnencryptedWrites; + private Task _lastContextWriteTask; + private bool _firedChannelRead; + private IByteBuffer _pendingSslStreamReadBuffer; + private Task _pendingSslStreamReadFuture; - #region @@ Constructors @@ + #endregion - //public TlsHandler(TlsSettings settings) - // : this(stream => new SslStream(stream, false), settings) - //{ - //} + #region @@ Constructors @@ - public TlsHandler(Func sslStreamFactory, TlsSettings settings) - { - Contract.Requires(sslStreamFactory != null); - Contract.Requires(settings != null); + //public TlsHandler(TlsSettings settings) + // : this(stream => new SslStream(stream, true), settings) + //{ + //} - _settings = settings; - _closeFuture = new TaskCompletionSource(); - _mediationStream = new MediationStream(this); - _sslStream = sslStreamFactory(_mediationStream); - } + public TlsHandler(Func sslStreamFactory, TlsSettings settings) + { + Contract.Requires(sslStreamFactory != null); + Contract.Requires(settings != null); - public static TlsHandler Client(string targetHost) => new TlsHandler(new ClientTlsSettings(targetHost)); + _settings = settings; + _closeFuture = new TaskCompletionSource(); + _mediationStream = new MediationStream(this); + _sslStream = sslStreamFactory(_mediationStream); + } - public static TlsHandler Client(string targetHost, X509Certificate clientCertificate) => new TlsHandler(new ClientTlsSettings(targetHost, new List { clientCertificate })); + public static TlsHandler Client(string targetHost) => new TlsHandler(new ClientTlsSettings(targetHost)); - public static TlsHandler Server(X509Certificate certificate) => new TlsHandler(new ServerTlsSettings(certificate)); + public static TlsHandler Client(string targetHost, X509Certificate clientCertificate) => new TlsHandler(new ClientTlsSettings(targetHost, new List { clientCertificate })); - #endregion + public static TlsHandler Server(X509Certificate certificate) => new TlsHandler(new ServerTlsSettings(certificate)); - #region @@ Properties @@ + #endregion - // using workaround mentioned here: https://github.com/dotnet/corefx/issues/4510 - public X509Certificate2 LocalCertificate => _sslStream.LocalCertificate as X509Certificate2 ?? new X509Certificate2(_sslStream.LocalCertificate?.Export(X509ContentType.Cert)); + #region @@ Properties @@ - public X509Certificate2 RemoteCertificate => _sslStream.RemoteCertificate as X509Certificate2 ?? new X509Certificate2(_sslStream.RemoteCertificate?.Export(X509ContentType.Cert)); + // using workaround mentioned here: https://github.com/dotnet/corefx/issues/4510 + public X509Certificate2 LocalCertificate => _sslStream.LocalCertificate as X509Certificate2 ?? new X509Certificate2(_sslStream.LocalCertificate?.Export(X509ContentType.Cert)); - private bool IsServer => _settings is ServerTlsSettings; + public X509Certificate2 RemoteCertificate => _sslStream.RemoteCertificate as X509Certificate2 ?? new X509Certificate2(_sslStream.RemoteCertificate?.Export(X509ContentType.Cert)); - #endregion + private bool IsServer => _settings is ServerTlsSettings; - #region -- IDisposable Members -- + #endregion - public void Dispose() => _sslStream?.Dispose(); + #region -- ChannelActive -- - #endregion - - #region -- ChannelActive -- - - public override void ChannelActive(IChannelHandlerContext context) - { - base.ChannelActive(context); + public override void ChannelActive(IChannelHandlerContext context) + { + base.ChannelActive(context); - if (!IsServer) - { - EnsureAuthenticated(); - } - } + if (!IsServer) + { + EnsureAuthenticated(); + } + } - #endregion + #endregion - #region -- ChannelInactive -- + #region -- ChannelInactive -- - public override void ChannelInactive(IChannelHandlerContext context) - { - // Make sure to release SslStream, - // and notify the handshake future if the connection has been closed during handshake. - HandleFailure(s_channelClosedException); + public override void ChannelInactive(IChannelHandlerContext context) + { + // Make sure to release SslStream, + // and notify the handshake future if the connection has been closed during handshake. + HandleFailure(s_channelClosedException); - base.ChannelInactive(context); - } + base.ChannelInactive(context); + } - #endregion + #endregion - #region -- ExceptionCaught -- + #region -- ExceptionCaught -- - public override void ExceptionCaught(IChannelHandlerContext context, Exception exception) - { - if (IgnoreException(exception)) - { - // Close the connection explicitly just in case the transport - // did not close the connection automatically. - if (context.Channel.Active) + public override void ExceptionCaught(IChannelHandlerContext context, Exception exception) { - context.CloseAsync(); + if (IgnoreException(exception)) + { + // Close the connection explicitly just in case the transport + // did not close the connection automatically. + if (context.Channel.Active) + { + context.CloseAsync(); + } + } + else + { + base.ExceptionCaught(context, exception); + } } - } - else - { - base.ExceptionCaught(context, exception); - } - } - - #endregion - - #region ** IgnoreException ** - - private bool IgnoreException(Exception t) - { - if (t is ObjectDisposedException && _closeFuture.Task.IsCompleted) - { - return true; - } - return false; - } - - #endregion - - #region -- HandlerAdded -- - - public override void HandlerAdded(IChannelHandlerContext context) - { - base.HandlerAdded(context); - _capturedContext = context; - _pendingUnencryptedWrites = new BatchingPendingWriteQueue(context, c_unencryptedWriteBatchSize); - if (context.Channel.Active && !IsServer) - { - // todo: support delayed initialization on an existing/active channel if in client mode - EnsureAuthenticated(); - } - } - #endregion - - #region ++ HandlerRemovedInternal ++ - - protected override void HandlerRemovedInternal(IChannelHandlerContext context) - { - if (!_pendingUnencryptedWrites.IsEmpty) - { - // Check if queue is not empty first because create a new ChannelException is expensive - _pendingUnencryptedWrites.RemoveAndFailAll(new ChannelException("Write has failed due to TlsHandler being removed from channel pipeline.")); - } - } - - #endregion + #endregion - #region ++ Decode ++ + #region ** IgnoreException ** - protected override void Decode(IChannelHandlerContext context, IByteBuffer input, List output) - { - int startOffset = input.ReaderIndex; - int endOffset = input.WriterIndex; - int offset = startOffset; - int totalLength = 0; - - List packetLengths; - // if we calculated the length of the current SSL record before, use that information. - if (_packetLength > 0) - { - if (endOffset - startOffset < _packetLength) - { - // input does not contain a single complete SSL record - return; - } - else + private bool IgnoreException(Exception t) { - packetLengths = new List(4) - { - _packetLength - }; - offset += _packetLength; - totalLength = _packetLength; - _packetLength = 0; - } - } - else - { - packetLengths = new List(4); - } - - bool nonSslRecord = false; - - while (totalLength < TlsUtils.MAX_ENCRYPTED_PACKET_LENGTH) - { - int readableBytes = endOffset - offset; - if (readableBytes < TlsUtils.SSL_RECORD_HEADER_LENGTH) - { - break; + if (t is ObjectDisposedException && _closeFuture.Task.IsCompleted) + { + return true; + } + return false; } - int encryptedPacketLength = TlsUtils.GetEncryptedPacketLength(input, offset); - if (encryptedPacketLength == -1) - { - nonSslRecord = true; - break; - } + #endregion - Contract.Assert(encryptedPacketLength > 0); + #region -- HandlerAdded -- - if (encryptedPacketLength > readableBytes) + public override void HandlerAdded(IChannelHandlerContext context) { - // wait until the whole packet can be read - _packetLength = encryptedPacketLength; - break; + base.HandlerAdded(context); + _capturedContext = context; + _pendingUnencryptedWrites = new BatchingPendingWriteQueue(context, c_unencryptedWriteBatchSize); + if (context.Channel.Active && !IsServer) + { + // todo: support delayed initialization on an existing/active channel if in client mode + EnsureAuthenticated(); + } } - int newTotalLength = totalLength + encryptedPacketLength; - if (newTotalLength > TlsUtils.MAX_ENCRYPTED_PACKET_LENGTH) - { - // Don't read too much. - break; - } + #endregion - // 1. call unwrap with packet boundaries - call SslStream.ReadAsync only once. - // 2. once we're through all the whole packets, switch to reading out using fallback sized buffer - - // We have a whole packet. - // Increment the offset to handle the next packet. - packetLengths.Add(encryptedPacketLength); - offset += encryptedPacketLength; - totalLength = newTotalLength; - } - - if (totalLength > 0) - { - // The buffer contains one or more full SSL records. - // Slice out the whole packet so unwrap will only be called with complete packets. - // Also directly reset the packetLength. This is needed as unwrap(..) may trigger - // decode(...) again via: - // 1) unwrap(..) is called - // 2) wrap(...) is called from within unwrap(...) - // 3) wrap(...) calls unwrapLater(...) - // 4) unwrapLater(...) calls decode(...) - // - // See https://github.com/netty/netty/issues/1534 - - input.SkipBytes(totalLength); - Unwrap(context, input, startOffset, totalLength, packetLengths, output); - - if (!_firedChannelRead) + #region ++ HandlerRemovedInternal ++ + + protected override void HandlerRemovedInternal(IChannelHandlerContext context) { - // Check first if firedChannelRead is not set yet as it may have been set in a - // previous decode(...) call. - _firedChannelRead = output.Count > 0; + if (!_pendingUnencryptedWrites.IsEmpty) + { + // Check if queue is not empty first because create a new ChannelException is expensive + _pendingUnencryptedWrites.RemoveAndFailAll(new ChannelException("Write has failed due to TlsHandler being removed from channel pipeline.")); + } } - } - - if (nonSslRecord) - { - // Not an SSL/TLS packet - var ex = new NotSslRecordException( - "not an SSL/TLS record: " + ByteBufferUtil.HexDump(input)); - input.SkipBytes(input.ReadableBytes); - context.FireExceptionCaught(ex); - HandleFailure(ex); - } - } - #endregion + #endregion - #region -- ChannelReadComplete -- + #region ++ Decode ++ - public override void ChannelReadComplete(IChannelHandlerContext ctx) - { - // Discard bytes of the cumulation buffer if needed. - DiscardSomeReadBytes(); - - ReadIfNeeded(ctx); + protected override void Decode(IChannelHandlerContext context, IByteBuffer input, List output) + { + int startOffset = input.ReaderIndex; + int endOffset = input.WriterIndex; + int offset = startOffset; + int totalLength = 0; + + List packetLengths; + // if we calculated the length of the current SSL record before, use that information. + if (_packetLength > 0) + { + if (endOffset - startOffset < _packetLength) + { + // input does not contain a single complete SSL record + return; + } + else + { + packetLengths = new List(4) + { + _packetLength + }; + offset += _packetLength; + totalLength = _packetLength; + _packetLength = 0; + } + } + else + { + packetLengths = new List(4); + } - _firedChannelRead = false; - ctx.FireChannelReadComplete(); - } + bool nonSslRecord = false; - #endregion + while (totalLength < TlsUtils.MAX_ENCRYPTED_PACKET_LENGTH) + { + int readableBytes = endOffset - offset; + if (readableBytes < TlsUtils.SSL_RECORD_HEADER_LENGTH) + { + break; + } + + int encryptedPacketLength = TlsUtils.GetEncryptedPacketLength(input, offset); + if (encryptedPacketLength == -1) + { + nonSslRecord = true; + break; + } + + Contract.Assert(encryptedPacketLength > 0); + + if (encryptedPacketLength > readableBytes) + { + // wait until the whole packet can be read + _packetLength = encryptedPacketLength; + break; + } + + int newTotalLength = totalLength + encryptedPacketLength; + if (newTotalLength > TlsUtils.MAX_ENCRYPTED_PACKET_LENGTH) + { + // Don't read too much. + break; + } + + // 1. call unwrap with packet boundaries - call SslStream.ReadAsync only once. + // 2. once we're through all the whole packets, switch to reading out using fallback sized buffer + + // We have a whole packet. + // Increment the offset to handle the next packet. + packetLengths.Add(encryptedPacketLength); + offset += encryptedPacketLength; + totalLength = newTotalLength; + } - #region ** ReadIfNeeded ** + if (totalLength > 0) + { + // The buffer contains one or more full SSL records. + // Slice out the whole packet so unwrap will only be called with complete packets. + // Also directly reset the packetLength. This is needed as unwrap(..) may trigger + // decode(...) again via: + // 1) unwrap(..) is called + // 2) wrap(...) is called from within unwrap(...) + // 3) wrap(...) calls unwrapLater(...) + // 4) unwrapLater(...) calls decode(...) + // + // See https://github.com/netty/netty/issues/1534 + + input.SkipBytes(totalLength); + Unwrap(context, input, startOffset, totalLength, packetLengths, output); + + if (!_firedChannelRead) + { + // Check first if firedChannelRead is not set yet as it may have been set in a + // previous decode(...) call. + _firedChannelRead = output.Count > 0; + } + } - private void ReadIfNeeded(IChannelHandlerContext ctx) - { - // if handshake is not finished yet, we need more data - if (!ctx.Channel.Configuration.AutoRead && (!_firedChannelRead || !_state.HasAny(TlsHandlerState.AuthenticationCompleted))) - { - // No auto-read used and no message was passed through the ChannelPipeline or the handshake was not completed - // yet, which means we need to trigger the read to ensure we will not stall - ctx.Read(); - } - } + if (nonSslRecord) + { + // Not an SSL/TLS packet + var ex = new NotSslRecordException( + "not an SSL/TLS record: " + ByteBufferUtil.HexDump(input)); + input.SkipBytes(input.ReadableBytes); + context.FireExceptionCaught(ex); + HandleFailure(ex); + } + } - #endregion + #endregion - #region ** Unwrap ** + #region -- ChannelReadComplete -- - /// Unwraps inbound SSL records. - private void Unwrap(IChannelHandlerContext ctx, IByteBuffer packet, int offset, int length, List packetLengths, List output) - { - Contract.Requires(packetLengths.Count > 0); + public override void ChannelReadComplete(IChannelHandlerContext ctx) + { + // Discard bytes of the cumulation buffer if needed. + DiscardSomeReadBytes(); - //bool notifyClosure = false; // todo: netty/issues/137 - bool pending = false; + ReadIfNeeded(ctx); - IByteBuffer outputBuffer = null; + _firedChannelRead = false; + ctx.FireChannelReadComplete(); + } - try - { - ArraySegment inputIoBuffer = packet.GetIoBuffer(offset, length); - _mediationStream.SetSource(inputIoBuffer.Array, inputIoBuffer.Offset); + #endregion - int packetIndex = 0; + #region ** ReadIfNeeded ** - while (!EnsureAuthenticated()) + private void ReadIfNeeded(IChannelHandlerContext ctx) { - _mediationStream.ExpandSource(packetLengths[packetIndex]); - if (++packetIndex == packetLengths.Count) - { - return; - } + // if handshake is not finished yet, we need more data + if (!ctx.Channel.Configuration.AutoRead && (!_firedChannelRead || !_state.HasAny(TlsHandlerState.AuthenticationCompleted))) + { + // No auto-read used and no message was passed through the ChannelPipeline or the handshake was not completed + // yet, which means we need to trigger the read to ensure we will not stall + ctx.Read(); + } } - Task currentReadFuture = _pendingSslStreamReadFuture; + #endregion - int outputBufferLength; + #region ** Unwrap ** - if (currentReadFuture != null) + /// Unwraps inbound SSL records. + private void Unwrap(IChannelHandlerContext ctx, IByteBuffer packet, int offset, int length, List packetLengths, List output) { - // restoring context from previous read - Contract.Assert(_pendingSslStreamReadBuffer != null); + Contract.Requires(packetLengths.Count > 0); - outputBuffer = _pendingSslStreamReadBuffer; - outputBufferLength = outputBuffer.WritableBytes; - } - else - { - outputBufferLength = 0; - } + //bool notifyClosure = false; // todo: netty/issues/137 + bool pending = false; - // go through packets one by one (because SslStream does not consume more than 1 packet at a time) - for (; packetIndex < packetLengths.Count; packetIndex++) - { - int currentPacketLength = packetLengths[packetIndex]; - _mediationStream.ExpandSource(currentPacketLength); - - if (currentReadFuture != null) - { - // there was a read pending already, so we make sure we completed that first + IByteBuffer outputBuffer = null; - if (!currentReadFuture.IsCompleted) + try { - // we did feed the whole current packet to SslStream yet it did not produce any result -> move to the next packet in input - Contract.Assert(_mediationStream.SourceReadableBytes == 0); - - continue; + ArraySegment inputIoBuffer = packet.GetIoBuffer(offset, length); + _mediationStream.SetSource(inputIoBuffer.Array, inputIoBuffer.Offset); + + int packetIndex = 0; + + while (!EnsureAuthenticated()) + { + _mediationStream.ExpandSource(packetLengths[packetIndex]); + if (++packetIndex == packetLengths.Count) + { + return; + } + } + + Task currentReadFuture = _pendingSslStreamReadFuture; + + int outputBufferLength; + + if (currentReadFuture != null) + { + // restoring context from previous read + Contract.Assert(_pendingSslStreamReadBuffer != null); + + outputBuffer = _pendingSslStreamReadBuffer; + outputBufferLength = outputBuffer.WritableBytes; + + _pendingSslStreamReadFuture = null; + _pendingSslStreamReadBuffer = null; + } + else + { + outputBufferLength = 0; + } + + // go through packets one by one (because SslStream does not consume more than 1 packet at a time) + for (; packetIndex < packetLengths.Count; packetIndex++) + { + int currentPacketLength = packetLengths[packetIndex]; + _mediationStream.ExpandSource(currentPacketLength); + + if (currentReadFuture != null) + { + // there was a read pending already, so we make sure we completed that first + + if (!currentReadFuture.IsCompleted) + { + // we did feed the whole current packet to SslStream yet it did not produce any result -> move to the next packet in input + + continue; + } + + int read = currentReadFuture.Result; + + if (read == 0) + { + //Stream closed + return; + } + + // Now output the result of previous read and decide whether to do an extra read on the same source or move forward + AddBufferToOutput(outputBuffer, read, output); + + currentReadFuture = null; + outputBuffer = null; + if (_mediationStream.SourceReadableBytes == 0) + { + // we just made a frame available for reading but there was already pending read so SslStream read it out to make further progress there + + if (read < outputBufferLength) + { + // SslStream returned non-full buffer and there's no more input to go through -> + // typically it means SslStream is done reading current frame so we skip + continue; + } + + // we've read out `read` bytes out of current packet to fulfil previously outstanding read + outputBufferLength = currentPacketLength - read; + if (outputBufferLength <= 0) + { + // after feeding to SslStream current frame it read out more bytes than current packet size + outputBufferLength = c_fallbackReadBufferSize; + } + } + else + { + // SslStream did not get to reading current frame so it completed previous read sync + // and the next read will likely read out the new frame + outputBufferLength = currentPacketLength; + } + } + else + { + // there was no pending read before so we estimate buffer of `currentPacketLength` bytes to be sufficient + outputBufferLength = currentPacketLength; + } + + outputBuffer = ctx.Allocator.Buffer(outputBufferLength); + currentReadFuture = ReadFromSslStreamAsync(outputBuffer, outputBufferLength); + } + + // read out the rest of SslStream's output (if any) at risk of going async + // using FallbackReadBufferSize - buffer size we're ok to have pinned with the SslStream until it's done reading + while (true) + { + if (currentReadFuture != null) + { + if (!currentReadFuture.IsCompleted) + { + break; + } + int read = currentReadFuture.Result; + AddBufferToOutput(outputBuffer, read, output); + } + outputBuffer = ctx.Allocator.Buffer(c_fallbackReadBufferSize); + currentReadFuture = ReadFromSslStreamAsync(outputBuffer, c_fallbackReadBufferSize); + } + + pending = true; + _pendingSslStreamReadBuffer = outputBuffer; + _pendingSslStreamReadFuture = currentReadFuture; } - - int read = currentReadFuture.Result; - - // Now output the result of previous read and decide whether to do an extra read on the same source or move forward - AddBufferToOutput(outputBuffer, read, output); - - currentReadFuture = null; - if (_mediationStream.SourceReadableBytes == 0) + catch (Exception ex) { - // we just made a frame available for reading but there was already pending read so SslStream read it out to make further progress there - - if (read < outputBufferLength) - { - // SslStream returned non-full buffer and there's no more input to go through -> - // typically it means SslStream is done reading current frame so we skip - continue; - } - - // we've read out `read` bytes out of current packet to fulfil previously outstanding read - outputBufferLength = currentPacketLength - read; - if (outputBufferLength <= 0) - { - // after feeding to SslStream current frame it read out more bytes than current packet size - outputBufferLength = c_fallbackReadBufferSize; - } + HandleFailure(ex); + throw; } - else + finally { - // SslStream did not get to reading current frame so it completed previous read sync - // and the next read will likely read out the new frame - outputBufferLength = currentPacketLength; + _mediationStream.ResetSource(); + if (!pending && outputBuffer != null) + { + if (outputBuffer.IsReadable()) + { + output.Add(outputBuffer); + } + else + { + outputBuffer.SafeRelease(); + } + } } - } - else - { - // there was no pending read before so we estimate buffer of `currentPacketLength` bytes to be sufficient - outputBufferLength = currentPacketLength; - } - - outputBuffer = ctx.Allocator.Buffer(outputBufferLength); - currentReadFuture = ReadFromSslStreamAsync(outputBuffer, outputBufferLength); } - // read out the rest of SslStream's output (if any) at risk of going async - // using FallbackReadBufferSize - buffer size we're ok to have pinned with the SslStream until it's done reading - while (true) - { - if (currentReadFuture != null) - { - if (!currentReadFuture.IsCompleted) - { - break; - } - int read = currentReadFuture.Result; - AddBufferToOutput(outputBuffer, read, output); - } - outputBuffer = ctx.Allocator.Buffer(c_fallbackReadBufferSize); - currentReadFuture = ReadFromSslStreamAsync(outputBuffer, c_fallbackReadBufferSize); - } + #endregion + + #region **& AddBufferToOutput &** - pending = true; - _pendingSslStreamReadBuffer = outputBuffer; - _pendingSslStreamReadFuture = currentReadFuture; - } - catch (Exception ex) - { - HandleFailure(ex); - throw; - } - finally - { - _mediationStream.ResetSource(); - if (!pending && outputBuffer != null) + private static void AddBufferToOutput(IByteBuffer outputBuffer, int length, List output) { - if (outputBuffer.IsReadable()) - { - output.Add(outputBuffer); - } - else - { - outputBuffer.SafeRelease(); - } + Contract.Assert(length > 0); + output.Add(outputBuffer.SetWriterIndex(outputBuffer.WriterIndex + length)); } - } - } - #endregion - - #region **& AddBufferToOutput &** - - private static void AddBufferToOutput(IByteBuffer outputBuffer, int length, List output) - { - Contract.Assert(length > 0); - output.Add(outputBuffer.SetWriterIndex(outputBuffer.WriterIndex + length)); - } + #endregion - #endregion + #region ** ReadFromSslStreamAsync ** - #region ** ReadFromSslStreamAsync ** + private Task ReadFromSslStreamAsync(IByteBuffer outputBuffer, int outputBufferLength) + { + ArraySegment outlet = outputBuffer.GetIoBuffer(outputBuffer.WriterIndex, outputBufferLength); + return _sslStream.ReadAsync(outlet.Array, outlet.Offset, outlet.Count); + } - private Task ReadFromSslStreamAsync(IByteBuffer outputBuffer, int outputBufferLength) - { - ArraySegment outlet = outputBuffer.GetIoBuffer(outputBuffer.WriterIndex, outputBufferLength); - return _sslStream.ReadAsync(outlet.Array, outlet.Offset, outlet.Count); - } + #endregion - #endregion + #region -- Read -- - #region -- Read -- - - public override void Read(IChannelHandlerContext context) - { - TlsHandlerState oldState = _state; - if (!oldState.HasAny(TlsHandlerState.AuthenticationCompleted)) - { - _state = oldState | TlsHandlerState.ReadRequestedBeforeAuthenticated; - } + public override void Read(IChannelHandlerContext context) + { + TlsHandlerState oldState = _state; + if (!oldState.HasAny(TlsHandlerState.AuthenticationCompleted)) + { + _state = oldState | TlsHandlerState.ReadRequestedBeforeAuthenticated; + } - context.Read(); - } + context.Read(); + } - #endregion + #endregion - #region ** EnsureAuthenticated ** + #region ** EnsureAuthenticated ** - private bool EnsureAuthenticated() - { - TlsHandlerState oldState = _state; - if (!oldState.HasAny(TlsHandlerState.AuthenticationStarted)) - { - _state = oldState | TlsHandlerState.Authenticating; - if (_settings is ServerTlsSettings serverSettings) + private bool EnsureAuthenticated() { + TlsHandlerState oldState = _state; + if (!oldState.HasAny(TlsHandlerState.AuthenticationStarted)) + { + _state = oldState | TlsHandlerState.Authenticating; + if (_settings is ServerTlsSettings serverSettings) + { #if !NET40 - _sslStream.AuthenticateAsServerAsync(serverSettings.Certificate, - serverSettings.NegotiateClientCertificate, - serverSettings.EnabledProtocols, - serverSettings.CheckCertificateRevocation) - .ContinueWith(s_handshakeCompletionCallback, this, TaskContinuationOptions.ExecuteSynchronously); + _sslStream.AuthenticateAsServerAsync(serverSettings.Certificate, + serverSettings.NegotiateClientCertificate, + serverSettings.EnabledProtocols, + serverSettings.CheckCertificateRevocation) + .ContinueWith(s_handshakeCompletionCallback, this, TaskContinuationOptions.ExecuteSynchronously); #else - _sslStream.BeginAuthenticateAsServer(serverSettings.Certificate, - serverSettings.NegotiateClientCertificate, - serverSettings.EnabledProtocols, - serverSettings.CheckCertificateRevocation, - Server_HandleHandshakeCompleted, - this); + _sslStream.BeginAuthenticateAsServer(serverSettings.Certificate, + serverSettings.NegotiateClientCertificate, + serverSettings.EnabledProtocols, + serverSettings.CheckCertificateRevocation, + Server_HandleHandshakeCompleted, + this); #endif - } - else - { - var clientSettings = (ClientTlsSettings)_settings; + } + else + { + var clientSettings = (ClientTlsSettings)_settings; #if !NET40 - _sslStream.AuthenticateAsClientAsync(clientSettings.TargetHost, - clientSettings.X509CertificateCollection, - clientSettings.EnabledProtocols, - clientSettings.CheckCertificateRevocation) - .ContinueWith(s_handshakeCompletionCallback, this, TaskContinuationOptions.ExecuteSynchronously); + _sslStream.AuthenticateAsClientAsync(clientSettings.TargetHost, + clientSettings.X509CertificateCollection, + clientSettings.EnabledProtocols, + clientSettings.CheckCertificateRevocation) + .ContinueWith(s_handshakeCompletionCallback, this, TaskContinuationOptions.ExecuteSynchronously); #else - _sslStream.BeginAuthenticateAsClient(clientSettings.TargetHost, - clientSettings.X509CertificateCollection, - clientSettings.EnabledProtocols, - clientSettings.CheckCertificateRevocation, - Client_HandleHandshakeCompleted, - this); + _sslStream.BeginAuthenticateAsClient(clientSettings.TargetHost, + clientSettings.X509CertificateCollection, + clientSettings.EnabledProtocols, + clientSettings.CheckCertificateRevocation, + Client_HandleHandshakeCompleted, + this); #endif - } - return false; - } + } + return false; + } - return oldState.Has(TlsHandlerState.Authenticated); - } + return oldState.Has(TlsHandlerState.Authenticated); + } - #endregion + #endregion - #region **& HandleHandshakeCompleted &** + #region **& HandleHandshakeCompleted &** #if NET40 - private static void Client_HandleHandshakeCompleted(IAsyncResult result) - { - var self = (TlsHandler)result.AsyncState; - TlsHandlerState oldState; - try - { - self._sslStream.EndAuthenticateAsClient(result); - } - catch (Exception ex) - { - // ReSharper disable once AssignNullToNotNullAttribute -- task.Exception will be present as task is faulted - oldState = self._state; - Contract.Assert(!oldState.HasAny(TlsHandlerState.Authenticated)); - self.HandleFailure(ex); - return; - } - - oldState = self._state; - - Contract.Assert(!oldState.HasAny(TlsHandlerState.AuthenticationCompleted)); - self._state = (oldState | TlsHandlerState.Authenticated) & ~(TlsHandlerState.Authenticating | TlsHandlerState.FlushedBeforeHandshake); - - self._capturedContext.FireUserEventTriggered(TlsHandshakeCompletionEvent.Success); - - if (oldState.Has(TlsHandlerState.ReadRequestedBeforeAuthenticated) && !self._capturedContext.Channel.Configuration.AutoRead) - { - self._capturedContext.Read(); - } - - if (oldState.Has(TlsHandlerState.FlushedBeforeHandshake)) - { - self.Wrap(self._capturedContext); - self._capturedContext.Flush(); - } - } + private static void Client_HandleHandshakeCompleted(IAsyncResult result) + { + var self = (TlsHandler)result.AsyncState; + TlsHandlerState oldState; + try + { + self._sslStream.EndAuthenticateAsClient(result); + } + catch (Exception ex) + { + // ReSharper disable once AssignNullToNotNullAttribute -- task.Exception will be present as task is faulted + oldState = self._state; + Contract.Assert(!oldState.HasAny(TlsHandlerState.Authenticated)); + self.HandleFailure(ex); + return; + } - private static void Server_HandleHandshakeCompleted(IAsyncResult result) - { - var self = (TlsHandler)result.AsyncState; - TlsHandlerState oldState; - try - { - self._sslStream.EndAuthenticateAsServer(result); - } - catch (Exception ex) - { - // ReSharper disable once AssignNullToNotNullAttribute -- task.Exception will be present as task is faulted - oldState = self._state; - Contract.Assert(!oldState.HasAny(TlsHandlerState.Authenticated)); - self.HandleFailure(ex); - return; - } - - oldState = self._state; - - Contract.Assert(!oldState.HasAny(TlsHandlerState.AuthenticationCompleted)); - self._state = (oldState | TlsHandlerState.Authenticated) & ~(TlsHandlerState.Authenticating | TlsHandlerState.FlushedBeforeHandshake); - - self._capturedContext.FireUserEventTriggered(TlsHandshakeCompletionEvent.Success); - - if (oldState.Has(TlsHandlerState.ReadRequestedBeforeAuthenticated) && !self._capturedContext.Channel.Configuration.AutoRead) - { - self._capturedContext.Read(); - } - - if (oldState.Has(TlsHandlerState.FlushedBeforeHandshake)) - { - self.Wrap(self._capturedContext); - self._capturedContext.Flush(); - } - } -#else - private static void HandleHandshakeCompleted(Task task, object state) - { - var self = (TlsHandler)state; - switch (task.Status) - { - case TaskStatus.RanToCompletion: - { - TlsHandlerState oldState = self._state; + oldState = self._state; Contract.Assert(!oldState.HasAny(TlsHandlerState.AuthenticationCompleted)); self._state = (oldState | TlsHandlerState.Authenticated) & ~(TlsHandlerState.Authenticating | TlsHandlerState.FlushedBeforeHandshake); @@ -660,580 +602,661 @@ private static void HandleHandshakeCompleted(Task task, object state) if (oldState.Has(TlsHandlerState.ReadRequestedBeforeAuthenticated) && !self._capturedContext.Channel.Configuration.AutoRead) { - self._capturedContext.Read(); + self._capturedContext.Read(); } if (oldState.Has(TlsHandlerState.FlushedBeforeHandshake)) { - self.Wrap(self._capturedContext); - self._capturedContext.Flush(); + self.Wrap(self._capturedContext); + self._capturedContext.Flush(); } - break; - } - case TaskStatus.Canceled: - case TaskStatus.Faulted: - { - // ReSharper disable once AssignNullToNotNullAttribute -- task.Exception will be present as task is faulted - TlsHandlerState oldState = self._state; - Contract.Assert(!oldState.HasAny(TlsHandlerState.Authenticated)); - self.HandleFailure(task.Exception); - break; - } - default: - ThrowHelper.ThrowArgumentOutOfRangeException_HandshakeCompleted(task.Status); break; - } - } -#endif + } - #endregion + private static void Server_HandleHandshakeCompleted(IAsyncResult result) + { + var self = (TlsHandler)result.AsyncState; + TlsHandlerState oldState; + try + { + self._sslStream.EndAuthenticateAsServer(result); + } + catch (Exception ex) + { + // ReSharper disable once AssignNullToNotNullAttribute -- task.Exception will be present as task is faulted + oldState = self._state; + Contract.Assert(!oldState.HasAny(TlsHandlerState.Authenticated)); + self.HandleFailure(ex); + return; + } - #region -- WriteAsync -- + oldState = self._state; - public override Task WriteAsync(IChannelHandlerContext context, object message) - { - if (!(message is IByteBuffer)) - { - return TaskUtil.FromException(new UnsupportedMessageTypeException(message, typeof(IByteBuffer))); - } - return _pendingUnencryptedWrites.Add(message); - } + Contract.Assert(!oldState.HasAny(TlsHandlerState.AuthenticationCompleted)); + self._state = (oldState | TlsHandlerState.Authenticated) & ~(TlsHandlerState.Authenticating | TlsHandlerState.FlushedBeforeHandshake); - #endregion + self._capturedContext.FireUserEventTriggered(TlsHandshakeCompletionEvent.Success); - #region -- Flush -- + if (oldState.Has(TlsHandlerState.ReadRequestedBeforeAuthenticated) && !self._capturedContext.Channel.Configuration.AutoRead) + { + self._capturedContext.Read(); + } - public override void Flush(IChannelHandlerContext context) - { - if (_pendingUnencryptedWrites.IsEmpty) - { - _pendingUnencryptedWrites.Add(Unpooled.Empty); - } - - if (!EnsureAuthenticated()) - { - _state |= TlsHandlerState.FlushedBeforeHandshake; - return; - } - - try - { - Wrap(context); - } - finally - { - // We may have written some parts of data before an exception was thrown so ensure we always flush. - context.Flush(); - } - } + if (oldState.Has(TlsHandlerState.FlushedBeforeHandshake)) + { + self.Wrap(self._capturedContext); + self._capturedContext.Flush(); + } + } +#else + private static void HandleHandshakeCompleted(Task task, object state) + { + var self = (TlsHandler)state; + switch (task.Status) + { + case TaskStatus.RanToCompletion: + { + TlsHandlerState oldState = self._state; + + Contract.Assert(!oldState.HasAny(TlsHandlerState.AuthenticationCompleted)); + self._state = (oldState | TlsHandlerState.Authenticated) & ~(TlsHandlerState.Authenticating | TlsHandlerState.FlushedBeforeHandshake); + + self._capturedContext.FireUserEventTriggered(TlsHandshakeCompletionEvent.Success); + + if (oldState.Has(TlsHandlerState.ReadRequestedBeforeAuthenticated) && !self._capturedContext.Channel.Configuration.AutoRead) + { + self._capturedContext.Read(); + } + + if (oldState.Has(TlsHandlerState.FlushedBeforeHandshake)) + { + self.Wrap(self._capturedContext); + self._capturedContext.Flush(); + } + break; + } + case TaskStatus.Canceled: + case TaskStatus.Faulted: + { + // ReSharper disable once AssignNullToNotNullAttribute -- task.Exception will be present as task is faulted + TlsHandlerState oldState = self._state; + Contract.Assert(!oldState.HasAny(TlsHandlerState.Authenticated)); + self.HandleFailure(task.Exception); + break; + } + default: + ThrowHelper.ThrowArgumentOutOfRangeException_HandshakeCompleted(task.Status); break; + } + } +#endif - #endregion + #endregion - #region ** Wrap ** + #region -- WriteAsync -- - private void Wrap(IChannelHandlerContext context) - { - Contract.Assert(context == _capturedContext); + public override Task WriteAsync(IChannelHandlerContext context, object message) + { + if (!(message is IByteBuffer)) + { + return TaskUtil.FromException(new UnsupportedMessageTypeException(message, typeof(IByteBuffer))); + } + return _pendingUnencryptedWrites.Add(message); + } + + #endregion - IByteBuffer buf = null; - try - { - while (true) + #region -- Flush -- + + public override void Flush(IChannelHandlerContext context) { - List messages = _pendingUnencryptedWrites.Current; - if (messages == null || messages.Count == 0) - { - break; - } + if (_pendingUnencryptedWrites.IsEmpty) + { + _pendingUnencryptedWrites.Add(Unpooled.Empty); + } - if (messages.Count == 1) - { - buf = (IByteBuffer)messages[0]; - } - else - { - buf = context.Allocator.Buffer((int)_pendingUnencryptedWrites.CurrentSize); - foreach (IByteBuffer buffer in messages) + if (!EnsureAuthenticated()) { - buffer.ReadBytes(buf, buffer.ReadableBytes); - buffer.Release(); + _state |= TlsHandlerState.FlushedBeforeHandshake; + return; } - } - buf.ReadBytes(_sslStream, buf.ReadableBytes); // this leads to FinishWrap being called 0+ times - buf.Release(); - TaskCompletionSource promise = _pendingUnencryptedWrites.Remove(); - Task task = _lastContextWriteTask; - if (task != null) - { - task.LinkOutcome(promise); - _lastContextWriteTask = null; - } - else - { - promise.TryComplete(); - } + try + { + Wrap(context); + } + finally + { + // We may have written some parts of data before an exception was thrown so ensure we always flush. + context.Flush(); + } } - } - catch (Exception ex) - { - buf.SafeRelease(); - HandleFailure(ex); - throw; - } - } - #endregion + #endregion - #region ** FinishWrap ** + #region ** Wrap ** - private void FinishWrap(byte[] buffer, int offset, int count) - { - IByteBuffer output; - if (count == 0) - { - output = Unpooled.Empty; - } - else - { - output = _capturedContext.Allocator.Buffer(count); - output.WriteBytes(buffer, offset, count); - } - - _lastContextWriteTask = _capturedContext.WriteAsync(output); - } + private void Wrap(IChannelHandlerContext context) + { + Contract.Assert(context == _capturedContext); - #endregion + IByteBuffer buf = null; + try + { + while (true) + { + List messages = _pendingUnencryptedWrites.Current; + if (messages == null || messages.Count == 0) + { + break; + } + + if (messages.Count == 1) + { + buf = (IByteBuffer)messages[0]; + } + else + { + buf = context.Allocator.Buffer((int)_pendingUnencryptedWrites.CurrentSize); + foreach (IByteBuffer buffer in messages) + { + buffer.ReadBytes(buf, buffer.ReadableBytes); + buffer.Release(); + } + } + buf.ReadBytes(_sslStream, buf.ReadableBytes); // this leads to FinishWrap being called 0+ times + buf.Release(); + + TaskCompletionSource promise = _pendingUnencryptedWrites.Remove(); + Task task = _lastContextWriteTask; + if (task != null) + { + task.LinkOutcome(promise); + _lastContextWriteTask = null; + } + else + { + promise.TryComplete(); + } + } + } + catch (Exception ex) + { + buf.SafeRelease(); + HandleFailure(ex); + throw; + } + } - #region ** FinishWrapNonAppDataAsync ** + #endregion - private Task FinishWrapNonAppDataAsync(byte[] buffer, int offset, int count) - { - var future = _capturedContext.WriteAndFlushAsync(Unpooled.WrappedBuffer(buffer, offset, count)); - this.ReadIfNeeded(_capturedContext); - return future; - } + #region ** FinishWrap ** - #endregion + private void FinishWrap(byte[] buffer, int offset, int count) + { + IByteBuffer output; + if (count == 0) + { + output = Unpooled.Empty; + } + else + { + output = _capturedContext.Allocator.Buffer(count); + output.WriteBytes(buffer, offset, count); + } - #region -- CloseAsync -- + _lastContextWriteTask = _capturedContext.WriteAsync(output); + } - public override Task CloseAsync(IChannelHandlerContext context) - { - _closeFuture.TryComplete(); - _sslStream.Dispose(); - return base.CloseAsync(context); - } + #endregion - #endregion + #region ** FinishWrapNonAppDataAsync ** - #region ** HandleFailure ** + private Task FinishWrapNonAppDataAsync(byte[] buffer, int offset, int count) + { + var future = _capturedContext.WriteAndFlushAsync(Unpooled.WrappedBuffer(buffer, offset, count)); + this.ReadIfNeeded(_capturedContext); + return future; + } - private void HandleFailure(Exception cause) - { - // Release all resources such as internal buffers that SSLEngine - // is managing. - - try - { - _sslStream.Dispose(); - } - catch (Exception) - { - // todo: evaluate following: - // only log in Debug mode as it most likely harmless and latest chrome still trigger - // this all the time. - // - // See https://github.com/netty/netty/issues/1340 - //string msg = ex.Message; - //if (msg == null || !msg.contains("possible truncation attack")) - //{ - // //Logger.Debug("{} SSLEngine.closeInbound() raised an exception.", ctx.channel(), e); - //} - } - _pendingSslStreamReadBuffer?.SafeRelease(); - _pendingSslStreamReadBuffer = null; - _pendingSslStreamReadFuture = null; + #endregion - NotifyHandshakeFailure(cause); - _pendingUnencryptedWrites.RemoveAndFailAll(cause); - } + #region -- CloseAsync -- - #endregion + public override Task CloseAsync(IChannelHandlerContext context) + { + _closeFuture.TryComplete(); + _sslStream.Dispose(); + return base.CloseAsync(context); + } - #region ** NotifyHandshakeFailure ** + #endregion - private void NotifyHandshakeFailure(Exception cause) - { - if (!_state.HasAny(TlsHandlerState.AuthenticationCompleted)) - { - // handshake was not completed yet => TlsHandler react to failure by closing the channel - _state = (_state | TlsHandlerState.FailedAuthentication) & ~TlsHandlerState.Authenticating; - _capturedContext.FireUserEventTriggered(new TlsHandshakeCompletionEvent(cause)); - CloseAsync(_capturedContext); - } - } + #region ** HandleFailure ** - #endregion + private void HandleFailure(Exception cause) + { + // Release all resources such as internal buffers that SSLEngine + // is managing. - #region ** class MediationStream ** + _mediationStream.Dispose(); + try + { + _sslStream.Dispose(); + } + catch (Exception) + { + // todo: evaluate following: + // only log in Debug mode as it most likely harmless and latest chrome still trigger + // this all the time. + // + // See https://github.com/netty/netty/issues/1340 + //string msg = ex.Message; + //if (msg == null || !msg.contains("possible truncation attack")) + //{ + // //Logger.Debug("{} SSLEngine.closeInbound() raised an exception.", ctx.channel(), e); + //} + } + _pendingSslStreamReadBuffer?.SafeRelease(); + _pendingSslStreamReadBuffer = null; + _pendingSslStreamReadFuture = null; - private sealed class MediationStream : Stream - { - private readonly TlsHandler _owner; - private byte[] _input; - private int _inputStartOffset; - private int _inputOffset; - private int _inputLength; - private TaskCompletionSource _readCompletionSource; - private ArraySegment _sslOwnedBuffer; + NotifyHandshakeFailure(cause); + _pendingUnencryptedWrites.RemoveAndFailAll(cause); + } + + #endregion + + #region ** NotifyHandshakeFailure ** + + private void NotifyHandshakeFailure(Exception cause) + { + if (!_state.HasAny(TlsHandlerState.AuthenticationCompleted)) + { + // handshake was not completed yet => TlsHandler react to failure by closing the channel + _state = (_state | TlsHandlerState.FailedAuthentication) & ~TlsHandlerState.Authenticating; + _capturedContext.FireUserEventTriggered(new TlsHandshakeCompletionEvent(cause)); + CloseAsync(_capturedContext); + } + } + + #endregion + + #region ** class MediationStream ** + + private sealed class MediationStream : Stream + { + private readonly TlsHandler _owner; + private byte[] _input; + private int _inputStartOffset; + private int _inputOffset; + private int _inputLength; + private TaskCompletionSource _readCompletionSource; + private ArraySegment _sslOwnedBuffer; #if NETSTANDARD - private int _readByteCount; + private int _readByteCount; #else - private SynchronousAsyncResult _syncReadResult; - private AsyncCallback _readCallback; - private TaskCompletionSource _writeCompletion; - private AsyncCallback _writeCallback; + private SynchronousAsyncResult _syncReadResult; + private AsyncCallback _readCallback; + private TaskCompletionSource _writeCompletion; + private AsyncCallback _writeCallback; #endif - public MediationStream(TlsHandler owner) - { - _owner = owner; - } - - public int SourceReadableBytes => _inputLength - _inputOffset; + public MediationStream(TlsHandler owner) + { + _owner = owner; + } - public void SetSource(byte[] source, int offset) - { - _input = source; - _inputStartOffset = offset; - _inputOffset = 0; - _inputLength = 0; - } + public int SourceReadableBytes => _inputLength - _inputOffset; - public void ResetSource() - { - _input = null; - _inputLength = 0; - } + public void SetSource(byte[] source, int offset) + { + _input = source; + _inputStartOffset = offset; + _inputOffset = 0; + _inputLength = 0; + } - public void ExpandSource(int count) - { - Contract.Assert(_input != null); + public void ResetSource() + { + _input = null; + _inputLength = 0; + } - _inputLength += count; + public void ExpandSource(int count) + { + Contract.Assert(_input != null); - TaskCompletionSource promise = _readCompletionSource; - if (promise == null) - { - // there is no pending read operation - keep for future - return; - } + _inputLength += count; - ArraySegment sslBuffer = _sslOwnedBuffer; + ArraySegment sslBuffer = _sslOwnedBuffer; + if (sslBuffer.Array == null) + { + // there is no pending read operation - keep for future + return; + } + _sslOwnedBuffer = default(ArraySegment); #if NETSTANDARD - this._readByteCount = this.ReadFromInput(sslBuffer.Array, sslBuffer.Offset, sslBuffer.Count); - // hack: this tricks SslStream's continuation to run synchronously instead of dispatching to TP. Remove once Begin/EndRead are available. - new Task( - ms => - { - var self = (MediationStream)ms; - TaskCompletionSource p = self._readCompletionSource; - this._readCompletionSource = null; - p.TrySetResult(self._readByteCount); - }, - this) - .RunSynchronously(TaskScheduler.Default); + _readByteCount = this.ReadFromInput(sslBuffer.Array, sslBuffer.Offset, sslBuffer.Count); + // hack: this tricks SslStream's continuation to run synchronously instead of dispatching to TP. Remove once Begin/EndRead are available. + new Task( + ms => + { + var self = (MediationStream)ms; + TaskCompletionSource p = self._readCompletionSource; + self._readCompletionSource = null; + p.TrySetResult(self._readByteCount); + }, + this) + .RunSynchronously(TaskScheduler.Default); #else - int read = ReadFromInput(sslBuffer.Array, sslBuffer.Offset, sslBuffer.Count); - _readCompletionSource = null; - promise.TrySetResult(read); - _readCallback?.Invoke(promise.Task); + int read = ReadFromInput(sslBuffer.Array, sslBuffer.Offset, sslBuffer.Count); + + TaskCompletionSource promise = _readCompletionSource; + _readCompletionSource = null; + promise.TrySetResult(read); + + AsyncCallback callback = _readCallback; + _readCallback = null; + callback?.Invoke(promise.Task); #endif - } + } #if NETSTANDARD - public override Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) - { - if (_inputLength - _inputOffset > 0) - { - // we have the bytes available upfront - write out synchronously - int read = ReadFromInput(buffer, offset, count); - return Task.FromResult(read); - } - - // take note of buffer - we will pass bytes there once available - _sslOwnedBuffer = new ArraySegment(buffer, offset, count); - _readCompletionSource = new TaskCompletionSource(); - return _readCompletionSource.Task; - } + public override Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + if (this.SourceReadableBytes > 0) + { + // we have the bytes available upfront - write out synchronously + int read = ReadFromInput(buffer, offset, count); + return Task.FromResult(read); + } + + Contract.Assert(_sslOwnedBuffer.Array == null); + // take note of buffer - we will pass bytes there once available + _sslOwnedBuffer = new ArraySegment(buffer, offset, count); + _readCompletionSource = new TaskCompletionSource(); + return _readCompletionSource.Task; + } #else - public override IAsyncResult BeginRead(byte[] buffer, int offset, int count, AsyncCallback callback, object state) - { - if (_inputLength - _inputOffset > 0) - { - // we have the bytes available upfront - write out synchronously - int read = ReadFromInput(buffer, offset, count); - return PrepareSyncReadResult(read, state); - } - - // take note of buffer - we will pass bytes there once available - _sslOwnedBuffer = new ArraySegment(buffer, offset, count); - _readCompletionSource = new TaskCompletionSource(state); - _readCallback = callback; - return _readCompletionSource.Task; - } - - public override int EndRead(IAsyncResult asyncResult) - { - SynchronousAsyncResult syncResult = _syncReadResult; - if (ReferenceEquals(asyncResult, syncResult)) - { - return syncResult.Result; - } - - Contract.Assert(!((Task)asyncResult).IsCanceled); + public override IAsyncResult BeginRead(byte[] buffer, int offset, int count, AsyncCallback callback, object state) + { + if (this.SourceReadableBytes > 0) + { + // we have the bytes available upfront - write out synchronously + int read = ReadFromInput(buffer, offset, count); + var res = this.PrepareSyncReadResult(read, state); + callback?.Invoke(res); + return res; + } + + Contract.Assert(_sslOwnedBuffer.Array == null); + // take note of buffer - we will pass bytes there once available + _sslOwnedBuffer = new ArraySegment(buffer, offset, count); + _readCompletionSource = new TaskCompletionSource(state); + _readCallback = callback; + return _readCompletionSource.Task; + } - try - { - return ((Task)asyncResult).Result; - } - catch (AggregateException ex) - { + public override int EndRead(IAsyncResult asyncResult) + { + SynchronousAsyncResult syncResult = _syncReadResult; + if (ReferenceEquals(asyncResult, syncResult)) + { + return syncResult.Result; + } + + Debug.Assert(_readCompletionSource == null || _readCompletionSource.Task == asyncResult); + Contract.Assert(!((Task)asyncResult).IsCanceled); + + try + { + return ((Task)asyncResult).Result; + } + catch (AggregateException ex) + { #if !NET40 - ExceptionDispatchInfo.Capture(ex.InnerException).Throw(); + ExceptionDispatchInfo.Capture(ex.InnerException).Throw(); #else - throw ExceptionEnlightenment.PrepareForRethrow(ex.InnerException); + throw ExceptionEnlightenment.PrepareForRethrow(ex.InnerException); #endif - throw; // unreachable - } - finally - { - _readCompletionSource = null; - _readCallback = null; - _sslOwnedBuffer = default(ArraySegment); - } - } - - private IAsyncResult PrepareSyncReadResult(int readBytes, object state) - { - // it is safe to reuse sync result object as it can't lead to leak (no way to attach to it via handle) - SynchronousAsyncResult result = _syncReadResult ?? (_syncReadResult = new SynchronousAsyncResult()); - result.Result = readBytes; - result.AsyncState = state; - return result; - } + throw; // unreachable + } + } + + private IAsyncResult PrepareSyncReadResult(int readBytes, object state) + { + // it is safe to reuse sync result object as it can't lead to leak (no way to attach to it via handle) + SynchronousAsyncResult result = _syncReadResult ?? (_syncReadResult = new SynchronousAsyncResult()); + result.Result = readBytes; + result.AsyncState = state; + return result; + } #endif - public override void Write(byte[] buffer, int offset, int count) => _owner.FinishWrap(buffer, offset, count); + public override void Write(byte[] buffer, int offset, int count) => _owner.FinishWrap(buffer, offset, count); #if !NET40 - public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) - => _owner.FinishWrapNonAppDataAsync(buffer, offset, count); + public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + => _owner.FinishWrapNonAppDataAsync(buffer, offset, count); #endif #if DESKTOPCLR #if !NET40 - private static readonly Action s_writeCompleteCallback = HandleChannelWriteComplete; + private static readonly Action s_writeCompleteCallback = HandleChannelWriteComplete; #endif - public override IAsyncResult BeginWrite(byte[] buffer, int offset, int count, AsyncCallback callback, object state) - { + public override IAsyncResult BeginWrite(byte[] buffer, int offset, int count, AsyncCallback callback, object state) + { #if NET40 - Task task = _owner.FinishWrapNonAppDataAsync(buffer, offset, count); + Task task = _owner.FinishWrapNonAppDataAsync(buffer, offset, count); #else - Task task = this.WriteAsync(buffer, offset, count); + Task task = this.WriteAsync(buffer, offset, count); #endif - switch (task.Status) - { - case TaskStatus.RanToCompletion: - // write+flush completed synchronously (and successfully) - var result = new SynchronousAsyncResult - { - AsyncState = state - }; - callback(result); - return result; - - default: - _writeCallback = callback; - var tcs = new TaskCompletionSource(state); - _writeCompletion = tcs; + switch (task.Status) + { + case TaskStatus.RanToCompletion: + // write+flush completed synchronously (and successfully) + var result = new SynchronousAsyncResult + { + AsyncState = state + }; + callback?.Invoke(result); + return result; + + default: + if (callback != null || state != task.AsyncState) + { + Contract.Assert(_writeCompletion == null); + _writeCallback = callback; + var tcs = new TaskCompletionSource(state); + _writeCompletion = tcs; #if !NET40 - task.ContinueWith(s_writeCompleteCallback, this, TaskContinuationOptions.ExecuteSynchronously); + task.ContinueWith(s_writeCompleteCallback, this, TaskContinuationOptions.ExecuteSynchronously); #else - Action continuationAction = completed => HandleChannelWriteComplete(completed, this); - task.ContinueWith(continuationAction, TaskContinuationOptions.ExecuteSynchronously); + Action continuationAction = completed => HandleChannelWriteComplete(completed, this); + task.ContinueWith(continuationAction, TaskContinuationOptions.ExecuteSynchronously); #endif - return tcs.Task; - } - } + return tcs.Task; + } + else + { + return task; + } + } + } - private static void HandleChannelWriteComplete(Task writeTask, object state) - { - var self = (MediationStream)state; - switch (writeTask.Status) - { - case TaskStatus.RanToCompletion: - self._writeCompletion.TryComplete(); - break; + private static void HandleChannelWriteComplete(Task writeTask, object state) + { + var self = (MediationStream)state; - case TaskStatus.Canceled: - self._writeCompletion.TrySetCanceled(); - break; + AsyncCallback callback = self._writeCallback; + self._writeCallback = null; - case TaskStatus.Faulted: - self._writeCompletion.TrySetException(writeTask.Exception); - break; + var promise = self._writeCompletion; + self._writeCompletion = null; - default: - ThrowHelper.ThrowArgumentOutOfRangeException_WriteComplete(writeTask.Status); break; - } + switch (writeTask.Status) + { + case TaskStatus.RanToCompletion: + promise.TryComplete(); + break; - self._writeCallback?.Invoke(self._writeCompletion.Task); - } + case TaskStatus.Canceled: + promise.TrySetCanceled(); + break; - public override void EndWrite(IAsyncResult asyncResult) - { - _writeCallback = null; - _writeCompletion = null; + case TaskStatus.Faulted: + promise.TrySetException(writeTask.Exception); + break; - if (asyncResult is SynchronousAsyncResult) - { - return; - } + default: + ThrowHelper.ThrowArgumentOutOfRangeException_WriteComplete(writeTask.Status); break; + } - try - { - ((Task)asyncResult).Wait(); - } - catch (AggregateException ex) - { + callback?.Invoke(promise.Task); + } + + public override void EndWrite(IAsyncResult asyncResult) + { + if (asyncResult is SynchronousAsyncResult) + { + return; + } + + try + { + ((Task)asyncResult).Wait(); + } + catch (AggregateException ex) + { #if !NET40 - ExceptionDispatchInfo.Capture(ex.InnerException).Throw(); + ExceptionDispatchInfo.Capture(ex.InnerException).Throw(); #else - throw ExceptionEnlightenment.PrepareForRethrow(ex.InnerException); + throw ExceptionEnlightenment.PrepareForRethrow(ex.InnerException); #endif - throw; - } - } + throw; + } + } #endif - private int ReadFromInput(byte[] destination, int destinationOffset, int destinationCapacity) - { - Contract.Assert(destination != null); - - byte[] source = _input; - int readableBytes = _inputLength - _inputOffset; - int length = Math.Min(readableBytes, destinationCapacity); - Buffer.BlockCopy(source, _inputStartOffset + _inputOffset, destination, destinationOffset, length); - _inputOffset += length; - return length; - } - - public override void Flush() - { - // NOOP: called on SslStream.Close - } - - protected override void Dispose(bool disposing) - { - base.Dispose(disposing); - if (disposing) - { - TaskCompletionSource p = _readCompletionSource; - _readCompletionSource = null; - p?.TrySetResult(0); - } - } + private int ReadFromInput(byte[] destination, int destinationOffset, int destinationCapacity) + { + Contract.Assert(destination != null); + + byte[] source = _input; + int readableBytes = this.SourceReadableBytes; + int length = Math.Min(readableBytes, destinationCapacity); + Buffer.BlockCopy(source, _inputStartOffset + _inputOffset, destination, destinationOffset, length); + _inputOffset += length; + return length; + } + + public override void Flush() + { + // NOOP: called on SslStream.Close + } + + protected override void Dispose(bool disposing) + { + base.Dispose(disposing); + if (disposing) + { + TaskCompletionSource p = _readCompletionSource; + if (p != null) + { + _readCompletionSource = null; + p.TrySetResult(0); + } + } + } + + #region plumbing + + public override long Seek(long offset, SeekOrigin origin) + { + throw new NotSupportedException(); + } - #region plumbing + public override void SetLength(long value) + { + throw new NotSupportedException(); + } - public override long Seek(long offset, SeekOrigin origin) - { - throw new NotSupportedException(); - } + public override int Read(byte[] buffer, int offset, int count) + { + throw new NotSupportedException(); + } - public override void SetLength(long value) - { - throw new NotSupportedException(); - } + public override bool CanRead => true; - public override int Read(byte[] buffer, int offset, int count) - { - throw new NotSupportedException(); - } + public override bool CanSeek => false; - public override bool CanRead => true; + public override bool CanWrite => true; - public override bool CanSeek => false; + public override long Length + { + get { throw new NotSupportedException(); } + } + + public override long Position + { + get { throw new NotSupportedException(); } + set { throw new NotSupportedException(); } + } - public override bool CanWrite => true; + #endregion - public override long Length - { - get { throw new NotSupportedException(); } - } + #region sync result - public override long Position - { - get { throw new NotSupportedException(); } - set { throw new NotSupportedException(); } - } + private sealed class SynchronousAsyncResult : IAsyncResult + { + public T Result { get; set; } - #endregion + public bool IsCompleted => true; - #region sync result + public WaitHandle AsyncWaitHandle + { + get { throw new InvalidOperationException("Cannot wait on a synchronous result."); } + } - private sealed class SynchronousAsyncResult : IAsyncResult - { - public T Result { get; set; } + public object AsyncState { get; set; } - public bool IsCompleted => true; + public bool CompletedSynchronously => true; + } - public WaitHandle AsyncWaitHandle - { - get { throw new InvalidOperationException("Cannot wait on a synchronous result."); } + #endregion } - public object AsyncState { get; set; } + #endregion + } - public bool CompletedSynchronously => true; - } + #region == enum TlsHandlerState == - #endregion + [Flags] + internal enum TlsHandlerState + { + Authenticating = 1, + Authenticated = 1 << 1, + FailedAuthentication = 1 << 2, + ReadRequestedBeforeAuthenticated = 1 << 3, + FlushedBeforeHandshake = 1 << 4, + AuthenticationStarted = Authenticating | Authenticated | FailedAuthentication, + AuthenticationCompleted = Authenticated | FailedAuthentication } #endregion - } - - #region == enum TlsHandlerState == - [Flags] - internal enum TlsHandlerState - { - Authenticating = 1, - Authenticated = 1 << 1, - FailedAuthentication = 1 << 2, - ReadRequestedBeforeAuthenticated = 1 << 3, - FlushedBeforeHandshake = 1 << 4, - AuthenticationStarted = Authenticating | Authenticated | FailedAuthentication, - AuthenticationCompleted = Authenticated | FailedAuthentication - } + #region == class TlsHandlerStateExtensions == - #endregion - - #region == class TlsHandlerStateExtensions == - - internal static class TlsHandlerStateExtensions - { - public static bool Has(this TlsHandlerState value, TlsHandlerState testValue) => (value & testValue) == testValue; + internal static class TlsHandlerStateExtensions + { + public static bool Has(this TlsHandlerState value, TlsHandlerState testValue) => (value & testValue) == testValue; - public static bool HasAny(this TlsHandlerState value, TlsHandlerState testValue) => (value & testValue) != 0; - } + public static bool HasAny(this TlsHandlerState value, TlsHandlerState testValue) => (value & testValue) != 0; + } - #endregion + #endregion } \ No newline at end of file diff --git a/src/DotNetty.Transport/Channels/Sockets/TcpServerSocketChannel.cs b/src/DotNetty.Transport/Channels/Sockets/TcpServerSocketChannel.cs index 0106a7012..5de3e5717 100644 --- a/src/DotNetty.Transport/Channels/Sockets/TcpServerSocketChannel.cs +++ b/src/DotNetty.Transport/Channels/Sockets/TcpServerSocketChannel.cs @@ -18,7 +18,7 @@ public class TcpServerSocketChannel : AbstractS where TChannelFactory : ITcpSocketChannelFactory, new() { //static readonly IInternalLogger Logger = InternalLoggerFactory.GetInstance(); ## 苦竹 屏蔽 ## - static readonly ChannelMetadata METADATA = new ChannelMetadata(false, 16); + static readonly ChannelMetadata METADATA = new ChannelMetadata(false); static readonly Action ReadCompletedSyncCallback = OnReadCompletedSync; diff --git a/test/DotNetty.Codecs.Protobuf.Tests/DotNetty.Codecs.Protobuf.Tests.csproj b/test/DotNetty.Codecs.Protobuf.Tests/DotNetty.Codecs.Protobuf.Tests.csproj index d05f4d68a..fbda200fe 100644 --- a/test/DotNetty.Codecs.Protobuf.Tests/DotNetty.Codecs.Protobuf.Tests.csproj +++ b/test/DotNetty.Codecs.Protobuf.Tests/DotNetty.Codecs.Protobuf.Tests.csproj @@ -8,7 +8,7 @@ - + diff --git a/test/DotNetty.Codecs.ProtocolBuffers.Tests/DotNetty.Codecs.ProtocolBuffers.Tests.csproj b/test/DotNetty.Codecs.ProtocolBuffers.Tests/DotNetty.Codecs.ProtocolBuffers.Tests.csproj index 930b7fdcd..80b7207f9 100644 --- a/test/DotNetty.Codecs.ProtocolBuffers.Tests/DotNetty.Codecs.ProtocolBuffers.Tests.csproj +++ b/test/DotNetty.Codecs.ProtocolBuffers.Tests/DotNetty.Codecs.ProtocolBuffers.Tests.csproj @@ -7,7 +7,7 @@ - + diff --git a/test/DotNetty.Common.Tests/DotNetty.Common.Tests.csproj b/test/DotNetty.Common.Tests/DotNetty.Common.Tests.csproj index 4643775b6..f28d8af8f 100644 --- a/test/DotNetty.Common.Tests/DotNetty.Common.Tests.csproj +++ b/test/DotNetty.Common.Tests/DotNetty.Common.Tests.csproj @@ -16,7 +16,6 @@ - diff --git a/test/DotNetty.Common.Tests/Internal/Logging/InternalLoggerFactoryTest.cs b/test/DotNetty.Common.Tests/Internal/Logging/InternalLoggerFactoryTest.cs index 7c1d0f7f9..7f125f33e 100644 --- a/test/DotNetty.Common.Tests/Internal/Logging/InternalLoggerFactoryTest.cs +++ b/test/DotNetty.Common.Tests/Internal/Logging/InternalLoggerFactoryTest.cs @@ -3,59 +3,62 @@ namespace DotNetty.Common.Tests.Internal.Logging { - using System; - using DotNetty.Common.Internal.Logging; - using DotNetty.Tests.Common; - using Microsoft.Extensions.Logging; - using Moq; - using Xunit; - - public class InternalLoggerFactoryTest - { - // todo: CodeContracts on CI - //[Fact] - //public void ShouldNotAllowNullDefaultFactory() - //{ - // Assert.ThrowsAny(() => InternalLoggerFactory.DefaultFactory = null); - //} - - [Fact] - public void ShouldGetInstance() + using System; + using DotNetty.Common.Internal.Logging; + using DotNetty.Tests.Common; + using Microsoft.Extensions.Logging; + using Moq; + using Xunit; + +#if !TEST40 + [CollectionDefinition(nameof(InternalLoggerFactoryTest), DisableParallelization = true)] +#endif + public class InternalLoggerFactoryTest { - IInternalLogger one = InternalLoggerFactory.GetInstance("helloWorld"); - IInternalLogger two = InternalLoggerFactory.GetInstance(); + // todo: CodeContracts on CI + //[Fact] + //public void ShouldNotAllowNullDefaultFactory() + //{ + // Assert.ThrowsAny(() => InternalLoggerFactory.DefaultFactory = null); + //} - Assert.NotNull(one); - Assert.NotNull(two); - Assert.NotSame(one, two); - } + [Fact] + public void ShouldGetInstance() + { + IInternalLogger one = InternalLoggerFactory.GetInstance("helloWorld"); + IInternalLogger two = InternalLoggerFactory.GetInstance(); - [Fact] - public void TestMockReturned() - { - Mock mock; - using (SetupMockLogger(out mock)) - { - mock.Setup(x => x.IsEnabled(LogLevel.Trace)).Returns(true).Verifiable(); + Assert.NotNull(one); + Assert.NotNull(two); + Assert.NotSame(one, two); + } - IInternalLogger logger = InternalLoggerFactory.GetInstance("mock"); + [Fact] + public void TestMockReturned() + { + Mock mock; + using (SetupMockLogger(out mock)) + { + mock.Setup(x => x.IsEnabled(LogLevel.Trace)).Returns(true).Verifiable(); - Assert.True(logger.TraceEnabled); - mock.Verify(x => x.IsEnabled(LogLevel.Trace), Times.Once); - } - } + IInternalLogger logger = InternalLoggerFactory.GetInstance("mock"); - static IDisposable SetupMockLogger(out Mock loggerMock) - { - ILoggerFactory oldLoggerFactory = InternalLoggerFactory.DefaultFactory; - var loggerFactory = new LoggerFactory(); - var factoryMock = new Mock(MockBehavior.Strict); - ILoggerProvider mockFactory = factoryMock.Object; - loggerMock = new Mock(MockBehavior.Strict); - loggerFactory.AddProvider(mockFactory); - factoryMock.Setup(x => x.CreateLogger("mock")).Returns(loggerMock.Object); - InternalLoggerFactory.DefaultFactory = loggerFactory; - return new Disposable(() => InternalLoggerFactory.DefaultFactory = oldLoggerFactory); + Assert.True(logger.TraceEnabled); + mock.Verify(x => x.IsEnabled(LogLevel.Trace), Times.Once); + } + } + + static IDisposable SetupMockLogger(out Mock loggerMock) + { + ILoggerFactory oldLoggerFactory = InternalLoggerFactory.DefaultFactory; + var loggerFactory = new LoggerFactory(); + var factoryMock = new Mock(MockBehavior.Strict); + ILoggerProvider mockFactory = factoryMock.Object; + loggerMock = new Mock(MockBehavior.Strict); + loggerFactory.AddProvider(mockFactory); + factoryMock.Setup(x => x.CreateLogger("mock")).Returns(loggerMock.Object); + InternalLoggerFactory.DefaultFactory = loggerFactory; + return new Disposable(() => InternalLoggerFactory.DefaultFactory = oldLoggerFactory); + } } - } } diff --git a/test/DotNetty.Handlers.Tests/TlsHandlerTest.cs b/test/DotNetty.Handlers.Tests/TlsHandlerTest.cs index 0d016ce3f..7aff77094 100644 --- a/test/DotNetty.Handlers.Tests/TlsHandlerTest.cs +++ b/test/DotNetty.Handlers.Tests/TlsHandlerTest.cs @@ -104,7 +104,11 @@ public async Task TlsRead(int[] frameLengths, bool isClient, IWriteStrategy writ await Task.WhenAll(writeTasks).WithTimeout(TimeSpan.FromSeconds(5)); IByteBuffer finalReadBuffer = Unpooled.Buffer(16 * 1024); await ReadOutboundAsync(async () => ch.ReadInbound(), expectedBuffer.ReadableBytes, finalReadBuffer, TestTimeout); - Assert.True(ByteBufferUtil.Equals(expectedBuffer, finalReadBuffer), $"---Expected:\n{ByteBufferUtil.PrettyHexDump(expectedBuffer)}\n---Actual:\n{ByteBufferUtil.PrettyHexDump(finalReadBuffer)}"); + bool isEqual = ByteBufferUtil.Equals(expectedBuffer, finalReadBuffer); + if (!isEqual) + { + Assert.True(isEqual, $"---Expected:\n{ByteBufferUtil.PrettyHexDump(expectedBuffer)}\n---Actual:\n{ByteBufferUtil.PrettyHexDump(finalReadBuffer)}"); + } driverStream.Dispose(); Assert.False(ch.Finish()); } @@ -192,7 +196,11 @@ await ReadOutboundAsync( return Unpooled.WrappedBuffer(readBuffer, 0, read); }, expectedBuffer.ReadableBytes, finalReadBuffer, TestTimeout); - Assert.True(ByteBufferUtil.Equals(expectedBuffer, finalReadBuffer), $"---Expected:\n{ByteBufferUtil.PrettyHexDump(expectedBuffer)}\n---Actual:\n{ByteBufferUtil.PrettyHexDump(finalReadBuffer)}"); + bool isEqual = ByteBufferUtil.Equals(expectedBuffer, finalReadBuffer); + if (!isEqual) + { + Assert.True(isEqual, $"---Expected:\n{ByteBufferUtil.PrettyHexDump(expectedBuffer)}\n---Actual:\n{ByteBufferUtil.PrettyHexDump(finalReadBuffer)}"); + } driverStream.Dispose(); Assert.False(ch.Finish()); }