diff --git a/shared/contoso.com.pfx b/shared/contoso.com.pfx
new file mode 100644
index 000000000..bacddeff7
Binary files /dev/null and b/shared/contoso.com.pfx differ
diff --git a/src/DotNetty.Codecs/ByteToMessageDecoder.cs b/src/DotNetty.Codecs/ByteToMessageDecoder.cs
index f950a44ea..1f676e978 100644
--- a/src/DotNetty.Codecs/ByteToMessageDecoder.cs
+++ b/src/DotNetty.Codecs/ByteToMessageDecoder.cs
@@ -151,6 +151,9 @@ static IByteBuffer ExpandCumulation(IByteBufferAllocator allocator, IByteBuffer
public override void HandlerRemoved(IChannelHandlerContext context)
{
IByteBuffer buf = this.InternalBuffer;
+
+ // Directly set this to null so we are sure we not access it in any other method here anymore.
+ this.cumulation = null;
int readable = buf.ReadableBytes;
if (readable > 0)
{
@@ -162,7 +165,7 @@ public override void HandlerRemoved(IChannelHandlerContext context)
{
buf.Release();
}
- this.cumulation = null;
+
context.FireChannelReadComplete();
this.HandlerRemovedInternal(context);
}
diff --git a/src/DotNetty.Handlers/DotNetty.Handlers.csproj b/src/DotNetty.Handlers/DotNetty.Handlers.csproj
index 465ff7cf2..9fde1728b 100644
--- a/src/DotNetty.Handlers/DotNetty.Handlers.csproj
+++ b/src/DotNetty.Handlers/DotNetty.Handlers.csproj
@@ -37,4 +37,9 @@
+
+
+ 4.3.0
+
+
\ No newline at end of file
diff --git a/src/DotNetty.Handlers/Tls/ServerTlsSniSettings.cs b/src/DotNetty.Handlers/Tls/ServerTlsSniSettings.cs
new file mode 100644
index 000000000..5dcd941d5
--- /dev/null
+++ b/src/DotNetty.Handlers/Tls/ServerTlsSniSettings.cs
@@ -0,0 +1,23 @@
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+
+namespace DotNetty.Handlers.Tls
+{
+ using System;
+ using System.Diagnostics.Contracts;
+ using System.Threading.Tasks;
+
+ public sealed class ServerTlsSniSettings
+ {
+ public ServerTlsSniSettings(Func> serverTlsSettingMap, string defaultServerHostName = null)
+ {
+ Contract.Requires(serverTlsSettingMap != null);
+ this.ServerTlsSettingMap = serverTlsSettingMap;
+ this.DefaultServerHostName = defaultServerHostName;
+ }
+
+ public Func> ServerTlsSettingMap { get; }
+
+ public string DefaultServerHostName { get; }
+ }
+}
\ No newline at end of file
diff --git a/src/DotNetty.Handlers/Tls/SniHandler.cs b/src/DotNetty.Handlers/Tls/SniHandler.cs
new file mode 100644
index 000000000..d9398da49
--- /dev/null
+++ b/src/DotNetty.Handlers/Tls/SniHandler.cs
@@ -0,0 +1,316 @@
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+
+namespace DotNetty.Handlers.Tls
+{
+ using System;
+ using System.Collections.Generic;
+ using System.Diagnostics.Contracts;
+ using System.Globalization;
+ using System.IO;
+ using System.Net.Security;
+ using System.Text;
+ using DotNetty.Buffers;
+ using DotNetty.Codecs;
+ using DotNetty.Common.Internal.Logging;
+ using DotNetty.Transport.Channels;
+
+ public sealed class SniHandler : ByteToMessageDecoder
+ {
+ // Maximal number of ssl records to inspect before fallback to the default (aligned with netty)
+ const int MAX_SSL_RECORDS = 4;
+ static readonly IInternalLogger Logger = InternalLoggerFactory.GetInstance(typeof(SniHandler));
+ readonly Func sslStreamFactory;
+ readonly ServerTlsSniSettings serverTlsSniSettings;
+
+ bool handshakeFailed;
+ bool suppressRead;
+ bool readPending;
+
+ public SniHandler(ServerTlsSniSettings settings)
+ : this(stream => new SslStream(stream, true), settings)
+ {
+ }
+
+ public SniHandler(Func sslStreamFactory, ServerTlsSniSettings settings)
+ {
+ Contract.Requires(settings != null);
+ Contract.Requires(sslStreamFactory != null);
+ this.sslStreamFactory = sslStreamFactory;
+ this.serverTlsSniSettings = settings;
+ }
+
+ protected override void Decode(IChannelHandlerContext context, IByteBuffer input, List output)
+ {
+ if (!this.suppressRead && !this.handshakeFailed)
+ {
+ int writerIndex = input.WriterIndex;
+ Exception error = null;
+ try
+ {
+ bool continueLoop = true;
+ for (int i = 0; i < MAX_SSL_RECORDS && continueLoop; i++)
+ {
+ int readerIndex = input.ReaderIndex;
+ int readableBytes = writerIndex - readerIndex;
+ if (readableBytes < TlsUtils.SSL_RECORD_HEADER_LENGTH)
+ {
+ // Not enough data to determine the record type and length.
+ return;
+ }
+
+ int command = input.GetByte(readerIndex);
+ // tls, but not handshake command
+ switch (command)
+ {
+ case TlsUtils.SSL_CONTENT_TYPE_CHANGE_CIPHER_SPEC:
+ case TlsUtils.SSL_CONTENT_TYPE_ALERT:
+ int len = TlsUtils.GetEncryptedPacketLength(input, readerIndex);
+
+ // Not an SSL/TLS packet
+ if (len == TlsUtils.NOT_ENCRYPTED)
+ {
+ this.handshakeFailed = true;
+ var e = new NotSslRecordException(
+ "not an SSL/TLS record: " + ByteBufferUtil.HexDump(input));
+ input.SkipBytes(input.ReadableBytes);
+
+ TlsUtils.NotifyHandshakeFailure(context, e);
+ throw e;
+ }
+ if (len == TlsUtils.NOT_ENOUGH_DATA ||
+ writerIndex - readerIndex - TlsUtils.SSL_RECORD_HEADER_LENGTH < len)
+ {
+ // Not enough data
+ return;
+ }
+
+ // increase readerIndex and try again.
+ input.SkipBytes(len);
+ continue;
+
+ case TlsUtils.SSL_CONTENT_TYPE_HANDSHAKE:
+ int majorVersion = input.GetByte(readerIndex + 1);
+
+ // SSLv3 or TLS
+ if (majorVersion == 3)
+ {
+ int packetLength = input.GetUnsignedShort(readerIndex + 3) + TlsUtils.SSL_RECORD_HEADER_LENGTH;
+
+ if (readableBytes < packetLength)
+ {
+ // client hello incomplete; try again to decode once more data is ready.
+ return;
+ }
+
+ // See https://tools.ietf.org/html/rfc5246#section-7.4.1.2
+ //
+ // Decode the ssl client hello packet.
+ // We have to skip bytes until SessionID (which sum to 43 bytes).
+ //
+ // struct {
+ // ProtocolVersion client_version;
+ // Random random;
+ // SessionID session_id;
+ // CipherSuite cipher_suites<2..2^16-2>;
+ // CompressionMethod compression_methods<1..2^8-1>;
+ // select (extensions_present) {
+ // case false:
+ // struct {};
+ // case true:
+ // Extension extensions<0..2^16-1>;
+ // };
+ // } ClientHello;
+ //
+
+ int endOffset = readerIndex + packetLength;
+ int offset = readerIndex + 43;
+
+ if (endOffset - offset < 6)
+ {
+ continueLoop = false;
+ break;
+ }
+
+ int sessionIdLength = input.GetByte(offset);
+ offset += sessionIdLength + 1;
+
+ int cipherSuitesLength = input.GetUnsignedShort(offset);
+ offset += cipherSuitesLength + 2;
+
+ int compressionMethodLength = input.GetByte(offset);
+ offset += compressionMethodLength + 1;
+
+ int extensionsLength = input.GetUnsignedShort(offset);
+ offset += 2;
+ int extensionsLimit = offset + extensionsLength;
+
+ if (extensionsLimit > endOffset)
+ {
+ // Extensions should never exceed the record boundary.
+ continueLoop = false;
+ break;
+ }
+
+ for (;;)
+ {
+ if (extensionsLimit - offset < 4)
+ {
+ continueLoop = false;
+ break;
+ }
+
+ int extensionType = input.GetUnsignedShort(offset);
+ offset += 2;
+
+ int extensionLength = input.GetUnsignedShort(offset);
+ offset += 2;
+
+ if (extensionsLimit - offset < extensionLength)
+ {
+ continueLoop = false;
+ break;
+ }
+
+ // SNI
+ // See https://tools.ietf.org/html/rfc6066#page-6
+ if (extensionType == 0)
+ {
+ offset += 2;
+ if (extensionsLimit - offset < 3)
+ {
+ continueLoop = false;
+ break;
+ }
+
+ int serverNameType = input.GetByte(offset);
+ offset++;
+
+ if (serverNameType == 0)
+ {
+ int serverNameLength = input.GetUnsignedShort(offset);
+ offset += 2;
+
+ if (serverNameLength <= 0 || extensionsLimit - offset < serverNameLength)
+ {
+ continueLoop = false;
+ break;
+ }
+
+ string hostname = input.ToString(offset, serverNameLength, Encoding.UTF8);
+ //try
+ //{
+ // select(ctx, IDN.toASCII(hostname,
+ // IDN.ALLOW_UNASSIGNED).toLowerCase(Locale.US));
+ //}
+ //catch (Throwable t)
+ //{
+ // PlatformDependent.throwException(t);
+ //}
+
+ var idn = new IdnMapping()
+ {
+ AllowUnassigned = true
+ };
+
+ hostname = idn.GetAscii(hostname);
+#if NETSTANDARD1_3
+ // TODO: netcore does not have culture sensitive tolower()
+ hostname = hostname.ToLowerInvariant();
+#else
+ hostname = hostname.ToLower(new CultureInfo("en-US"));
+#endif
+ this.Select(context, hostname);
+ return;
+ }
+ else
+ {
+ // invalid enum value
+ continueLoop = false;
+ break;
+ }
+ }
+
+ offset += extensionLength;
+ }
+ }
+
+ break;
+ // Fall-through
+ default:
+ //not tls, ssl or application data, do not try sni
+ continueLoop = false;
+ break;
+ }
+ }
+ }
+ catch (Exception e)
+ {
+ error = e;
+
+ // unexpected encoding, ignore sni and use default
+ if (Logger.DebugEnabled)
+ {
+ Logger.Warn($"Unexpected client hello packet: {ByteBufferUtil.HexDump(input)}", e);
+ }
+ }
+
+ if (this.serverTlsSniSettings.DefaultServerHostName != null)
+ {
+ // Just select the default certifcate
+ this.Select(context, this.serverTlsSniSettings.DefaultServerHostName);
+ }
+ else
+ {
+ this.handshakeFailed = true;
+ var e = new DecoderException($"failed to get the Tls Certificate {error}");
+ TlsUtils.NotifyHandshakeFailure(context, e);
+ throw e;
+ }
+ }
+ }
+
+ async void Select(IChannelHandlerContext context, string hostName)
+ {
+ Contract.Requires(hostName != null);
+ this.suppressRead = true;
+ try
+ {
+ var serverTlsSetting = await this.serverTlsSniSettings.ServerTlsSettingMap(hostName);
+ this.ReplaceHandler(context, serverTlsSetting);
+ }
+ catch (Exception ex)
+ {
+ this.ExceptionCaught(context, new DecoderException($"failed to get the Tls Certificate for {hostName}, {ex}"));
+ }
+ finally
+ {
+ this.suppressRead = false;
+ if (this.readPending)
+ {
+ this.readPending = false;
+ context.Read();
+ }
+ }
+ }
+
+ void ReplaceHandler(IChannelHandlerContext context, ServerTlsSettings serverTlsSetting)
+ {
+ Contract.Requires(serverTlsSetting != null);
+ var tlsHandler = new TlsHandler(this.sslStreamFactory, serverTlsSetting);
+ context.Channel.Pipeline.Replace(this, nameof(TlsHandler), tlsHandler);
+ }
+
+ public override void Read(IChannelHandlerContext context)
+ {
+ if (this.suppressRead)
+ {
+ this.readPending = true;
+ }
+ else
+ {
+ base.Read(context);
+ }
+ }
+ }
+}
diff --git a/src/DotNetty.Handlers/Tls/TlsUtils.cs b/src/DotNetty.Handlers/Tls/TlsUtils.cs
index b647aef56..c271c823e 100644
--- a/src/DotNetty.Handlers/Tls/TlsUtils.cs
+++ b/src/DotNetty.Handlers/Tls/TlsUtils.cs
@@ -33,6 +33,12 @@ static class TlsUtils
/// the length of the ssl record header (in bytes)
public const int SSL_RECORD_HEADER_LENGTH = 5;
+ // Not enough data in buffer to parse the record length
+ public const int NOT_ENOUGH_DATA = -1;
+
+ // data is not encrypted
+ public const int NOT_ENCRYPTED = -2;
+
///
/// Return how much bytes can be read out of the encrypted data. Be aware that this method will not increase
/// the readerIndex of the given .
diff --git a/test/DotNetty.Handlers.Tests/SniHandlerTest.cs b/test/DotNetty.Handlers.Tests/SniHandlerTest.cs
new file mode 100644
index 000000000..9f43d5c4e
--- /dev/null
+++ b/test/DotNetty.Handlers.Tests/SniHandlerTest.cs
@@ -0,0 +1,280 @@
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+
+namespace DotNetty.Handlers.Tests
+{
+ using System;
+ using System.Collections.Generic;
+ using System.Diagnostics;
+ using System.Linq;
+ using System.Net.Security;
+ using System.Security.Authentication;
+ using System.Security.Cryptography.X509Certificates;
+ using System.Threading.Tasks;
+ using DotNetty.Buffers;
+ using DotNetty.Common.Concurrency;
+ using DotNetty.Handlers.Tls;
+ using DotNetty.Tests.Common;
+ using DotNetty.Transport.Channels;
+ using DotNetty.Transport.Channels.Embedded;
+ using Xunit;
+ using Xunit.Abstractions;
+
+ public class SniHandlerTest : TestBase
+ {
+ static readonly TimeSpan TestTimeout = TimeSpan.FromSeconds(10);
+ static readonly Dictionary SettingMap = new Dictionary();
+
+ static SniHandlerTest()
+ {
+ X509Certificate2 tlsCertificate = TestResourceHelper.GetTestCertificate();
+ X509Certificate2 tlsCertificate2 = TestResourceHelper.GetTestCertificate2();
+
+ SettingMap[tlsCertificate.GetNameInfo(X509NameType.DnsName, false)] = new ServerTlsSettings(tlsCertificate, false, false, SslProtocols.Tls12);
+ SettingMap[tlsCertificate2.GetNameInfo(X509NameType.DnsName, false)] = new ServerTlsSettings(tlsCertificate2, false, false, SslProtocols.Tls12);
+ }
+
+ public SniHandlerTest(ITestOutputHelper output)
+ : base(output)
+ {
+ }
+
+ public static IEnumerable GetTlsReadTestData()
+ {
+ var lengthVariations =
+ new[]
+ {
+ new[] { 1 }
+ };
+ var boolToggle = new[] { false, true };
+ var protocols = new[] { SslProtocols.Tls12 };
+ var writeStrategyFactories = new Func[]
+ {
+ () => new AsIsWriteStrategy()
+ };
+
+ return
+ from frameLengths in lengthVariations
+ from isClient in boolToggle
+ from writeStrategyFactory in writeStrategyFactories
+ from protocol in protocols
+ from targetHost in SettingMap.Keys
+ select new object[] { frameLengths, isClient, writeStrategyFactory(), protocol, targetHost };
+ }
+
+
+ [Theory]
+ [MemberData(nameof(GetTlsReadTestData))]
+ public async Task TlsRead(int[] frameLengths, bool isClient, IWriteStrategy writeStrategy, SslProtocols protocol, string targetHost)
+ {
+ this.Output.WriteLine($"frameLengths: {string.Join(", ", frameLengths)}");
+ this.Output.WriteLine($"writeStrategy: {writeStrategy}");
+ this.Output.WriteLine($"protocol: {protocol}");
+ this.Output.WriteLine($"targetHost: {targetHost}");
+
+ var executor = new SingleThreadEventExecutor("test executor", TimeSpan.FromMilliseconds(10));
+
+ try
+ {
+ var writeTasks = new List();
+ var pair = await SetupStreamAndChannelAsync(isClient, executor, writeStrategy, protocol, writeTasks, targetHost).WithTimeout(TimeSpan.FromSeconds(10));
+ EmbeddedChannel ch = pair.Item1;
+ SslStream driverStream = pair.Item2;
+
+ int randomSeed = Environment.TickCount;
+ var random = new Random(randomSeed);
+ IByteBuffer expectedBuffer = Unpooled.Buffer(16 * 1024);
+ foreach (int len in frameLengths)
+ {
+ var data = new byte[len];
+ random.NextBytes(data);
+ expectedBuffer.WriteBytes(data);
+ await driverStream.WriteAsync(data, 0, data.Length).WithTimeout(TimeSpan.FromSeconds(5));
+ }
+ await Task.WhenAll(writeTasks).WithTimeout(TimeSpan.FromSeconds(5));
+ IByteBuffer finalReadBuffer = Unpooled.Buffer(16 * 1024);
+ await ReadOutboundAsync(async () => ch.ReadInbound(), expectedBuffer.ReadableBytes, finalReadBuffer, TestTimeout);
+ Assert.True(ByteBufferUtil.Equals(expectedBuffer, finalReadBuffer), $"---Expected:\n{ByteBufferUtil.PrettyHexDump(expectedBuffer)}\n---Actual:\n{ByteBufferUtil.PrettyHexDump(finalReadBuffer)}");
+
+ if (!isClient)
+ {
+ // check if snihandler got replaced with tls handler
+ Assert.Null(ch.Pipeline.Get());
+ Assert.NotNull(ch.Pipeline.Get());
+ }
+
+ driverStream.Dispose();
+ }
+ finally
+ {
+ await executor.ShutdownGracefullyAsync(TimeSpan.FromMilliseconds(100), TimeSpan.FromMilliseconds(300));
+ }
+ }
+
+ public static IEnumerable GetTlsWriteTestData()
+ {
+ var lengthVariations =
+ new[]
+ {
+ new[] { 1 }
+ };
+ var boolToggle = new[] { false, true };
+ var protocols = new[] { SslProtocols.Tls12 };
+
+ return
+ from frameLengths in lengthVariations
+ from isClient in boolToggle
+ from protocol in protocols
+ from targetHost in SettingMap.Keys
+ select new object[] { frameLengths, isClient, protocol, targetHost };
+ }
+
+ [Theory]
+ [MemberData(nameof(GetTlsWriteTestData))]
+ public async Task TlsWrite(int[] frameLengths, bool isClient, SslProtocols protocol, string targetHost)
+ {
+ this.Output.WriteLine("frameLengths: " + string.Join(", ", frameLengths));
+ this.Output.WriteLine($"protocol: {protocol}");
+ this.Output.WriteLine($"targetHost: {targetHost}");
+
+ var writeStrategy = new AsIsWriteStrategy();
+ var executor = new SingleThreadEventExecutor("test executor", TimeSpan.FromMilliseconds(10));
+
+ try
+ {
+ var writeTasks = new List();
+ var pair = await SetupStreamAndChannelAsync(isClient, executor, writeStrategy, protocol, writeTasks, targetHost);
+ EmbeddedChannel ch = pair.Item1;
+ SslStream driverStream = pair.Item2;
+
+ int randomSeed = Environment.TickCount;
+ var random = new Random(randomSeed);
+ IByteBuffer expectedBuffer = Unpooled.Buffer(16 * 1024);
+ foreach (IEnumerable lengths in frameLengths.Split(x => x < 0))
+ {
+ ch.WriteOutbound(lengths.Select(len =>
+ {
+ var data = new byte[len];
+ random.NextBytes(data);
+ expectedBuffer.WriteBytes(data);
+ return (object)Unpooled.WrappedBuffer(data);
+ }).ToArray());
+ }
+
+ IByteBuffer finalReadBuffer = Unpooled.Buffer(16 * 1024);
+ var readBuffer = new byte[16 * 1024 * 10];
+ await ReadOutboundAsync(
+ async () =>
+ {
+ int read = await driverStream.ReadAsync(readBuffer, 0, readBuffer.Length);
+ return Unpooled.WrappedBuffer(readBuffer, 0, read);
+ },
+ expectedBuffer.ReadableBytes, finalReadBuffer, TestTimeout);
+ Assert.True(ByteBufferUtil.Equals(expectedBuffer, finalReadBuffer), $"---Expected:\n{ByteBufferUtil.PrettyHexDump(expectedBuffer)}\n---Actual:\n{ByteBufferUtil.PrettyHexDump(finalReadBuffer)}");
+
+ if (!isClient)
+ {
+ // check if snihandler got replaced with tls handler
+ Assert.Null(ch.Pipeline.Get());
+ Assert.NotNull(ch.Pipeline.Get());
+ }
+
+ driverStream.Dispose();
+ }
+ finally
+ {
+ await executor.ShutdownGracefullyAsync(TimeSpan.FromMilliseconds(100), TimeSpan.FromMilliseconds(300));
+ }
+ }
+
+ static async Task> SetupStreamAndChannelAsync(bool isClient, IEventExecutor executor, IWriteStrategy writeStrategy, SslProtocols protocol, List writeTasks, string targetHost)
+ {
+ IChannelHandler tlsHandler = isClient ?
+ (IChannelHandler)new TlsHandler(stream => new SslStream(stream, true, (sender, certificate, chain, errors) =>
+ {
+ Assert.Equal(targetHost, certificate.Issuer.Replace("CN=", string.Empty));
+ return true;
+ }), new ClientTlsSettings(SslProtocols.Tls12, false, new List(), targetHost)) :
+ new SniHandler(stream => new SslStream(stream, true, (sender, certificate, chain, errors) => true), new ServerTlsSniSettings(CertificateSelector));
+ //var ch = new EmbeddedChannel(new LoggingHandler("BEFORE"), tlsHandler, new LoggingHandler("AFTER"));
+ var ch = new EmbeddedChannel(tlsHandler);
+
+ if (!isClient)
+ {
+ // check if in the beginning snihandler exists in the pipeline, but not tls handler
+ Assert.NotNull(ch.Pipeline.Get());
+ Assert.Null(ch.Pipeline.Get());
+ }
+
+ IByteBuffer readResultBuffer = Unpooled.Buffer(4 * 1024);
+ Func, Task> readDataFunc = async output =>
+ {
+ if (writeTasks.Count > 0)
+ {
+ await Task.WhenAll(writeTasks).WithTimeout(TestTimeout);
+ writeTasks.Clear();
+ }
+
+ if (readResultBuffer.ReadableBytes < output.Count)
+ {
+ await ReadOutboundAsync(async () => ch.ReadOutbound(), output.Count - readResultBuffer.ReadableBytes, readResultBuffer, TestTimeout);
+ }
+ Assert.NotEqual(0, readResultBuffer.ReadableBytes);
+ int read = Math.Min(output.Count, readResultBuffer.ReadableBytes);
+ readResultBuffer.ReadBytes(output.Array, output.Offset, read);
+ return read;
+ };
+ var mediationStream = new MediationStream(readDataFunc, input =>
+ {
+ Task task = executor.SubmitAsync(() => writeStrategy.WriteToChannelAsync(ch, input)).Unwrap();
+ writeTasks.Add(task);
+ return task;
+ });
+
+ var driverStream = new SslStream(mediationStream, true, (_1, _2, _3, _4) => true);
+ if (isClient)
+ {
+ await Task.Run(() => driverStream.AuthenticateAsServerAsync(CertificateSelector(targetHost).Result.Certificate).WithTimeout(TimeSpan.FromSeconds(5)));
+ }
+ else
+ {
+ await Task.Run(() => driverStream.AuthenticateAsClientAsync(targetHost, null, protocol, false)).WithTimeout(TimeSpan.FromSeconds(5));
+ }
+ writeTasks.Clear();
+
+ return Tuple.Create(ch, driverStream);
+ }
+
+ static Task CertificateSelector(string hostName)
+ {
+ Assert.NotNull(hostName);
+ Assert.Contains(hostName, SettingMap.Keys);
+ return Task.FromResult(SettingMap[hostName]);
+ }
+
+ static Task ReadOutboundAsync(Func> readFunc, int expectedBytes, IByteBuffer result, TimeSpan timeout)
+ {
+ Stopwatch stopwatch = Stopwatch.StartNew();
+ int remaining = expectedBytes;
+ return AssertEx.EventuallyAsync(
+ async () =>
+ {
+ TimeSpan readTimeout = timeout - stopwatch.Elapsed;
+ if (readTimeout <= TimeSpan.Zero)
+ {
+ return false;
+ }
+
+ IByteBuffer output = await readFunc().WithTimeout(readTimeout);//inbound ? ch.ReadInbound() : ch.ReadOutbound();
+ if (output != null)
+ {
+ remaining -= output.ReadableBytes;
+ result.WriteBytes(output);
+ }
+ return remaining <= 0;
+ },
+ TimeSpan.FromMilliseconds(10),
+ timeout);
+ }
+ }
+}
\ No newline at end of file
diff --git a/test/DotNetty.Handlers.Tests/TlsHandlerTest.cs b/test/DotNetty.Handlers.Tests/TlsHandlerTest.cs
index 9e25e2c49..6aa93f183 100644
--- a/test/DotNetty.Handlers.Tests/TlsHandlerTest.cs
+++ b/test/DotNetty.Handlers.Tests/TlsHandlerTest.cs
@@ -93,6 +93,7 @@ public async Task TlsRead(int[] frameLengths, bool isClient, IWriteStrategy writ
IByteBuffer finalReadBuffer = Unpooled.Buffer(16 * 1024);
await ReadOutboundAsync(async () => ch.ReadInbound(), expectedBuffer.ReadableBytes, finalReadBuffer, TestTimeout);
Assert.True(ByteBufferUtil.Equals(expectedBuffer, finalReadBuffer), $"---Expected:\n{ByteBufferUtil.PrettyHexDump(expectedBuffer)}\n---Actual:\n{ByteBufferUtil.PrettyHexDump(finalReadBuffer)}");
+ driverStream.Dispose();
}
finally
{
@@ -166,6 +167,7 @@ await ReadOutboundAsync(
},
expectedBuffer.ReadableBytes, finalReadBuffer, TestTimeout);
Assert.True(ByteBufferUtil.Equals(expectedBuffer, finalReadBuffer), $"---Expected:\n{ByteBufferUtil.PrettyHexDump(expectedBuffer)}\n---Actual:\n{ByteBufferUtil.PrettyHexDump(finalReadBuffer)}");
+ driverStream.Dispose();
}
finally
{
diff --git a/test/DotNetty.Tests.Common/DotNetty.Tests.Common.csproj b/test/DotNetty.Tests.Common/DotNetty.Tests.Common.csproj
index 68bbc48b7..1f6ce1395 100644
--- a/test/DotNetty.Tests.Common/DotNetty.Tests.Common.csproj
+++ b/test/DotNetty.Tests.Common/DotNetty.Tests.Common.csproj
@@ -8,6 +8,7 @@
+
win-x64
diff --git a/test/DotNetty.Tests.Common/TestResourceHelper.cs b/test/DotNetty.Tests.Common/TestResourceHelper.cs
index 334b296db..fae5670ec 100644
--- a/test/DotNetty.Tests.Common/TestResourceHelper.cs
+++ b/test/DotNetty.Tests.Common/TestResourceHelper.cs
@@ -21,5 +21,18 @@ public static X509Certificate2 GetTestCertificate()
return new X509Certificate2(certData, "password");
}
+
+ public static X509Certificate2 GetTestCertificate2()
+ {
+ byte[] certData;
+ using (Stream resStream = typeof(TestResourceHelper).GetTypeInfo().Assembly.GetManifestResourceStream(typeof(TestResourceHelper).Namespace + "." + "contoso.com.pfx"))
+ using (var memStream = new MemoryStream())
+ {
+ resStream.CopyTo(memStream);
+ certData = memStream.ToArray();
+ }
+
+ return new X509Certificate2(certData, "password");
+ }
}
}
\ No newline at end of file