diff --git a/Directory.Build.props b/Directory.Build.props index dc40ddfdd..b594669f4 100644 --- a/Directory.Build.props +++ b/Directory.Build.props @@ -1,5 +1,5 @@ - - 8.0 + + 12.0 \ No newline at end of file diff --git a/examples/WebRTCScenarios/DataChannelBandwidth/BandwidthTestProgram.cs b/examples/WebRTCScenarios/DataChannelBandwidth/BandwidthTestProgram.cs index be21d322c..e6050fd8a 100644 --- a/examples/WebRTCScenarios/DataChannelBandwidth/BandwidthTestProgram.cs +++ b/examples/WebRTCScenarios/DataChannelBandwidth/BandwidthTestProgram.cs @@ -135,4 +135,4 @@ void SendRecv(RTCDataChannel channel, ref long received, { Interlocked.Add(ref received, stream.Read(buffer)); } -} \ No newline at end of file +} diff --git a/examples/WebRTCScenarios/DataChannelBandwidth/DataChannelBandwidth.csproj b/examples/WebRTCScenarios/DataChannelBandwidth/DataChannelBandwidth.csproj index d82759f1d..3fc2aee1b 100644 --- a/examples/WebRTCScenarios/DataChannelBandwidth/DataChannelBandwidth.csproj +++ b/examples/WebRTCScenarios/DataChannelBandwidth/DataChannelBandwidth.csproj @@ -2,8 +2,7 @@ Exe - net6.0 - 12 + net8.0 enable enable diff --git a/examples/WebRTCScenarios/DataChannelBandwidth/DataChannelStream.cs b/examples/WebRTCScenarios/DataChannelBandwidth/DataChannelStream.cs index 84fe9e539..1a2e2431b 100644 --- a/examples/WebRTCScenarios/DataChannelBandwidth/DataChannelStream.cs +++ b/examples/WebRTCScenarios/DataChannelBandwidth/DataChannelStream.cs @@ -10,7 +10,7 @@ class DataChannelStream : Stream { readonly RTCDataChannel channel; int currentMessageOffset; - byte[] message = []; + ArraySegment message = Array.Empty(); readonly CancellationTokenSource closed = new(); readonly SemaphoreSlim messageNeeded = new(0, 1); readonly SemaphoreSlim messageAvailable = new(0, maxCount: 1); @@ -75,7 +75,7 @@ or RTCPeerConnectionState.disconnected } int messages; - void OnMessage(RTCDataChannel _, DataChannelPayloadProtocols protocol, byte[] data) + void OnMessage(RTCDataChannel _, DataChannelPayloadProtocols protocol, ReadOnlySpan data) { int seq = Interlocked.Increment(ref messages); log?.LogDebug("{Seq} received", seq); @@ -84,7 +84,11 @@ void OnMessage(RTCDataChannel _, DataChannelPayloadProtocols protocol, byte[] da messageNeeded.Wait(); lock (sync) { - message = data; + if (message.Array is not null) + { + ArrayPool.Shared.Return(message.Array); + } + message = data.ToArraySegment(ArrayPool.Shared); currentMessageOffset = 0; } messageAvailable.Release(); @@ -139,7 +143,7 @@ public override int Read(Span buffer) lock (sync) { - int remaining = message.Length - currentMessageOffset; + int remaining = message.Count - currentMessageOffset; int toCopy = Math.Min(remaining, buffer.Length); message.AsSpan(currentMessageOffset, toCopy).CopyTo(buffer); currentMessageOffset += toCopy; @@ -193,7 +197,7 @@ public override async ValueTask ReadAsync(Memory buffer, lock (sync) { - int remaining = message.Length - currentMessageOffset; + int remaining = message.Count - currentMessageOffset; int toCopy = Math.Min(remaining, buffer.Length); message.AsSpan(currentMessageOffset, toCopy).CopyTo(buffer.Span); currentMessageOffset += toCopy; @@ -208,7 +212,7 @@ public override async ValueTask ReadAsync(Memory buffer, } long totalRead; - bool MessageNeeded() => message.Length - currentMessageOffset == 0; + bool MessageNeeded() => message.Count - currentMessageOffset == 0; public override void Write(byte[] buffer, int offset, int count) { @@ -235,13 +239,13 @@ public override async ValueTask WriteAsync(ReadOnlyMemory buffer, { cancellationToken.ThrowIfCancellationRequested(); var packet = buffer[..MaxSendBytes]; - await Task.Run(() => Send(packet.Span.ToArray()), cancellationToken) + await Task.Run(() => Send(packet.Span), cancellationToken) .ConfigureAwait(false); buffer = buffer[MaxSendBytes..]; } cancellationToken.ThrowIfCancellationRequested(); if (buffer.Length > 0) - Send(buffer.Span.ToArray()); + Send(buffer.Span); } finally { @@ -250,7 +254,7 @@ await Task.Run(() => Send(packet.Span.ToArray()), cancellationToken) } long totalSent; - void Send(byte[] buffer) + void Send(ReadOnlySpan buffer) { channel.send(buffer); Interlocked.Add(ref totalSent, buffer.Length); @@ -264,11 +268,11 @@ public override void Write(ReadOnlySpan buffer) while (buffer.Length > MaxSendBytes) { var packet = buffer[..MaxSendBytes]; - Send(packet.ToArray()); + Send(packet); buffer = buffer[MaxSendBytes..]; } if (buffer.Length > 0) - Send(buffer.ToArray()); + Send(buffer); } finally { diff --git a/examples/WebRTCScenarios/DataChannelBandwidth/SpanExtensions.cs b/examples/WebRTCScenarios/DataChannelBandwidth/SpanExtensions.cs new file mode 100644 index 000000000..ebd8c8e9f --- /dev/null +++ b/examples/WebRTCScenarios/DataChannelBandwidth/SpanExtensions.cs @@ -0,0 +1,13 @@ +using System.Buffers; + +namespace DataChannelBandwidth; + +static class SpanExtensions +{ + public static ArraySegment ToArraySegment(this ReadOnlySpan span, ArrayPool pool) + { + var result = pool.Rent(span.Length); + span.CopyTo(result); + return new ArraySegment(result, 0, span.Length); + } +} diff --git a/src/SIPSorcery.csproj b/src/SIPSorcery.csproj old mode 100755 new mode 100644 index 5ce11441d..c884407c8 --- a/src/SIPSorcery.csproj +++ b/src/SIPSorcery.csproj @@ -17,16 +17,17 @@ - + + - netstandard2.0;netstandard2.1;netcoreapp3.1;net461;net5.0;net6.0;net8.0 - true + netstandard2.0;net461;net6.0;net8.0 + true $(NoWarn);SYSLIB0050 true @@ -41,7 +42,6 @@ SIP Sorcery PTY LTD SIPSorcery SIPSorcery - true https://sipsorcery-org.github.io/sipsorcery/ http://www.sipsorcery.com/mainsite/favicon.ico icon.png @@ -50,6 +50,7 @@ master SIP WebRTC VoIP RTP SDP STUN ICE SIPSorcery -v8.0.0: RTP header extension improvements (thanks to @ChristopheI). Major version to 8 to reflect highest .net runtime supported. +-v7.0.0: upgraded BouncyCastle to v2. Set minimal supported DTLS to 1.2, and enabled 1.3. -v6.2.4: WebRTC fix for DTLS change in Chrome v124. -v6.2.3: Bug fixes. -v6.2.1: Bug fixes. @@ -67,9 +68,9 @@ -v6.0.2: Set .net6 targetted version as stable. -v6.0.1-pre: Added .net6 target. en - 8.0.0 - 8.0.0 - 8.0.0 + 9.0.0-PacketView-24.7.10.0 + 9.0.0 + 9.0.0 @@ -80,8 +81,8 @@ true + Embedded true - snupkg true diff --git a/src/net/DtlsSrtp/DtlsSrtpClient.cs b/src/net/DtlsSrtp/DtlsSrtpClient.cs index b892e8f72..8e7d00d74 100644 --- a/src/net/DtlsSrtp/DtlsSrtpClient.cs +++ b/src/net/DtlsSrtp/DtlsSrtpClient.cs @@ -16,11 +16,13 @@ using System; using System.Collections; using Microsoft.Extensions.Logging; -using Org.BouncyCastle.Crypto; -using Org.BouncyCastle.Crypto.Tls; +using Org.BouncyCastle.Tls; using Org.BouncyCastle.Security; using Org.BouncyCastle.Utilities; using SIPSorcery.Sys; +using System.Collections.Generic; +using Org.BouncyCastle.Crypto; +using Org.BouncyCastle.Tls.Crypto; namespace SIPSorcery.Net { @@ -36,7 +38,7 @@ internal DtlsSrtpTlsAuthentication(DtlsSrtpClient client) this.mContext = client.TlsContext; } - public virtual void NotifyServerCertificate(Certificate serverCertificate) + public virtual void NotifyServerCertificate(TlsServerCertificate serverCertificate) { //Console.WriteLine("DTLS client received server certificate chain of length " + chain.Length); mClient.ServerCertificate = serverCertificate; @@ -44,8 +46,8 @@ public virtual void NotifyServerCertificate(Certificate serverCertificate) public virtual TlsCredentials GetClientCredentials(CertificateRequest certificateRequest) { - byte[] certificateTypes = certificateRequest.CertificateTypes; - if (certificateTypes == null || !Arrays.Contains(certificateTypes, ClientCertificateType.rsa_sign) || !Arrays.Contains(certificateTypes, ClientCertificateType.ecdsa_sign)) + short[] certificateTypes = certificateRequest.CertificateTypes; + if (certificateTypes == null || !Arrays.Contains(certificateTypes, ClientCertificateType.rsa_sign)) { return null; } @@ -56,11 +58,6 @@ public virtual TlsCredentials GetClientCredentials(CertificateRequest certificat mClient.mCertificateChain, mClient.mPrivateKey); } - - public TlsCredentials GetClientCredentials(TlsContext context, CertificateRequest certificateRequest) - { - return GetClientCredentials(certificateRequest); - } }; public class DtlsSrtpClient : DefaultTlsClient, IDtlsSrtpPeer @@ -72,7 +69,7 @@ public class DtlsSrtpClient : DefaultTlsClient, IDtlsSrtpPeer internal TlsClientContext TlsContext { - get { return mContext; } + get { return m_context; } } protected internal TlsSession mSession; @@ -80,7 +77,7 @@ internal TlsClientContext TlsContext public bool ForceUseExtendedMasterSecret { get; set; } = true; //Received from server - public Certificate ServerCertificate { get; internal set; } + public TlsServerCertificate ServerCertificate { get; internal set; } public RTCDtlsFingerprint Fingerprint { get; private set; } @@ -105,36 +102,37 @@ internal TlsClientContext TlsContext /// public event Action OnAlert; - public DtlsSrtpClient() : - this(null, null, null) + public DtlsSrtpClient(TlsCrypto crypto) : + this(crypto, null, null, null) { } - public DtlsSrtpClient(System.Security.Cryptography.X509Certificates.X509Certificate2 certificate) : - this(DtlsUtils.LoadCertificateChain(certificate), DtlsUtils.LoadPrivateKeyResource(certificate)) + public DtlsSrtpClient(TlsCrypto crypto, System.Security.Cryptography.X509Certificates.X509Certificate2 certificate) : + this(crypto, DtlsUtils.LoadCertificateChain(crypto, certificate), DtlsUtils.LoadPrivateKeyResource(certificate)) { } - public DtlsSrtpClient(string certificatePath, string keyPath) : - this(new string[] { certificatePath }, keyPath) + public DtlsSrtpClient(TlsCrypto crypto, string certificatePath, string keyPath) : + this(crypto, new string[] { certificatePath }, keyPath) { } - public DtlsSrtpClient(string[] certificatesPath, string keyPath) : - this(DtlsUtils.LoadCertificateChain(certificatesPath), DtlsUtils.LoadPrivateKeyResource(keyPath)) + public DtlsSrtpClient(TlsCrypto crypto, string[] certificatesPath, string keyPath) : + this(crypto, DtlsUtils.LoadCertificateChain(crypto, certificatesPath), DtlsUtils.LoadPrivateKeyResource(keyPath)) { } - public DtlsSrtpClient(Certificate certificateChain, AsymmetricKeyParameter privateKey) : - this(certificateChain, privateKey, null) + public DtlsSrtpClient(TlsCrypto crypto, Certificate certificateChain, Org.BouncyCastle.Crypto.AsymmetricKeyParameter privateKey) : + this(crypto, certificateChain, privateKey, null) { } - public DtlsSrtpClient(Certificate certificateChain, AsymmetricKeyParameter privateKey, UseSrtpData clientSrtpData) + public DtlsSrtpClient(TlsCrypto crypto, Certificate certificateChain, Org.BouncyCastle.Crypto.AsymmetricKeyParameter privateKey, UseSrtpData clientSrtpData) : base(crypto) { + if (certificateChain == null && privateKey == null) { - (certificateChain, privateKey) = DtlsUtils.CreateSelfSignedTlsCert(); + (certificateChain, privateKey) = DtlsUtils.CreateSelfSignedTlsCert(crypto); } if (clientSrtpData == null) @@ -158,31 +156,33 @@ public DtlsSrtpClient(Certificate certificateChain, AsymmetricKeyParameter priva Fingerprint = certificate != null ? DtlsUtils.Fingerprint(certificate) : null; } - public DtlsSrtpClient(UseSrtpData clientSrtpData) : this(null, null, clientSrtpData) + public DtlsSrtpClient(TlsCrypto crypto, UseSrtpData clientSrtpData) : this(crypto, null, null, clientSrtpData) { } - public override IDictionary GetClientExtensions() + + public override IDictionary GetClientExtensions() { var clientExtensions = base.GetClientExtensions(); - if (TlsSRTPUtils.GetUseSrtpExtension(clientExtensions) == null) + if (TlsSrpUtilities.GetSrpExtension(clientExtensions) == null) { if (clientExtensions == null) { - clientExtensions = new Hashtable(); + clientExtensions = new Hashtable() as IDictionary; } - TlsSRTPUtils.AddUseSrtpExtension(clientExtensions, clientSrtpData); + TlsSrtpUtilities.AddUseSrtpExtension(clientExtensions, clientSrtpData); } return clientExtensions; } - public override void ProcessServerExtensions(IDictionary clientExtensions) + + public override void ProcessServerExtensions(IDictionary serverExtensions) { - base.ProcessServerExtensions(clientExtensions); + base.ProcessServerExtensions(serverExtensions); // set to some reasonable default value int chosenProfile = SrtpProtectionProfile.SRTP_AES128_CM_HMAC_SHA1_80; - UseSrtpData clientSrtpData = TlsSRTPUtils.GetUseSrtpExtension(clientExtensions); + clientSrtpData = TlsSrtpUtilities.GetUseSrtpExtension(serverExtensions); foreach (int profile in clientSrtpData.ProtectionProfiles) { @@ -244,12 +244,12 @@ public override void NotifyHandshakeComplete() { base.NotifyHandshakeComplete(); - //Copy master Secret (will be inaccessible after this call) - masterSecret = new byte[mContext.SecurityParameters.MasterSecret != null ? mContext.SecurityParameters.MasterSecret.Length : 0]; - Buffer.BlockCopy(mContext.SecurityParameters.MasterSecret, 0, masterSecret, 0, masterSecret.Length); - //Prepare Srtp Keys (we must to it here because master key will be cleared after that) PrepareSrtpSharedSecret(); + + //Copy master Secret (will be inaccessible after this call) + masterSecret = new byte[m_context.SecurityParameters.MasterSecret != null ? m_context.SecurityParameters.MasterSecret.Length : 0]; + Buffer.BlockCopy(m_context.SecurityParameters.MasterSecret.Extract(), 0, masterSecret, 0, masterSecret.Length); } public bool IsClient() @@ -269,7 +269,7 @@ protected virtual byte[] GetKeyingMaterial(string asciiLabel, byte[] context_val throw new ArgumentException("must have length less than 2^16 (or be null)", "context_value"); } - SecurityParameters sp = mContext.SecurityParameters; + SecurityParameters sp = m_context.SecurityParameters; if (!sp.IsExtendedMasterSecret && RequiresExtendedMasterSecret()) { /* @@ -309,7 +309,7 @@ protected virtual byte[] GetKeyingMaterial(string asciiLabel, byte[] context_val throw new InvalidOperationException("error in calculation of seed for export"); } - return TlsUtilities.PRF(mContext, sp.MasterSecret, asciiLabel, seed, length); + return TlsUtilities.Prf(sp, sp.MasterSecret, asciiLabel, seed, length).Extract(); } public override bool RequiresExtendedMasterSecret() @@ -371,22 +371,12 @@ protected virtual void PrepareSrtpSharedSecret() Buffer.BlockCopy(sharedSecret, (2 * keyLen + saltLen), srtpMasterServerSalt, 0, saltLen); } - public override ProtocolVersion ClientVersion - { - get { return ProtocolVersion.DTLSv12; } - } - - public override ProtocolVersion MinimumVersion - { - get { return ProtocolVersion.DTLSv10; } - } - public override TlsSession GetSessionToResume() { return this.mSession; } - public override void NotifyAlertRaised(byte alertLevel, byte alertDescription, string message, Exception cause) + public override void NotifyAlertRaised(short alertLevel, short alertDescription, string message, Exception cause) { string description = null; if (message != null) @@ -401,7 +391,7 @@ public override void NotifyAlertRaised(byte alertLevel, byte alertDescription, s string alertMessage = $"{AlertLevel.GetText(alertLevel)}, {AlertDescription.GetText(alertDescription)}"; alertMessage += !string.IsNullOrEmpty(description) ? $", {description}." : "."; - if (alertDescription == AlertTypesEnum.close_notify.GetHashCode()) + if (alertDescription == (byte)AlertTypesEnum.close_notify) { logger.LogDebug($"DTLS client raised close notification: {alertMessage}"); } @@ -418,22 +408,31 @@ public override void NotifyServerVersion(ProtocolVersion serverVersion) public Certificate GetRemoteCertificate() { - return ServerCertificate; + return ServerCertificate.Certificate; + } + + protected override ProtocolVersion[] GetSupportedVersions() + { + return new ProtocolVersion[] + { + ProtocolVersion.DTLSv10, + ProtocolVersion.DTLSv12, + }; } - public override void NotifyAlertReceived(byte alertLevel, byte alertDescription) + public override void NotifyAlertReceived(short alertLevel, short alertDescription) { string description = AlertDescription.GetText(alertDescription); AlertLevelsEnum level = AlertLevelsEnum.Warning; AlertTypesEnum alertType = AlertTypesEnum.unknown; - if (Enum.IsDefined(typeof(AlertLevelsEnum), alertLevel)) + if (Enum.IsDefined(typeof(AlertLevelsEnum), checked((byte)alertLevel))) { level = (AlertLevelsEnum)alertLevel; } - if (Enum.IsDefined(typeof(AlertTypesEnum), alertDescription)) + if (Enum.IsDefined(typeof(AlertTypesEnum), checked((byte)alertDescription))) { alertType = (AlertTypesEnum)alertDescription; } diff --git a/src/net/DtlsSrtp/DtlsSrtpServer.cs b/src/net/DtlsSrtp/DtlsSrtpServer.cs index cf1988b20..df458fd3c 100644 --- a/src/net/DtlsSrtp/DtlsSrtpServer.cs +++ b/src/net/DtlsSrtp/DtlsSrtpServer.cs @@ -22,7 +22,9 @@ using System.Collections.Generic; using Microsoft.Extensions.Logging; using Org.BouncyCastle.Crypto; -using Org.BouncyCastle.Crypto.Tls; +using Org.BouncyCastle.Tls; +using Org.BouncyCastle.Tls.Crypto; +using Org.BouncyCastle.Tls.Crypto.Impl.BC; using Org.BouncyCastle.Utilities; using SIPSorcery.Sys; @@ -114,8 +116,6 @@ public class DtlsSrtpServer : DefaultTlsServer, IDtlsSrtpPeer private SrtpPolicy srtpPolicy; private SrtpPolicy srtcpPolicy; - private int[] cipherSuites; - /// /// Parameters: /// - alert level, @@ -124,32 +124,30 @@ public class DtlsSrtpServer : DefaultTlsServer, IDtlsSrtpPeer /// public event Action OnAlert; - public DtlsSrtpServer() : this((Certificate)null, null) + public DtlsSrtpServer(TlsCrypto crypto) : this(crypto, (Certificate)null, null) { } - public DtlsSrtpServer(System.Security.Cryptography.X509Certificates.X509Certificate2 certificate) : this(DtlsUtils.LoadCertificateChain(certificate), DtlsUtils.LoadPrivateKeyResource(certificate)) + public DtlsSrtpServer(TlsCrypto crypto, System.Security.Cryptography.X509Certificates.X509Certificate2 certificate) : this(crypto,DtlsUtils.LoadCertificateChain(crypto,certificate), DtlsUtils.LoadPrivateKeyResource(certificate)) { } - public DtlsSrtpServer(string certificatePath, string keyPath) : this(new string[] { certificatePath }, keyPath) + public DtlsSrtpServer(TlsCrypto crypto, string certificatePath, string keyPath) : this(crypto, new string[] { certificatePath }, keyPath) { } - public DtlsSrtpServer(string[] certificatesPath, string keyPath) : - this(DtlsUtils.LoadCertificateChain(certificatesPath), DtlsUtils.LoadPrivateKeyResource(keyPath)) + public DtlsSrtpServer(TlsCrypto crypto, string[] certificatesPath, string keyPath) : + this(crypto, DtlsUtils.LoadCertificateChain(crypto, certificatesPath), DtlsUtils.LoadPrivateKeyResource(keyPath)) { } - public DtlsSrtpServer(Certificate certificateChain, AsymmetricKeyParameter privateKey) + public DtlsSrtpServer(TlsCrypto crypto, Certificate certificateChain, AsymmetricKeyParameter privateKey) : base(crypto) { if (certificateChain == null && privateKey == null) { - (certificateChain, privateKey) = DtlsUtils.CreateSelfSignedTlsCert(); + (certificateChain, privateKey) = DtlsUtils.CreateSelfSignedTlsCert(crypto); } - this.cipherSuites = base.GetCipherSuites(); - this.mPrivateKey = privateKey; mCertificateChain = certificateChain; @@ -183,19 +181,19 @@ public Certificate CertificateChain } } - protected override ProtocolVersion MaximumVersion + protected ProtocolVersion MaximumVersion { get { - return ProtocolVersion.DTLSv12; + return ProtocolVersion.DTLSv13; } } - protected override ProtocolVersion MinimumVersion + protected ProtocolVersion MinimumVersion { get { - return ProtocolVersion.DTLSv10; + return ProtocolVersion.DTLSv12; } } @@ -213,20 +211,32 @@ public override int GetSelectedCipherSuite() * must be negotiated only if the server can successfully complete the handshake while using the curves and point * formats supported by the client [...]. */ - bool eccCipherSuitesEnabled = SupportsClientEccCapabilities(this.mNamedCurves, this.mClientECPointFormats); int[] cipherSuites = GetCipherSuites(); for (int i = 0; i < cipherSuites.Length; ++i) { int cipherSuite = cipherSuites[i]; - if (Arrays.Contains(this.mOfferedCipherSuites, cipherSuite) - && (eccCipherSuitesEnabled || !TlsEccUtilities.IsEccCipherSuite(cipherSuite)) - && TlsUtilities.IsValidCipherSuiteForVersion(cipherSuite, mServerVersion)) + if (Arrays.Contains(this.m_offeredCipherSuites, cipherSuite) + && !TlsEccUtilities.IsEccCipherSuite(cipherSuite) + && TlsUtilities.IsValidVersionForCipherSuite(cipherSuite, GetServerVersion())) { - return this.mSelectedCipherSuite = cipherSuite; + return this.m_selectedCipherSuite = cipherSuite; } } + + for (int i = 0; i < cipherSuites.Length; ++i) + { + int cipherSuite = cipherSuites[i]; + + if (Arrays.Contains(this.m_offeredCipherSuites, cipherSuite) + && TlsEccUtilities.IsEccCipherSuite(cipherSuite) + && TlsUtilities.IsValidVersionForCipherSuite(cipherSuite, GetServerVersion())) + { + return this.m_selectedCipherSuite = cipherSuite; + } + } + throw new TlsFatalAlert(AlertDescription.handshake_failure); } @@ -234,10 +244,10 @@ public override CertificateRequest GetCertificateRequest() { List serverSigAlgs = new List(); - if (TlsUtilities.IsSignatureAlgorithmsExtensionAllowed(mServerVersion)) + if (TlsUtilities.IsSignatureAlgorithmsExtensionAllowed(GetServerVersion())) { - byte[] hashAlgorithms = new byte[] { HashAlgorithm.sha512, HashAlgorithm.sha384, HashAlgorithm.sha256, HashAlgorithm.sha224, HashAlgorithm.sha1 }; - byte[] signatureAlgorithms = new byte[] { SignatureAlgorithm.rsa, SignatureAlgorithm.ecdsa }; + short[] hashAlgorithms = new short[] { HashAlgorithm.sha512, HashAlgorithm.sha384, HashAlgorithm.sha256, HashAlgorithm.sha224, HashAlgorithm.sha1 }; + short[] signatureAlgorithms = new short[] { SignatureAlgorithm.rsa, SignatureAlgorithm.ecdsa }; serverSigAlgs = new List(); for (int i = 0; i < hashAlgorithms.Length; ++i) @@ -248,7 +258,7 @@ public override CertificateRequest GetCertificateRequest() } } } - return new CertificateRequest(new byte[] { ClientCertificateType.rsa_sign, ClientCertificateType.ecdsa_sign }, serverSigAlgs, null); + return new CertificateRequest(new short[] { ClientCertificateType.rsa_sign, ClientCertificateType.ecdsa_sign }, serverSigAlgs, null); } public override void NotifyClientCertificate(Certificate clientCertificate) @@ -256,27 +266,27 @@ public override void NotifyClientCertificate(Certificate clientCertificate) ClientCertificate = clientCertificate; } - public override IDictionary GetServerExtensions() + public override IDictionary GetServerExtensions() { - Hashtable serverExtensions = (Hashtable)base.GetServerExtensions(); - if (TlsSRTPUtils.GetUseSrtpExtension(serverExtensions) == null) + var serverExtensions = base.GetServerExtensions(); + if (TlsSrtpUtilities.GetUseSrtpExtension(serverExtensions) == null) { if (serverExtensions == null) { - serverExtensions = new Hashtable(); + serverExtensions = (IDictionary)new Hashtable(); } - TlsSRTPUtils.AddUseSrtpExtension(serverExtensions, serverSrtpData); + TlsSrtpUtilities.AddUseSrtpExtension(serverExtensions, serverSrtpData); } return serverExtensions; } - public override void ProcessClientExtensions(IDictionary clientExtensions) + public override void ProcessClientExtensions(IDictionary clientExtensions) { base.ProcessClientExtensions(clientExtensions); // set to some reasonable default value int chosenProfile = SrtpProtectionProfile.SRTP_AES128_CM_HMAC_SHA1_80; - UseSrtpData clientSrtpData = TlsSRTPUtils.GetUseSrtpExtension(clientExtensions); + UseSrtpData clientSrtpData = TlsSrtpUtilities.GetUseSrtpExtension(clientExtensions); foreach (int profile in clientSrtpData.ProtectionProfiles) { @@ -331,12 +341,12 @@ public byte[] GetSrtpMasterClientSalt() public override void NotifyHandshakeComplete() { + //Prepare Srtp Keys (we must to it here because master key will be cleared after that) + PrepareSrtpSharedSecret(); //Copy master Secret (will be inaccessible after this call) - masterSecret = new byte[mContext.SecurityParameters.MasterSecret != null ? mContext.SecurityParameters.MasterSecret.Length : 0]; - Buffer.BlockCopy(mContext.SecurityParameters.MasterSecret, 0, masterSecret, 0, masterSecret.Length); + masterSecret = new byte[m_context.SecurityParameters.MasterSecret != null ? m_context.SecurityParameters.MasterSecret.Length : 0]; + Buffer.BlockCopy(m_context.SecurityParameters.MasterSecret.Extract(), 0, masterSecret, 0, masterSecret.Length); - //Prepare Srtp Keys (we must to it here because master key will be cleared after that) - PrepareSrtpSharedSecret(); } public bool IsClient() @@ -344,42 +354,19 @@ public bool IsClient() return false; } - protected override TlsSignerCredentials GetECDsaSignerCredentials() + protected override TlsCredentialedSigner GetECDsaSignerCredentials() { - return DtlsUtils.LoadSignerCredentials(mContext, mCertificateChain, mPrivateKey, new SignatureAndHashAlgorithm(HashAlgorithm.sha256, SignatureAlgorithm.ecdsa)); + return new BcDefaultTlsCredentialedSigner(new TlsCryptoParameters(m_context), this.Crypto as BcTlsCrypto, mPrivateKey, mCertificateChain, new SignatureAndHashAlgorithm(HashAlgorithm.sha256, SignatureAlgorithm.ecdsa)); } - protected override TlsEncryptionCredentials GetRsaEncryptionCredentials() - { - return DtlsUtils.LoadEncryptionCredentials(mContext, mCertificateChain, mPrivateKey); - } - - protected override TlsSignerCredentials GetRsaSignerCredentials() - { - /* - * TODO Note that this code fails to provide default value for the client supported - * algorithms if it wasn't sent. - */ - SignatureAndHashAlgorithm signatureAndHashAlgorithm = null; - IList sigAlgs = mSupportedSignatureAlgorithms; - if (sigAlgs != null) - { - foreach (var sigAlgUncasted in sigAlgs) - { - SignatureAndHashAlgorithm sigAlg = sigAlgUncasted as SignatureAndHashAlgorithm; - if (sigAlg != null && sigAlg.Signature == SignatureAlgorithm.rsa) - { - signatureAndHashAlgorithm = sigAlg; - break; - } - } + protected override TlsCredentialedDecryptor GetRsaEncryptionCredentials() + { + return new BcDefaultTlsCredentialedDecryptor(Crypto as BcTlsCrypto, mCertificateChain, mPrivateKey); + } - if (signatureAndHashAlgorithm == null) - { - return null; - } - } - return DtlsUtils.LoadSignerCredentials(mContext, mCertificateChain, mPrivateKey, signatureAndHashAlgorithm); + protected override TlsCredentialedSigner GetRsaSignerCredentials() + { + return new BcDefaultTlsCredentialedSigner(new TlsCryptoParameters(m_context), this.Crypto as BcTlsCrypto, mPrivateKey, mCertificateChain, new SignatureAndHashAlgorithm(HashAlgorithm.sha256, SignatureAlgorithm.rsa)); } protected virtual void PrepareSrtpSharedSecret() @@ -448,7 +435,7 @@ protected virtual byte[] GetKeyingMaterial(string asciiLabel, byte[] context_val throw new ArgumentException("must have length less than 2^16 (or be null)", "context_value"); } - SecurityParameters sp = mContext.SecurityParameters; + SecurityParameters sp = m_context.SecurityParameters; if (!sp.IsExtendedMasterSecret && RequiresExtendedMasterSecret()) { /* @@ -488,7 +475,7 @@ protected virtual byte[] GetKeyingMaterial(string asciiLabel, byte[] context_val throw new InvalidOperationException("error in calculation of seed for export"); } - return TlsUtilities.PRF(mContext, sp.MasterSecret, asciiLabel, seed, length); + return TlsUtilities.Prf(sp, sp.MasterSecret, asciiLabel, seed, length).Extract(); } public override bool RequiresExtendedMasterSecret() @@ -496,22 +483,21 @@ public override bool RequiresExtendedMasterSecret() return ForceUseExtendedMasterSecret; } - protected override int[] GetCipherSuites() - { - int[] cipherSuites = new int[this.cipherSuites.Length]; - for (int i = 0; i < this.cipherSuites.Length; i++) - { - cipherSuites[i] = this.cipherSuites[i]; - } - return cipherSuites; - } - public Certificate GetRemoteCertificate() { return ClientCertificate; } - public override void NotifyAlertRaised(byte alertLevel, byte alertDescription, string message, Exception cause) + protected override ProtocolVersion[] GetSupportedVersions() + { + return new ProtocolVersion[] + { + ProtocolVersion.DTLSv10, + ProtocolVersion.DTLSv12 + }; + } + + public override void NotifyAlertRaised(short alertLevel, short alertDescription, string message, Exception cause) { string description = null; if (message != null) @@ -536,19 +522,19 @@ public override void NotifyAlertRaised(byte alertLevel, byte alertDescription, s } } - public override void NotifyAlertReceived(byte alertLevel, byte alertDescription) + public override void NotifyAlertReceived(short alertLevel, short alertDescription) { string description = AlertDescription.GetText(alertDescription); AlertLevelsEnum level = AlertLevelsEnum.Warning; AlertTypesEnum alertType = AlertTypesEnum.unknown; - if (Enum.IsDefined(typeof(AlertLevelsEnum), alertLevel)) + if (Enum.IsDefined(typeof(AlertLevelsEnum), checked((byte)alertLevel))) { level = (AlertLevelsEnum)alertLevel; } - if (Enum.IsDefined(typeof(AlertTypesEnum), alertDescription)) + if (Enum.IsDefined(typeof(AlertTypesEnum), checked((byte)alertDescription))) { alertType = (AlertTypesEnum)alertDescription; } diff --git a/src/net/DtlsSrtp/DtlsSrtpTransport.cs b/src/net/DtlsSrtp/DtlsSrtpTransport.cs old mode 100755 new mode 100644 index 3a72203d4..b4b514b40 --- a/src/net/DtlsSrtp/DtlsSrtpTransport.cs +++ b/src/net/DtlsSrtp/DtlsSrtpTransport.cs @@ -19,9 +19,10 @@ using System; using System.Collections.Concurrent; using Microsoft.Extensions.Logging; -using Org.BouncyCastle.Crypto.Tls; +using Org.BouncyCastle.Tls; using Org.BouncyCastle.Security; using SIPSorcery.Sys; +using System.Buffers; namespace SIPSorcery.Net { @@ -47,7 +48,7 @@ public class DtlsSrtpTransport : DatagramTransport, IDisposable IDtlsSrtpPeer connection = null; /// The collection of chunks to be written. - private BlockingCollection _chunks = new BlockingCollection(new ConcurrentQueue()); + private BlockingCollection> _chunks = new(new ConcurrentQueue>()); public DtlsTransport Transport { get; private set; } @@ -62,7 +63,8 @@ public class DtlsSrtpTransport : DatagramTransport, IDisposable /// public int RetransmissionMilliseconds = DEFAULT_RETRANSMISSION_WAIT_MILLIS; - public Action OnDataReady; + public delegate void OnBytesReadyDelegate(ReadOnlySpan bytes); + public OnBytesReadyDelegate OnDataReady; /// /// Parameters: @@ -73,7 +75,7 @@ public class DtlsSrtpTransport : DatagramTransport, IDisposable public event Action OnAlert; private System.DateTime _startTime = System.DateTime.MinValue; - private bool _isClosed = false; + private Once _isClosed; // Network properties private int _waitMillis = DEFAULT_RETRANSMISSION_WAIT_MILLIS; @@ -171,8 +173,7 @@ private bool DoHandshakeAsClient(out string handshakeError) this._waitMillis = RetransmissionMilliseconds; this._startTime = System.DateTime.Now; this._handshaking = true; - SecureRandom secureRandom = new SecureRandom(); - DtlsClientProtocol clientProtocol = new DtlsClientProtocol(secureRandom); + DtlsClientProtocol clientProtocol = new DtlsClientProtocol(); try { var client = (DtlsSrtpClient)connection; @@ -208,9 +209,9 @@ private bool DoHandshakeAsClient(out string handshakeError) else { handshakeError = "unknown"; - if (excp is Org.BouncyCastle.Crypto.Tls.TlsFatalAlert) + if (excp is Org.BouncyCastle.Tls.TlsFatalAlert) { - handshakeError = (excp as Org.BouncyCastle.Crypto.Tls.TlsFatalAlert).Message; + handshakeError = (excp as Org.BouncyCastle.Tls.TlsFatalAlert).Message; } logger.LogWarning(excp, $"DTLS handshake as client failed. {excp.Message}"); @@ -238,8 +239,7 @@ private bool DoHandshakeAsServer(out string handshakeError) this._waitMillis = RetransmissionMilliseconds; this._startTime = System.DateTime.Now; this._handshaking = true; - SecureRandom secureRandom = new SecureRandom(); - DtlsServerProtocol serverProtocol = new DtlsServerProtocol(secureRandom); + DtlsServerProtocol serverProtocol = new DtlsServerProtocol(); try { var server = (DtlsSrtpServer)connection; @@ -274,9 +274,9 @@ private bool DoHandshakeAsServer(out string handshakeError) else { handshakeError = "unknown"; - if (excp is Org.BouncyCastle.Crypto.Tls.TlsFatalAlert) + if (excp is Org.BouncyCastle.Tls.TlsFatalAlert) { - handshakeError = (excp as Org.BouncyCastle.Crypto.Tls.TlsFatalAlert).Message; + handshakeError = (excp as Org.BouncyCastle.Tls.TlsFatalAlert).Message; } logger.LogWarning(excp, $"DTLS handshake as server failed. {excp.Message}"); @@ -373,7 +373,7 @@ protected IPacketTransformer GenerateTransformer(bool isClient, bool isRtp) } } - public byte[] UnprotectRTP(byte[] packet, int offset, int length) + public byte[] UnprotectRTP(Span packet, int offset, int length) { lock (this.srtpDecoder) { @@ -381,7 +381,7 @@ public byte[] UnprotectRTP(byte[] packet, int offset, int length) } } - public int UnprotectRTP(byte[] payload, int length, out int outLength) + public int UnprotectRTP(Span payload, int length, out int outLength) { var result = UnprotectRTP(payload, 0, length); @@ -391,13 +391,13 @@ public int UnprotectRTP(byte[] payload, int length, out int outLength) return -1; } - System.Buffer.BlockCopy(result, 0, payload, 0, result.Length); + result.AsSpan().CopyTo(payload); outLength = result.Length; return 0; //No Errors } - public byte[] ProtectRTP(byte[] packet, int offset, int length) + public byte[] ProtectRTP(Span packet, int offset, int length) { lock (this.srtpEncoder) { @@ -405,7 +405,7 @@ public byte[] ProtectRTP(byte[] packet, int offset, int length) } } - public int ProtectRTP(byte[] payload, int length, out int outLength) + public int ProtectRTP(Span payload, int length, out int outLength) { var result = ProtectRTP(payload, 0, length); @@ -415,13 +415,13 @@ public int ProtectRTP(byte[] payload, int length, out int outLength) return -1; } - System.Buffer.BlockCopy(result, 0, payload, 0, result.Length); + result.AsSpan().CopyTo(payload); outLength = result.Length; return 0; //No Errors } - public byte[] UnprotectRTCP(byte[] packet, int offset, int length) + public byte[] UnprotectRTCP(Span packet, int offset, int length) { lock (this.srtcpDecoder) { @@ -429,7 +429,7 @@ public byte[] UnprotectRTCP(byte[] packet, int offset, int length) } } - public int UnprotectRTCP(byte[] payload, int length, out int outLength) + public int UnprotectRTCP(Span payload, int length, out int outLength) { var result = UnprotectRTCP(payload, 0, length); if (result == null) @@ -438,13 +438,13 @@ public int UnprotectRTCP(byte[] payload, int length, out int outLength) return -1; } - System.Buffer.BlockCopy(result, 0, payload, 0, result.Length); + result.AsSpan().CopyTo(payload); outLength = result.Length; return 0; //No Errors } - public byte[] ProtectRTCP(byte[] packet, int offset, int length) + public byte[] ProtectRTCP(Span packet, int offset, int length) { lock (this.srtcpEncoder) { @@ -452,7 +452,7 @@ public byte[] ProtectRTCP(byte[] packet, int offset, int length) } } - public int ProtectRTCP(byte[] payload, int length, out int outLength) + public int ProtectRTCP(Span payload, int length, out int outLength) { var result = ProtectRTCP(payload, 0, length); if (result == null) @@ -461,7 +461,7 @@ public int ProtectRTCP(byte[] payload, int length, out int outLength) return -1; } - System.Buffer.BlockCopy(result, 0, payload, 0, result.Length); + result.AsSpan().CopyTo(payload); outLength = result.Length; return 0; //No Errors @@ -485,27 +485,63 @@ public int GetSendLimit() return this._sendLimit; } - public void WriteToRecvStream(byte[] buf) + public void WriteToRecvStream(ReadOnlySpan buf) { - if (!_isClosed) + if (!_isClosed.HasOccurred) { - _chunks.Add(buf); + var chunk = ArrayPool.Shared.Rent(buf.Length); + buf.CopyTo(chunk); + try + { + _chunks.Add(new(chunk, 0, buf.Length)); + } + catch (Exception) when (_isClosed.HasOccurred) + { + ArrayPool.Shared.Return(chunk); + } } } + private ArraySegment _partialChunk = default; + private int _partialChunkOffset = 0; private int Read(byte[] buffer, int offset, int count, int timeout) { try { - if (_isClosed) + if (_isClosed.HasOccurred) { throw new System.Net.Sockets.SocketException((int)System.Net.Sockets.SocketError.NotConnected); //return DTLS_RECEIVE_ERROR_CODE; } + else if (_partialChunk.Array != null) + { + int bytesToCopy = Math.Min(count, _partialChunk.Count - _partialChunkOffset); + Buffer.BlockCopy(_partialChunk.Array, _partialChunkOffset, buffer, offset, bytesToCopy); + _partialChunkOffset += bytesToCopy; + + if (_partialChunkOffset == _partialChunk.Count) + { + ArrayPool.Shared.Return(_partialChunk.Array); + _partialChunk = default; + _partialChunkOffset = 0; + } + + return bytesToCopy; + } else if (_chunks.TryTake(out var item, timeout)) { - Buffer.BlockCopy(item, 0, buffer, 0, item.Length); - return item.Length; + int bytesToCopy = Math.Min(count, item.Count); + Buffer.BlockCopy(item.Array, 0, buffer, offset, bytesToCopy); + if (bytesToCopy < item.Count) + { + _partialChunk = item; + _partialChunkOffset = bytesToCopy; + } + else + { + ArrayPool.Shared.Return(item.Array); + } + return bytesToCopy; } } catch (ObjectDisposedException) { } @@ -514,6 +550,14 @@ private int Read(byte[] buffer, int offset, int count, int timeout) return DTLS_RETRANSMISSION_CODE; } +#if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER || NET6_0_OR_GREATER + public int Receive(Span buf, int waitMillis) + { + throw new NotImplementedException(); + return Receive(buf.ToArray(), 0, buf.Length, waitMillis); + } +#endif + public int Receive(byte[] buf, int off, int len, int waitMillis) { if (!_handshakeComplete) @@ -533,7 +577,7 @@ public int Receive(byte[] buf, int off, int len, int waitMillis) logger.LogWarning($"DTLS transport timed out after {TimeoutMilliseconds}ms waiting for handshake from remote {(connection.IsClient() ? "server" : "client")}."); throw new TimeoutException(); } - else if (!_isClosed) + else if (!_isClosed.HasOccurred) { waitMillis = Math.Min(waitMillis, millisecondsRemaining); var receiveLen = Read(buf, off, len, waitMillis); @@ -557,14 +601,14 @@ public int Receive(byte[] buf, int off, int len, int waitMillis) //return DTLS_RECEIVE_ERROR_CODE; } } - else if (!_isClosed) + else if (!_isClosed.HasOccurred) { return Read(buf, off, len, waitMillis); } else { - //throw new System.Net.Sockets.SocketException((int)System.Net.Sockets.SocketError.NotConnected); - return DTLS_RECEIVE_ERROR_CODE; + throw new System.Net.Sockets.SocketException((int)System.Net.Sockets.SocketError.NotConnected); + //return DTLS_RECEIVE_ERROR_CODE; } } @@ -580,16 +624,33 @@ public void Send(byte[] buf, int off, int len) OnDataReady?.Invoke(buf); } +#if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER + public void Send(ReadOnlySpan buf) + { + OnDataReady?.Invoke(buf); + } +#endif + public virtual void Close() { - if (!_isClosed) + if (!_isClosed.TryMarkOccurred()) + { + return; + } + + this._startTime = System.DateTime.MinValue; + _chunks.CompleteAdding(); + foreach(var chunk in _chunks.GetConsumingEnumerable()) + { + ArrayPool.Shared.Return(chunk.Array); + } + if (_partialChunk.Array is { } partialChunk) { - _isClosed = true; - this._startTime = System.DateTime.MinValue; - this._chunks?.Dispose(); - Transport?.Close(); + ArrayPool.Shared.Return(partialChunk); } + this._chunks?.Dispose(); + Transport?.Close(); } /// @@ -597,7 +658,7 @@ public virtual void Close() /// protected void Dispose(bool disposing) { - if (!_isClosed) + if (!_isClosed.HasOccurred) { Close(); } @@ -608,7 +669,7 @@ protected void Dispose(bool disposing) /// public void Dispose() { - if (!_isClosed) + if (!_isClosed.HasOccurred) { Close(); } diff --git a/src/net/DtlsSrtp/DtlsUtils.cs b/src/net/DtlsSrtp/DtlsUtils.cs index ef2337c93..9dfb6f2dd 100644 --- a/src/net/DtlsSrtp/DtlsUtils.cs +++ b/src/net/DtlsSrtp/DtlsUtils.cs @@ -56,7 +56,7 @@ using Org.BouncyCastle.Crypto.Operators; using Org.BouncyCastle.Crypto.Parameters; using Org.BouncyCastle.Crypto.Prng; -using Org.BouncyCastle.Crypto.Tls; +using Org.BouncyCastle.Tls; using Org.BouncyCastle.Math; using Org.BouncyCastle.Pkcs; using Org.BouncyCastle.Security; @@ -64,6 +64,9 @@ using Org.BouncyCastle.Utilities.IO.Pem; using Org.BouncyCastle.X509; using SIPSorcery.Sys; +using System.Runtime.CompilerServices; +using Org.BouncyCastle.Tls.Crypto.Impl.BC; +using Org.BouncyCastle.Tls.Crypto; namespace SIPSorcery.Net { @@ -76,12 +79,12 @@ public class DtlsUtils private static ILogger logger = SIPSorcery.Sys.Log.Logger; - public static RTCDtlsFingerprint Fingerprint(string hashAlgorithm, X509Certificate2 certificate) + public static RTCDtlsFingerprint Fingerprint(TlsCrypto crypto, string hashAlgorithm, X509Certificate2 certificate) { - return Fingerprint(hashAlgorithm, LoadCertificateResource(certificate)); + return Fingerprint(hashAlgorithm, LoadCertificateResource(crypto, certificate)); } - public static RTCDtlsFingerprint Fingerprint(string hashAlgorithm, Org.BouncyCastle.Asn1.X509.X509CertificateStructure c) + public static RTCDtlsFingerprint Fingerprint(string hashAlgorithm, TlsCertificate c) { if (!IsHashSupported(hashAlgorithm)) { @@ -105,9 +108,9 @@ public static RTCDtlsFingerprint Fingerprint(Certificate certificateChain) return Fingerprint(certificate); } - public static RTCDtlsFingerprint Fingerprint(X509Certificate2 certificate) + public static RTCDtlsFingerprint Fingerprint(TlsCrypto crypto, X509Certificate2 certificate) { - return Fingerprint(LoadCertificateResource(certificate)); + return Fingerprint(LoadCertificateResource(crypto, certificate)); } public static RTCDtlsFingerprint Fingerprint(Org.BouncyCastle.X509.X509Certificate certificate) @@ -128,6 +131,18 @@ public static RTCDtlsFingerprint Fingerprint(X509CertificateStructure c) value = sha256Hash.HexStr(':') }; } + public static RTCDtlsFingerprint Fingerprint(TlsCertificate c) + { + IDigest sha256 = DigestUtilities.GetDigest(HashAlgorithmTag.Sha256.ToString()); + byte[] der = c.GetEncoded(); + byte[] sha256Hash = DigestOf(sha256, der); + + return new RTCDtlsFingerprint + { + algorithm = sha256.AlgorithmName.ToLower(), + value = sha256Hash.HexStr(':') + }; + } public static byte[] DigestOf(IDigest dAlg, byte[] input) { @@ -137,70 +152,56 @@ public static byte[] DigestOf(IDigest dAlg, byte[] input) return result; } - public static TlsAgreementCredentials LoadAgreementCredentials(TlsContext context, + public static TlsCredentialedAgreement LoadAgreementCredentials(TlsContext context, Certificate certificate, AsymmetricKeyParameter privateKey) { - return new DefaultTlsAgreementCredentials(certificate, privateKey); + return new BcDefaultTlsCredentialedAgreement(context.Crypto as BcTlsCrypto, certificate, privateKey); } - public static TlsAgreementCredentials LoadAgreementCredentials(TlsContext context, + public static TlsCredentialedAgreement LoadAgreementCredentials(TlsContext context, string[] certResources, string keyResource) { - Certificate certificate = LoadCertificateChain(certResources); + Certificate certificate = LoadCertificateChain(context.Crypto, certResources); AsymmetricKeyParameter privateKey = LoadPrivateKeyResource(keyResource); return LoadAgreementCredentials(context, certificate, privateKey); } - public static TlsEncryptionCredentials LoadEncryptionCredentials( + public static TlsCredentialedDecryptor LoadEncryptionCredentials( TlsContext context, Certificate certificate, AsymmetricKeyParameter privateKey) { - return new DefaultTlsEncryptionCredentials(context, certificate, + + return new BcDefaultTlsCredentialedDecryptor(context.Crypto as BcTlsCrypto, certificate, privateKey); } - public static TlsEncryptionCredentials LoadEncryptionCredentials( + public static TlsCredentialedDecryptor LoadEncryptionCredentials( TlsContext context, string[] certResources, string keyResource) { - Certificate certificate = LoadCertificateChain(certResources); + Certificate certificate = LoadCertificateChain(context.Crypto, certResources); AsymmetricKeyParameter privateKey = LoadPrivateKeyResource(keyResource); return LoadEncryptionCredentials(context, certificate, privateKey); } - public static TlsSignerCredentials LoadSignerCredentials(TlsContext context, - Certificate certificate, AsymmetricKeyParameter privateKey) - { - return new DefaultTlsSignerCredentials(context, certificate, privateKey); - } - - public static TlsSignerCredentials LoadSignerCredentials(TlsContext context, - string[] certResources, string keyResource) - { - Certificate certificate = LoadCertificateChain(certResources); - AsymmetricKeyParameter privateKey = LoadPrivateKeyResource(keyResource); - return LoadSignerCredentials(context, certificate, privateKey); - } - - public static TlsSignerCredentials LoadSignerCredentials(TlsContext context, + public static TlsCredentialedSigner LoadSignerCredentials(TlsContext context, Certificate certificate, AsymmetricKeyParameter privateKey, SignatureAndHashAlgorithm signatureAndHashAlgorithm) { - return new DefaultTlsSignerCredentials(context, certificate, - privateKey, signatureAndHashAlgorithm); + return new BcDefaultTlsCredentialedSigner(new TlsCryptoParameters(context), context.Crypto as BcTlsCrypto, privateKey, certificate, signatureAndHashAlgorithm); } - public static TlsSignerCredentials LoadSignerCredentials(TlsContext context, + public static TlsCredentialedSigner LoadSignerCredentials(TlsContext context, string[] certResources, string keyResource, SignatureAndHashAlgorithm signatureAndHashAlgorithm) { - Certificate certificate = LoadCertificateChain(certResources); + Certificate certificate = LoadCertificateChain(context.Crypto as BcTlsCrypto, certResources); Org.BouncyCastle.Crypto.AsymmetricKeyParameter privateKey = LoadPrivateKeyResource(keyResource); return LoadSignerCredentials(context, certificate, privateKey, signatureAndHashAlgorithm); } - public static TlsSignerCredentials LoadSignerCredentials(TlsContext context, IList supportedSignatureAlgorithms, - byte signatureAlgorithm, Certificate certificate, AsymmetricKeyParameter privateKey) + public static TlsCredentialedSigner LoadSignerCredentials(TlsContext context, IList supportedSignatureAlgorithms, + short signatureAlgorithm, Certificate certificate, AsymmetricKeyParameter privateKey) { /* * TODO Note that this code fails to provide default value for the client supported @@ -228,59 +229,59 @@ public static TlsSignerCredentials LoadSignerCredentials(TlsContext context, ILi return LoadSignerCredentials(context, certificate, privateKey, signatureAndHashAlgorithm); } - public static TlsSignerCredentials LoadSignerCredentials(TlsContext context, IList supportedSignatureAlgorithms, + public static TlsCredentialedSigner LoadSignerCredentials(TlsContext context, IList supportedSignatureAlgorithms, byte signatureAlgorithm, string certResource, string keyResource) { - Certificate certificate = LoadCertificateChain(new string[] { certResource, "x509-ca.pem" }); + Certificate certificate = LoadCertificateChain(context.Crypto as BcTlsCrypto, new string[] { certResource, "x509-ca.pem" }); AsymmetricKeyParameter privateKey = LoadPrivateKeyResource(keyResource); return LoadSignerCredentials(context, supportedSignatureAlgorithms, signatureAlgorithm, certificate, privateKey); } - public static Certificate LoadCertificateChain(X509Certificate2[] certificates) + public static Certificate LoadCertificateChain(TlsCrypto crypto, X509Certificate2[] certificates) { - var chain = new Org.BouncyCastle.Asn1.X509.X509CertificateStructure[certificates.Length]; + var chain = new TlsCertificate[certificates.Length]; for (int i = 0; i < certificates.Length; i++) { - chain[i] = LoadCertificateResource(certificates[i]); + chain[i] = LoadCertificateResource(crypto, certificates[i]); } return new Certificate(chain); } - public static Certificate LoadCertificateChain(X509Certificate2 certificate) + public static Certificate LoadCertificateChain(TlsCrypto crypto, X509Certificate2 certificate) { - return LoadCertificateChain(new X509Certificate2[] { certificate }); + return LoadCertificateChain(crypto, new X509Certificate2[] { certificate }); } - public static Certificate LoadCertificateChain(string[] resources) + public static Certificate LoadCertificateChain(TlsCrypto crypto, string[] resources) { - Org.BouncyCastle.Asn1.X509.X509CertificateStructure[] - chain = new Org.BouncyCastle.Asn1.X509.X509CertificateStructure[resources.Length]; + TlsCertificate[] + chain = new TlsCertificate[resources.Length]; for (int i = 0; i < resources.Length; ++i) { - chain[i] = LoadCertificateResource(resources[i]); + chain[i] = LoadCertificateResource(crypto, resources[i]); } return new Certificate(chain); } - public static X509CertificateStructure LoadCertificateResource(X509Certificate2 certificate) + public static TlsCertificate LoadCertificateResource(TlsCrypto crypto, X509Certificate2 certificate) { if (certificate != null) { var bouncyCertificate = DotNetUtilities.FromX509Certificate(certificate); - return X509CertificateStructure.GetInstance(bouncyCertificate.GetEncoded()); + return new BcTlsCertificate(crypto as BcTlsCrypto, X509CertificateStructure.GetInstance(bouncyCertificate.GetEncoded())); } throw new Exception("'resource' doesn't specify a valid certificate"); } - public static X509CertificateStructure LoadCertificateResource(string resource) + public static TlsCertificate LoadCertificateResource(TlsCrypto crypto, string resource) { PemObject pem = LoadPemResource(resource); if (pem.Type.EndsWith("CERTIFICATE")) { - return X509CertificateStructure.GetInstance(pem.Content); + return new BcTlsCertificate(crypto as BcTlsCrypto, X509CertificateStructure.GetInstance(pem.Content)); } throw new Exception("'resource' doesn't specify a valid certificate"); } @@ -484,18 +485,18 @@ public static (Org.BouncyCastle.X509.X509Certificate certificate, AsymmetricKeyP return (certificate, subjectKeyPair.Private); } - public static (Org.BouncyCastle.Crypto.Tls.Certificate certificate, AsymmetricKeyParameter privateKey) CreateSelfSignedTlsCert() + public static (Org.BouncyCastle.Tls.Certificate certificate, AsymmetricKeyParameter privateKey) CreateSelfSignedTlsCert(TlsCrypto crypto) { - return CreateSelfSignedTlsCert("CN=localhost", "CN=root", null); + return CreateSelfSignedTlsCert(crypto, "CN=localhost", "CN=root", null); } - public static (Org.BouncyCastle.Crypto.Tls.Certificate certificate, AsymmetricKeyParameter privateKey) CreateSelfSignedTlsCert(string subjectName, string issuerName, AsymmetricKeyParameter issuerPrivateKey) + public static (Org.BouncyCastle.Tls.Certificate certificate, AsymmetricKeyParameter privateKey) CreateSelfSignedTlsCert(TlsCrypto crypto, string subjectName, string issuerName, AsymmetricKeyParameter issuerPrivateKey) { var tuple = CreateSelfSignedBouncyCastleCert(subjectName, issuerName, issuerPrivateKey); var certificate = tuple.certificate; var privateKey = tuple.privateKey; - var chain = new Org.BouncyCastle.Asn1.X509.X509CertificateStructure[] { X509CertificateStructure.GetInstance(certificate.GetEncoded()) }; - var tlsCertificate = new Org.BouncyCastle.Crypto.Tls.Certificate(chain); + var chain = new TlsCertificate[] { new BcTlsCertificate(crypto as BcTlsCrypto, X509CertificateStructure.GetInstance(certificate.GetEncoded())) }; + var tlsCertificate = new Org.BouncyCastle.Tls.Certificate(chain); return (tlsCertificate, privateKey); } @@ -505,63 +506,48 @@ public static (Org.BouncyCastle.Crypto.Tls.Certificate certificate, AsymmetricKe /// use the serialize/deserialize from pfx to get from bouncy castle to .NET Core X509 certificates. public static X509Certificate2 ConvertBouncyCert(Org.BouncyCastle.X509.X509Certificate bouncyCert, AsymmetricCipherKeyPair keyPair) { - var pkcs12Store = new Pkcs12Store(); - var certEntry = new X509CertificateEntry(bouncyCert); - - pkcs12Store.SetCertificateEntry(bouncyCert.SerialNumber.ToString(), certEntry); - pkcs12Store.SetKeyEntry(bouncyCert.SerialNumber.ToString(), - new AsymmetricKeyEntry(keyPair.Private), new[] { certEntry }); +#if !NET461 && !NETSTANDARD2_0 + var info = Org.BouncyCastle.Pkcs.PrivateKeyInfoFactory.CreatePrivateKeyInfo(keyPair.Private); - X509Certificate2 keyedCert; + //// merge into X509Certificate2 + var x509 = new X509Certificate2(bouncyCert.GetEncoded()); - using (MemoryStream pfxStream = new MemoryStream()) + var seq = (Asn1Sequence)Asn1Object.FromByteArray(info.ParsePrivateKey().GetDerEncoded()); + if (seq.Count != 9) { - pkcs12Store.Save(pfxStream, new char[] { }, new SecureRandom()); - pfxStream.Seek(0, SeekOrigin.Begin); - keyedCert = new X509Certificate2(pfxStream.ToArray(), string.Empty, X509KeyStorageFlags.Exportable); + throw new Org.BouncyCastle.OpenSsl.PemException("malformed sequence in RSA private key"); } - return keyedCert; - - //var info = Org.BouncyCastle.Pkcs.PrivateKeyInfoFactory.CreatePrivateKeyInfo(keyPair.Private); + var rsa = RsaPrivateKeyStructure.GetInstance(seq); //new RsaPrivateKeyStructure(seq); + var rsaparams = new RsaPrivateCrtKeyParameters( + rsa.Modulus, rsa.PublicExponent, rsa.PrivateExponent, rsa.Prime1, rsa.Prime2, rsa.Exponent1, rsa.Exponent2, rsa.Coefficient); - //// merge into X509Certificate2 - //var x509 = new X509Certificate2(bouncyCert.GetEncoded()); - - //var seq = (Asn1Sequence)Asn1Object.FromByteArray(info.ParsePrivateKey().GetDerEncoded()); - //if (seq.Count != 9) - //{ - // throw new Org.BouncyCastle.OpenSsl.PemException("malformed sequence in RSA private key"); - //} - - //var rsa = RsaPrivateKeyStructure.GetInstance(seq); //new RsaPrivateKeyStructure(seq); - //var rsaparams = new RsaPrivateCrtKeyParameters( - // rsa.Modulus, rsa.PublicExponent, rsa.PrivateExponent, rsa.Prime1, rsa.Prime2, rsa.Exponent1, rsa.Exponent2, rsa.Coefficient); - - //return x509.CopyWithPrivateKey(ToRSA(rsaparams)); + return x509.CopyWithPrivateKey(ToRSA(rsaparams)); - //X509Certificate2 x509 = null; +#else + X509Certificate2 x509 = null; - //using (MemoryStream ms = new MemoryStream()) - //{ - // using (StreamWriter tw = new StreamWriter(ms)) - // { - // PemWriter pw = new PemWriter(tw); - // //PemObject po = new PemObject("CERTIFICATE", bouncyCert.GetEncoded()); - // PemObject po = new PemObject("CERTIFICATE", bouncyCert.GetEncoded()); - // pw.WriteObject(po); + using (MemoryStream ms = new MemoryStream()) + { + using (StreamWriter tw = new StreamWriter(ms)) + { + PemWriter pw = new PemWriter(tw); + //PemObject po = new PemObject("CERTIFICATE", bouncyCert.GetEncoded()); + PemObject po = new PemObject("CERTIFICATE", bouncyCert.GetEncoded()); + pw.WriteObject(po); - // logger.LogDebug(Encoding.UTF8.GetString(ms.GetBuffer())); + logger.LogDebug(System.Text.Encoding.UTF8.GetString(ms.GetBuffer())); - // StreamWriter sw2 = new StreamWriter("test.cer"); - // sw2.Write(ms.GetBuffer()); - // sw2.Close(); + StreamWriter sw2 = new StreamWriter("test.cer"); + sw2.Write(ms.GetBuffer()); + sw2.Close(); - // x509 = new X509Certificate2(bouncyCert.GetEncoded()); - // } - //} + x509 = new X509Certificate2(bouncyCert.GetEncoded()); + } + } - //return x509; + return x509; +#endif } @@ -602,7 +588,7 @@ public static AsymmetricKeyParameter CreatePrivateKeyResource(string subjectName return subjectKeyPair.Private; } - #endregion +#endregion /// /// This method and the related ones have been copied from the BouncyCode DotNetUtilities diff --git a/src/net/DtlsSrtp/SrtpHandler.cs b/src/net/DtlsSrtp/SrtpHandler.cs index 2fed6ecf6..509a86947 100644 --- a/src/net/DtlsSrtp/SrtpHandler.cs +++ b/src/net/DtlsSrtp/SrtpHandler.cs @@ -166,7 +166,7 @@ private IPacketTransformer GenerateTransformer(SDPSecurityDescription securityDe } } - public byte[] UnprotectRTP(byte[] packet, int offset, int length) + public byte[] UnprotectRTP(Span packet, int offset, int length) { lock (SrtpDecoder) { @@ -174,7 +174,7 @@ public byte[] UnprotectRTP(byte[] packet, int offset, int length) } } - public int UnprotectRTP(byte[] payload, int length, out int outLength) + public int UnprotectRTP(Span payload, int length, out int outLength) { var result = UnprotectRTP(payload, 0, length); @@ -184,13 +184,13 @@ public int UnprotectRTP(byte[] payload, int length, out int outLength) return -1; } - System.Buffer.BlockCopy(result, 0, payload, 0, result.Length); + result.AsSpan().CopyTo(payload); outLength = result.Length; return 0; //No Errors } - public byte[] ProtectRTP(byte[] packet, int offset, int length) + public byte[] ProtectRTP(Span packet, int offset, int length) { lock (SrtpEncoder) { @@ -198,7 +198,7 @@ public byte[] ProtectRTP(byte[] packet, int offset, int length) } } - public int ProtectRTP(byte[] payload, int length, out int outLength) + public int ProtectRTP(Span payload, int length, out int outLength) { var result = ProtectRTP(payload, 0, length); @@ -208,13 +208,13 @@ public int ProtectRTP(byte[] payload, int length, out int outLength) return -1; } - System.Buffer.BlockCopy(result, 0, payload, 0, result.Length); + result.AsSpan().CopyTo(payload); outLength = result.Length; return 0; //No Errors } - public byte[] UnprotectRTCP(byte[] packet, int offset, int length) + public byte[] UnprotectRTCP(Span packet, int offset, int length) { lock (SrtcpDecoder) { @@ -222,7 +222,7 @@ public byte[] UnprotectRTCP(byte[] packet, int offset, int length) } } - public int UnprotectRTCP(byte[] payload, int length, out int outLength) + public int UnprotectRTCP(Span payload, int length, out int outLength) { var result = UnprotectRTCP(payload, 0, length); if (result == null) @@ -231,13 +231,13 @@ public int UnprotectRTCP(byte[] payload, int length, out int outLength) return -1; } - System.Buffer.BlockCopy(result, 0, payload, 0, result.Length); + result.AsSpan().CopyTo(payload); outLength = result.Length; return 0; //No Errors } - public byte[] ProtectRTCP(byte[] packet, int offset, int length) + public byte[] ProtectRTCP(Span packet, int offset, int length) { lock (SrtcpEncoder) { @@ -245,7 +245,7 @@ public byte[] ProtectRTCP(byte[] packet, int offset, int length) } } - public int ProtectRTCP(byte[] payload, int length, out int outLength) + public int ProtectRTCP(Span payload, int length, out int outLength) { var result = ProtectRTCP(payload, 0, length); if (result == null) @@ -254,7 +254,7 @@ public int ProtectRTCP(byte[] payload, int length, out int outLength) return -1; } - System.Buffer.BlockCopy(result, 0, payload, 0, result.Length); + result.AsSpan().CopyTo(payload); outLength = result.Length; return 0; //No Errors diff --git a/src/net/DtlsSrtp/SrtpParameters.cs b/src/net/DtlsSrtp/SrtpParameters.cs index 8d6ed5582..b70d78244 100644 --- a/src/net/DtlsSrtp/SrtpParameters.cs +++ b/src/net/DtlsSrtp/SrtpParameters.cs @@ -18,7 +18,7 @@ //----------------------------------------------------------------------------- using System; -using Org.BouncyCastle.Crypto.Tls; +using Org.BouncyCastle.Tls; namespace SIPSorcery.Net { diff --git a/src/net/DtlsSrtp/Transform/IPackerTransformer.cs b/src/net/DtlsSrtp/Transform/IPackerTransformer.cs index 4f570f782..ef4af9b10 100644 --- a/src/net/DtlsSrtp/Transform/IPackerTransformer.cs +++ b/src/net/DtlsSrtp/Transform/IPackerTransformer.cs @@ -19,6 +19,8 @@ // Original Source: AGPL-3.0 License //----------------------------------------------------------------------------- +using System; + namespace SIPSorcery.Net { public interface IPacketTransformer @@ -30,7 +32,7 @@ public interface IPacketTransformer * the packet to be transformed * @return The transformed packet. Returns null if the packet cannot be transformed. */ - byte[] Transform(byte[] pkt); + byte[] Transform(ReadOnlySpan pkt); /** * Transforms a specific non-secure packet. @@ -44,7 +46,7 @@ public interface IPacketTransformer * @return The transformed packet. Returns null if the packet cannot be * transformed. */ - byte[] Transform(byte[] pkt, int offset, int length); + byte[] Transform(ReadOnlySpan pkt, int offset, int length); /** * Reverse-transforms a specific packet (i.e. transforms a transformed @@ -54,7 +56,7 @@ public interface IPacketTransformer * the transformed packet to be restored * @return Whether the packet was successfully restored */ - byte[] ReverseTransform(byte[] pkt); + byte[] ReverseTransform(ReadOnlySpan pkt); /** * Reverse-transforms a specific packet (i.e. transforms a transformed @@ -68,7 +70,7 @@ public interface IPacketTransformer * the length of data in the packet * @return The restored packet. Returns null if packet cannot be restored. */ - byte[] ReverseTransform(byte[] pkt, int offset, int length); + byte[] ReverseTransform(ReadOnlySpan pkt, int offset, int length); /** * Close the transformer and underlying transform engine. diff --git a/src/net/DtlsSrtp/Transform/RawPacket.cs b/src/net/DtlsSrtp/Transform/RawPacket.cs index 7da6fb3a7..ee3e38804 100644 --- a/src/net/DtlsSrtp/Transform/RawPacket.cs +++ b/src/net/DtlsSrtp/Transform/RawPacket.cs @@ -40,6 +40,7 @@ * */ +using System; using System.IO; namespace SIPSorcery.Net @@ -82,16 +83,16 @@ public RawPacket() * @param length the number of bytes in buffer which * constitute the actual data to be represented by the new instance */ - public RawPacket(byte[] data, int offset, int length) + public RawPacket(ReadOnlySpan data, int offset, int length) { this.buffer = new MemoryStream(RTP_PACKET_MAX_SIZE); Wrap(data, offset, length); } - public void Wrap(byte[] data, int offset, int length) + public void Wrap(ReadOnlySpan data, int offset, int length) { - this.buffer.Position = 0; - this.buffer.Write(data, offset, length); + this.buffer.Position = 0; + this.buffer.Write(data.ToArray(), offset, length); this.buffer.SetLength(length - offset); this.buffer.Position = 0; } diff --git a/src/net/DtlsSrtp/Transform/SrtcpTransformer.cs b/src/net/DtlsSrtp/Transform/SrtcpTransformer.cs index 734f9b120..0b9fe45ee 100644 --- a/src/net/DtlsSrtp/Transform/SrtcpTransformer.cs +++ b/src/net/DtlsSrtp/Transform/SrtcpTransformer.cs @@ -15,58 +15,59 @@ // License: // BSD 3-Clause "New" or "Revised" License, see included LICENSE.md file. // Original Source: AGPL-3.0 License -//----------------------------------------------------------------------------- - +//----------------------------------------------------------------------------- + +using System; using System.Collections.Concurrent; -using System.Collections.Generic; +using System.Collections.Generic; using System.Threading; -namespace SIPSorcery.Net +namespace SIPSorcery.Net { /// - /// SRTCPTransformer implements PacketTransformer. - /// It encapsulate the encryption / decryption logic for SRTCP packets - /// - /// @author Bing SU (nova.su @gmail.com) + /// SRTCPTransformer implements PacketTransformer. + /// It encapsulate the encryption / decryption logic for SRTCP packets + /// + /// @author Bing SU (nova.su @gmail.com) /// @author Werner Dittmann (Werner.Dittmann@t-online.de) - /// - public class SrtcpTransformer : IPacketTransformer - { - private int _isLocked = 0; - private RawPacket packet; - - private SrtpTransformEngine forwardEngine; - private SrtpTransformEngine reverseEngine; - - /** All the known SSRC's corresponding SRTCPCryptoContexts */ - private ConcurrentDictionary contexts; - - public SrtcpTransformer(SrtpTransformEngine engine) : this(engine, engine) - { - - } - - public SrtcpTransformer(SrtpTransformEngine forwardEngine, SrtpTransformEngine reverseEngine) - { - this.packet = new RawPacket(); - this.forwardEngine = forwardEngine; - this.reverseEngine = reverseEngine; - this.contexts = new ConcurrentDictionary(); - } - - /// - /// Encrypts a SRTCP packet + /// + public class SrtcpTransformer : IPacketTransformer + { + private int _isLocked = 0; + private RawPacket packet; + + private SrtpTransformEngine forwardEngine; + private SrtpTransformEngine reverseEngine; + + /** All the known SSRC's corresponding SRTCPCryptoContexts */ + private ConcurrentDictionary contexts; + + public SrtcpTransformer(SrtpTransformEngine engine) : this(engine, engine) + { + + } + + public SrtcpTransformer(SrtpTransformEngine forwardEngine, SrtpTransformEngine reverseEngine) + { + this.packet = new RawPacket(); + this.forwardEngine = forwardEngine; + this.reverseEngine = reverseEngine; + this.contexts = new ConcurrentDictionary(); + } + + /// + /// Encrypts a SRTCP packet /// /// plain SRTCP packet to be encrypted. /// encrypted SRTCP packet. - public byte[] Transform(byte[] pkt) - { - return Transform(pkt, 0, pkt.Length); - } - - public byte[] Transform(byte[] pkt, int offset, int length) - { - var isLocked = Interlocked.CompareExchange(ref _isLocked, 1, 0) != 0; + public byte[] Transform(ReadOnlySpan pkt) + { + return Transform(pkt, 0, pkt.Length); + } + + public byte[] Transform(ReadOnlySpan pkt, int offset, int length) + { + var isLocked = Interlocked.CompareExchange(ref _isLocked, 1, 0) != 0; try { // Wrap the data into raw packet for readable format @@ -90,23 +91,25 @@ public byte[] Transform(byte[] pkt, int offset, int length) byte[] result = packet.GetData(); return result; - } - finally + } + finally { //Unlock if (!isLocked) + { Interlocked.CompareExchange(ref _isLocked, 0, 1); - } - } - - public byte[] ReverseTransform(byte[] pkt) - { - return ReverseTransform(pkt, 0, pkt.Length); - } - - public byte[] ReverseTransform(byte[] pkt, int offset, int length) - { - var isLocked = Interlocked.CompareExchange(ref _isLocked, 1, 0) != 0; + } + } + } + + public byte[] ReverseTransform(ReadOnlySpan pkt) + { + return ReverseTransform(pkt, 0, pkt.Length); + } + + public byte[] ReverseTransform(ReadOnlySpan pkt, int offset, int length) + { + var isLocked = Interlocked.CompareExchange(ref _isLocked, 1, 0) != 0; try { // wrap data into raw packet for readable format @@ -133,36 +136,38 @@ public byte[] ReverseTransform(byte[] pkt, int offset, int length) result = packet.GetData(); } return result; - } - finally + } + finally { //Unlock if (!isLocked) + { Interlocked.CompareExchange(ref _isLocked, 0, 1); - } - } - + } + } + } + /// /// Close the transformer and underlying transform engine. /// The close functions closes all stored crypto contexts. This deletes key data /// and forces a cleanup of the crypto contexts. - /// - public void Close() - { - forwardEngine.Close(); + /// + public void Close() + { + forwardEngine.Close(); if (forwardEngine != reverseEngine) { reverseEngine.Close(); - } - - var keys = new List(contexts.Keys); - foreach (var ssrc in keys) + } + + var keys = new List(contexts.Keys); + foreach (var ssrc in keys) { if (contexts.TryRemove(ssrc, out var context)) { context.Close(); - } - } - } - } -} + } + } + } + } +} diff --git a/src/net/DtlsSrtp/Transform/SrtpTransformer.cs b/src/net/DtlsSrtp/Transform/SrtpTransformer.cs index e72816c98..0f800c50b 100644 --- a/src/net/DtlsSrtp/Transform/SrtpTransformer.cs +++ b/src/net/DtlsSrtp/Transform/SrtpTransformer.cs @@ -39,6 +39,7 @@ * */ +using System; using System.Collections.Concurrent; using System.Collections.Generic; using System.Threading; @@ -70,12 +71,12 @@ public SrtpTransformer(SrtpTransformEngine forwardEngine, SrtpTransformEngine re this.rawPacket = new RawPacket(); } - public byte[] Transform(byte[] pkt) + public byte[] Transform(ReadOnlySpan pkt) { return Transform(pkt, 0, pkt.Length); } - public byte[] Transform(byte[] pkt, int offset, int length) + public byte[] Transform(ReadOnlySpan pkt, int offset, int length) { var isLocked = Interlocked.CompareExchange(ref _isLocked, 1, 0) != 0; @@ -107,7 +108,9 @@ public byte[] Transform(byte[] pkt, int offset, int length) { //Unlock if (!isLocked) + { Interlocked.CompareExchange(ref _isLocked, 0, 1); + } } } @@ -119,12 +122,12 @@ public byte[] Transform(byte[] pkt, int offset, int length) * the transformed packet to be restored * @return the restored packet */ - public byte[] ReverseTransform(byte[] pkt) + public byte[] ReverseTransform(ReadOnlySpan pkt) { return ReverseTransform(pkt, 0, pkt.Length); } - public byte[] ReverseTransform(byte[] pkt, int offset, int length) + public byte[] ReverseTransform(ReadOnlySpan pkt, int offset, int length) { var isLocked = Interlocked.CompareExchange(ref _isLocked, 1, 0) != 0; try @@ -157,7 +160,9 @@ public byte[] ReverseTransform(byte[] pkt, int offset, int length) { //Unlock if (!isLocked) + { Interlocked.CompareExchange(ref _isLocked, 0, 1); + } } } diff --git a/src/net/ICE/IceChecklistEntry.cs b/src/net/ICE/IceChecklistEntry.cs index 8317a96eb..472f24a4c 100644 --- a/src/net/ICE/IceChecklistEntry.cs +++ b/src/net/ICE/IceChecklistEntry.cs @@ -182,23 +182,29 @@ public ulong Priority public string RequestTransactionID { get - { - return _cachedRequestTransactionIDs?.Count > 0 ? _cachedRequestTransactionIDs[0] : null; + { + lock (_cachedRequestTransactionIDs) + { + return _cachedRequestTransactionIDs?.Count > 0 ? _cachedRequestTransactionIDs[0] : null; + } } set { - var currentValue = _cachedRequestTransactionIDs?.Count > 0 ? _cachedRequestTransactionIDs[0] : null; - if (value != currentValue) + lock (_cachedRequestTransactionIDs) { - const int MAX_CACHED_REQUEST_IDS = 30; - while (_cachedRequestTransactionIDs.Count >= MAX_CACHED_REQUEST_IDS && _cachedRequestTransactionIDs.Count > 0) - { - _cachedRequestTransactionIDs.RemoveAt(_cachedRequestTransactionIDs.Count - 1); - } - - if (MAX_CACHED_REQUEST_IDS > 0) + var currentValue = _cachedRequestTransactionIDs?.Count > 0 ? _cachedRequestTransactionIDs[0] : null; + if (value != currentValue) { - _cachedRequestTransactionIDs.Insert(0, value); + const int MAX_CACHED_REQUEST_IDS = 30; + while (_cachedRequestTransactionIDs.Count >= MAX_CACHED_REQUEST_IDS && _cachedRequestTransactionIDs.Count > 0) + { + _cachedRequestTransactionIDs.RemoveAt(_cachedRequestTransactionIDs.Count - 1); + } + + if (MAX_CACHED_REQUEST_IDS > 0) + { + _cachedRequestTransactionIDs.Insert(0, value); + } } } } diff --git a/src/net/ICE/RtpIceChannel.cs b/src/net/ICE/RtpIceChannel.cs old mode 100755 new mode 100644 index 29df73134..ef412938e --- a/src/net/ICE/RtpIceChannel.cs +++ b/src/net/ICE/RtpIceChannel.cs @@ -479,7 +479,16 @@ public List Candidates /// /// Creates a copy of the checklist of local and remote candidate pairs /// - internal List Checklist { get { return _checklist.ToList(); } } + internal List Checklist + { + get + { + lock (_checklist) + { + return [.. _checklist]; + } + } + } /// /// For local candidates this implementation takes a shortcut to reduce complexity. @@ -569,7 +578,7 @@ internal int RTO /// public event Action OnStunMessageSent; - public new event Action OnRTPDataReceived; + public new event DataReceivedDelegate OnRTPDataReceived; /// /// An optional callback function to resolve remote ICE candidates with MDNS hostnames. @@ -742,7 +751,7 @@ public void StartGathering() logger.LogDebug($"RTP ICE Channel discovered {_candidates.Count} local candidates."); - if (_iceServerConnections?.Count > 0) + if (_iceServerConnections?.IsEmpty == false) { InitialiseIceServers(_iceServers); _processIceServersTimer = new Timer(CheckIceServers, null, 0, Ta); @@ -1582,7 +1591,12 @@ private async void ProcessChecklist() // Until that happens there is no work to do. if (IceConnectionState == RTCIceConnectionState.checking) { - if (Checklist.Count > 0) + int count; + lock (_checklist) + { + count = _checklist.Count; + } + if (count > 0) { if (RemoteIceUser == null || RemoteIcePassword == null) { @@ -1822,12 +1836,12 @@ private void SendSTUNBindingRequest(ChecklistEntry candidatePair, bool setUseCan { IPEndPoint relayServerEP = candidatePair.LocalCandidate.IceServer.ServerEndPoint; var protocol = candidatePair.LocalCandidate.IceServer.Protocol; - SendRelay(protocol, candidatePair.RemoteCandidate.DestinationEndPoint, stunReqBytes, relayServerEP, candidatePair.LocalCandidate.IceServer); + SendRelay(protocol, candidatePair.RemoteCandidate.DestinationEndPoint, stunReqBytes, relayServerEP, candidatePair.LocalCandidate.IceServer, OnBindingFailure); } else { IPEndPoint remoteEndPoint = candidatePair.RemoteCandidate.DestinationEndPoint; - var sendResult = base.Send(RTPChannelSocketsEnum.RTP, remoteEndPoint, stunReqBytes); + var sendResult = base.Send(RTPChannelSocketsEnum.RTP, remoteEndPoint, stunReqBytes, OnBindingFailure); if (sendResult != SocketError.Success) { @@ -1840,6 +1854,17 @@ private void SendSTUNBindingRequest(ChecklistEntry candidatePair, bool setUseCan } } + bool OnBindingFailure(Exception exception) + { + if (exception is SocketException socketException) + { + logger.LogDebug("Socket exception binding RTP channel: {Code} {Message}.", socketException.ErrorCode, socketException.Message); + return true; + } + + return false; + } + /// /// Builds and sends the connectivity check on a candidate pair that is set /// as the current nominated, connected pair. @@ -1974,7 +1999,10 @@ public async Task ProcessStunMessage(STUNMessage stunMessage, IPEndPoint remoteE else if (IsController) { logger.LogDebug($"ICE RTP channel binding response state {matchingChecklistEntry.State} as Controller for {matchingChecklistEntry.RemoteCandidate.ToShortString()}"); - ProcessNominateLogicAsController(matchingChecklistEntry); + lock (_checklist) + { + ProcessNominateLogicAsController(matchingChecklistEntry); + } } } } @@ -2147,7 +2175,10 @@ private void GotStunBindingRequest(STUNMessage bindingRequest, IPEndPoint remote entry.TurnPermissionsResponseAt = DateTime.Now; } - AddChecklistEntry(entry); + lock (_checklist) + { + AddChecklistEntry(entry); + } matchingChecklistEntry = entry; } @@ -2193,7 +2224,7 @@ private void GotStunBindingRequest(STUNMessage bindingRequest, IPEndPoint remote if (wasRelayed) { var protocol = matchingChecklistEntry.LocalCandidate.IceServer.Protocol; - SendRelay(protocol, remoteEndPoint, stunRespBytes, matchingChecklistEntry.LocalCandidate.IceServer.ServerEndPoint, matchingChecklistEntry.LocalCandidate.IceServer); + SendRelay(protocol, remoteEndPoint, stunRespBytes, matchingChecklistEntry.LocalCandidate.IceServer.ServerEndPoint, matchingChecklistEntry.LocalCandidate.IceServer, onFailure: null); OnStunMessageSent?.Invoke(stunResponse, remoteEndPoint, true); } else @@ -2232,7 +2263,7 @@ private ChecklistEntry GetChecklistEntryForStunResponse(byte[] transactionID) /// If found a matching state object or null if not. private IceServer GetIceServerForTransactionID(byte[] transactionID) { - if (_iceServerConnections == null || _iceServerConnections.Count == 0) + if (_iceServerConnections == null || _iceServerConnections.IsEmpty) { return null; } @@ -2283,7 +2314,7 @@ private SocketError SendStunBindingRequest(IceServer iceServer) var sendResult = iceServer.Protocol == ProtocolType.Tcp ? SendOverTCP(iceServer, stunReqBytes) : - base.Send(RTPChannelSocketsEnum.RTP, iceServer.ServerEndPoint, stunReqBytes); + base.Send(RTPChannelSocketsEnum.RTP, iceServer.ServerEndPoint, stunReqBytes, OnBindingFailure); if (sendResult != SocketError.Success) { @@ -2334,7 +2365,7 @@ private SocketError SendTurnAllocateRequest(IceServer iceServer) var sendResult = iceServer.Protocol == ProtocolType.Tcp ? SendOverTCP(iceServer, allocateReqBytes) : - base.Send(RTPChannelSocketsEnum.RTP, iceServer.ServerEndPoint, allocateReqBytes); + base.Send(RTPChannelSocketsEnum.RTP, iceServer.ServerEndPoint, allocateReqBytes, OnBindingFailure); if (sendResult != SocketError.Success) { @@ -2574,9 +2605,9 @@ private byte[] GetAuthenticatedStunRequest(STUNMessage stunRequest, string usern /// The local port it was received on. /// The remote end point of the sender. /// The raw packet received (note this may not be RTP if other protocols are being multiplexed). - protected override void OnRTPPacketReceived(UdpReceiver receiver, int localPort, IPEndPoint remoteEndPoint, byte[] packet) + protected override void OnRTPPacketReceived(UdpReceiver receiver, int localPort, IPEndPoint remoteEndPoint, ReadOnlySpan packet) { - if (packet?.Length > 0) + if (packet.Length > 0) { bool wasRelayed = false; @@ -2615,16 +2646,16 @@ protected override void OnRTPPacketReceived(UdpReceiver receiver, int localPort, /// The data to send to the peer. /// The TURN server end point to send the relayed request to. /// - private SocketError SendRelay(ProtocolType protocol, IPEndPoint dstEndPoint, byte[] buffer, IPEndPoint relayEndPoint, IceServer iceServer) + private SocketError SendRelay(ProtocolType protocol, IPEndPoint dstEndPoint, ReadOnlySpan buffer, IPEndPoint relayEndPoint, IceServer iceServer, Func? onFailure) { STUNMessage sendReq = new STUNMessage(STUNMessageTypesEnum.SendIndication); sendReq.AddXORPeerAddressAttribute(dstEndPoint.Address, dstEndPoint.Port); - sendReq.Attributes.Add(new STUNAttribute(STUNAttributeTypesEnum.Data, buffer)); + sendReq.Attributes.Add(new STUNAttribute(STUNAttributeTypesEnum.Data, buffer.ToArray())); var request = sendReq.ToByteBuffer(null, false); var sendResult = protocol == ProtocolType.Tcp ? SendOverTCP(iceServer, request) : - base.Send(RTPChannelSocketsEnum.RTP, relayEndPoint, request); + base.Send(RTPChannelSocketsEnum.RTP, relayEndPoint, request, onFailure); if (sendResult != SocketError.Success) { @@ -2692,7 +2723,7 @@ private async Task ResolveMdnsName(RTCIceCandidate candidate) /// The data to send. /// The result of initiating the send. This result does not reflect anything about /// whether the remote party received the packet or not. - public override SocketError Send(RTPChannelSocketsEnum sendOn, IPEndPoint dstEndPoint, byte[] buffer) + public override SocketError Send(RTPChannelSocketsEnum sendOn, IPEndPoint dstEndPoint, ReadOnlySpan buffer, Func? onFailure = null) { if (NominatedEntry != null && NominatedEntry.LocalCandidate.type == RTCIceCandidateType.relay && NominatedEntry.LocalCandidate.IceServer != null && @@ -2702,11 +2733,11 @@ public override SocketError Send(RTPChannelSocketsEnum sendOn, IPEndPoint dstEnd // A TURN relay channel is being used to communicate with the remote peer. var protocol = NominatedEntry.LocalCandidate.IceServer.Protocol; var serverEndPoint = NominatedEntry.LocalCandidate.IceServer.ServerEndPoint; - return SendRelay(protocol, dstEndPoint, buffer, serverEndPoint, NominatedEntry.LocalCandidate.IceServer); + return SendRelay(protocol, dstEndPoint, buffer, serverEndPoint, NominatedEntry.LocalCandidate.IceServer, onFailure); } else { - return base.Send(sendOn, dstEndPoint, buffer); + return base.Send(sendOn, dstEndPoint, buffer, onFailure); } } } diff --git a/src/net/RTCP/RTCPBye.cs b/src/net/RTCP/RTCPBye.cs index 56158e640..c43e1cd98 100644 --- a/src/net/RTCP/RTCPBye.cs +++ b/src/net/RTCP/RTCPBye.cs @@ -27,6 +27,7 @@ //----------------------------------------------------------------------------- using System; +using System.Buffers.Binary; using System.Text; using SIPSorcery.Sys; @@ -73,7 +74,7 @@ public RTCPBye(uint ssrc, string reason) /// Create a new RTCP Goodbye packet from a serialised byte array. /// /// The byte array holding the Goodbye packet. - public RTCPBye(byte[] packet) + public RTCPBye(ReadOnlySpan packet) { if (packet.Length < MIN_PACKET_SIZE) { @@ -81,15 +82,7 @@ public RTCPBye(byte[] packet) } Header = new RTCPHeader(packet); - - if (BitConverter.IsLittleEndian) - { - SSRC = NetConvert.DoReverseEndian(BitConverter.ToUInt32(packet, 4)); - } - else - { - SSRC = BitConverter.ToUInt32(packet, 4); - } + SSRC = BinaryPrimitives.ReadUInt32BigEndian(packet.Slice(4)); if (packet.Length > MIN_PACKET_SIZE) { @@ -97,7 +90,7 @@ public RTCPBye(byte[] packet) if (packet.Length - MIN_PACKET_SIZE - 1 >= reasonLength) { - Reason = Encoding.UTF8.GetString(packet, 9, reasonLength); + Reason = packet.Slice(9, reasonLength).ToString(Encoding.UTF8); } } } diff --git a/src/net/RTCP/RTCPCompoundPacket.cs b/src/net/RTCP/RTCPCompoundPacket.cs index 1ec7be7f1..8ae161c9f 100644 --- a/src/net/RTCP/RTCPCompoundPacket.cs +++ b/src/net/RTCP/RTCPCompoundPacket.cs @@ -60,7 +60,7 @@ public RTCPCompoundPacket(RTCPReceiverReport receiverReport, RTCPSDesReport sdes /// Creates a new RTCP compound packet from a serialised buffer. /// /// The serialised RTCP compound packet to parse. - public RTCPCompoundPacket(byte[] packet) + public RTCPCompoundPacket(ReadOnlySpan packet) { int offset = 0; while (offset < packet.Length) @@ -72,7 +72,7 @@ public RTCPCompoundPacket(byte[] packet) } else { - var buffer = packet.Skip(offset).ToArray(); + var buffer = packet.Slice(offset); // The payload type field is the second byte in the RTCP header. byte packetTypeID = buffer[1]; diff --git a/src/net/RTCP/RTCPFeedback.cs b/src/net/RTCP/RTCPFeedback.cs index f104329f9..5d9ff2b0f 100644 --- a/src/net/RTCP/RTCPFeedback.cs +++ b/src/net/RTCP/RTCPFeedback.cs @@ -27,6 +27,7 @@ //----------------------------------------------------------------------------- using System; +using System.Buffers.Binary; using Microsoft.Extensions.Logging; using SIPSorcery.Sys; @@ -146,21 +147,13 @@ public RTCPFeedback(uint senderSsrc, uint mediaSsrc, PSFBFeedbackTypesEnum feedb /// Create a new RTCP Report from a serialised byte array. /// /// The byte array holding the serialised feedback report. - public RTCPFeedback(byte[] packet) + public RTCPFeedback(ReadOnlySpan packet) { Header = new RTCPHeader(packet); int payloadIndex = RTCPHeader.HEADER_BYTES_LENGTH; - if (BitConverter.IsLittleEndian) - { - SenderSSRC = NetConvert.DoReverseEndian(BitConverter.ToUInt32(packet, payloadIndex)); - MediaSSRC = NetConvert.DoReverseEndian(BitConverter.ToUInt32(packet, payloadIndex + 4)); - } - else - { - SenderSSRC = BitConverter.ToUInt32(packet, payloadIndex); - MediaSSRC = BitConverter.ToUInt32(packet, payloadIndex + 4); - } + SenderSSRC = BinaryPrimitives.ReadUInt32BigEndian(packet.Slice(payloadIndex)); + MediaSSRC = BinaryPrimitives.ReadUInt32BigEndian(packet.Slice(payloadIndex + 4)); switch (Header) { @@ -170,16 +163,8 @@ public RTCPFeedback(byte[] packet) break; case var x when x.PacketType == RTCPReportTypesEnum.RTPFB: SENDER_PAYLOAD_SIZE = 12; - if (BitConverter.IsLittleEndian) - { - PID = NetConvert.DoReverseEndian(BitConverter.ToUInt16(packet, payloadIndex + 8)); - BLP = NetConvert.DoReverseEndian(BitConverter.ToUInt16(packet, payloadIndex + 10)); - } - else - { - PID = BitConverter.ToUInt16(packet, payloadIndex + 8); - BLP = BitConverter.ToUInt16(packet, payloadIndex + 10); - } + PID = BinaryPrimitives.ReadUInt16BigEndian(packet.Slice(payloadIndex + 8)); + BLP = BinaryPrimitives.ReadUInt16BigEndian(packet.Slice(payloadIndex + 10)); break; case var x when x.PacketType == RTCPReportTypesEnum.PSFB && x.PayloadFeedbackMessageType == PSFBFeedbackTypesEnum.PLI: @@ -202,7 +187,7 @@ public RTCPFeedback(byte[] packet) SENDER_PAYLOAD_SIZE = 8 + 12; // 8 bytes from (SenderSSRC + MediaSSRC) + extra 12 bytes from REMB Definition var currentCounter = payloadIndex + 8; - UniqueID = System.Text.ASCIIEncoding.ASCII.GetString(packet, currentCounter, 4); + UniqueID = packet.Slice(currentCounter, 4).ToString(System.Text.Encoding.ASCII); currentCounter += 4; if (string.Equals(UniqueID,"REMB", StringComparison.CurrentCultureIgnoreCase)) @@ -227,14 +212,7 @@ public RTCPFeedback(byte[] packet) currentCounter += 3; - if (BitConverter.IsLittleEndian) - { - FeedbackSSRC = NetConvert.DoReverseEndian(BitConverter.ToUInt32(packet, currentCounter)); - } - else - { - FeedbackSSRC = BitConverter.ToUInt32(packet, currentCounter); - } + FeedbackSSRC = BinaryPrimitives.ReadUInt32BigEndian(packet.Slice(currentCounter)); } break; diff --git a/src/net/RTCP/RTCPHeader.cs b/src/net/RTCP/RTCPHeader.cs index a6ccf244d..81aab080e 100644 --- a/src/net/RTCP/RTCPHeader.cs +++ b/src/net/RTCP/RTCPHeader.cs @@ -33,6 +33,7 @@ //----------------------------------------------------------------------------- using System; +using System.Buffers.Binary; using SIPSorcery.Sys; namespace SIPSorcery.Net @@ -132,24 +133,15 @@ public bool IsFeedbackReport() /// Extract and load the RTCP header from an RTCP packet. /// /// - public RTCPHeader(byte[] packet) + public RTCPHeader(ReadOnlySpan packet) { if (packet.Length < HEADER_BYTES_LENGTH) { throw new ApplicationException("The packet did not contain the minimum number of bytes for an RTCP header packet."); } - UInt16 firstWord = BitConverter.ToUInt16(packet, 0); - - if (BitConverter.IsLittleEndian) - { - firstWord = NetConvert.DoReverseEndian(firstWord); - Length = NetConvert.DoReverseEndian(BitConverter.ToUInt16(packet, 2)); - } - else - { - Length = BitConverter.ToUInt16(packet, 2); - } + UInt16 firstWord = BinaryPrimitives.ReadUInt16BigEndian(packet); + Length = BinaryPrimitives.ReadUInt16BigEndian(packet.Slice(2)); Version = Convert.ToInt32(firstWord >> 14); PaddingFlag = Convert.ToInt32((firstWord >> 13) & 0x1); diff --git a/src/net/RTCP/RTCPReceiverReport.cs b/src/net/RTCP/RTCPReceiverReport.cs index 2819c4dee..cdf23a3c3 100644 --- a/src/net/RTCP/RTCPReceiverReport.cs +++ b/src/net/RTCP/RTCPReceiverReport.cs @@ -45,6 +45,7 @@ //----------------------------------------------------------------------------- using System; +using System.Buffers.Binary; using System.Collections.Generic; using System.Linq; using SIPSorcery.Sys; @@ -76,7 +77,7 @@ public RTCPReceiverReport(uint ssrc, List receptionReport /// Create a new RTCP Receiver Report from a serialised byte array. /// /// The byte array holding the serialised receiver report. - public RTCPReceiverReport(byte[] packet) + public RTCPReceiverReport(ReadOnlySpan packet) { if (packet.Length < MIN_PACKET_SIZE) { @@ -85,20 +86,12 @@ public RTCPReceiverReport(byte[] packet) Header = new RTCPHeader(packet); ReceptionReports = new List(); - - if (BitConverter.IsLittleEndian) - { - SSRC = NetConvert.DoReverseEndian(BitConverter.ToUInt32(packet, 4)); - } - else - { - SSRC = BitConverter.ToUInt32(packet, 4); - } + SSRC = BinaryPrimitives.ReadUInt32BigEndian(packet.Slice(4)); int rrIndex = 8; for (int i = 0; i < Header.ReceptionReportCount; i++) { - var rr = new ReceptionReportSample(packet.Skip(rrIndex + i * ReceptionReportSample.PAYLOAD_SIZE).ToArray()); + var rr = new ReceptionReportSample(packet.Slice(rrIndex + i * ReceptionReportSample.PAYLOAD_SIZE)); ReceptionReports.Add(rr); } } diff --git a/src/net/RTCP/RTCPSdesReport.cs b/src/net/RTCP/RTCPSdesReport.cs index c90a7f6f1..b30e68ec0 100644 --- a/src/net/RTCP/RTCPSdesReport.cs +++ b/src/net/RTCP/RTCPSdesReport.cs @@ -51,6 +51,7 @@ //----------------------------------------------------------------------------- using System; +using System.Buffers.Binary; using System.Text; using SIPSorcery.Sys; @@ -100,7 +101,7 @@ public RTCPSDesReport(uint ssrc, string cname) /// Create a new RTCP SDES item from a serialised byte array. /// /// The byte array holding the SDES report. - public RTCPSDesReport(byte[] packet) + public RTCPSDesReport(ReadOnlySpan packet) { // if (packet.Length < MIN_PACKET_SIZE) // { @@ -119,14 +120,7 @@ public RTCPSDesReport(byte[] packet) if (packet.Length >= RTCPHeader.HEADER_BYTES_LENGTH+4) { - if (BitConverter.IsLittleEndian) - { - SSRC = NetConvert.DoReverseEndian(BitConverter.ToUInt32(packet, RTCPHeader.HEADER_BYTES_LENGTH)); - } - else - { - SSRC = BitConverter.ToUInt32(packet, RTCPHeader.HEADER_BYTES_LENGTH); - } + SSRC = BinaryPrimitives.ReadUInt32BigEndian(packet.Slice(RTCPHeader.HEADER_BYTES_LENGTH)); } if (packet.Length >= MIN_PACKET_SIZE) @@ -137,7 +131,7 @@ public RTCPSDesReport(byte[] packet) CNAME = string.Empty; return; } - CNAME = Encoding.UTF8.GetString(packet, 10, cnameLength); + CNAME = packet.Slice(10, cnameLength).ToString(Encoding.UTF8); } } diff --git a/src/net/RTCP/RTCPSenderReport.cs b/src/net/RTCP/RTCPSenderReport.cs index 0e5d8cb22..590eabd5f 100644 --- a/src/net/RTCP/RTCPSenderReport.cs +++ b/src/net/RTCP/RTCPSenderReport.cs @@ -45,6 +45,7 @@ //----------------------------------------------------------------------------- using System; +using System.Buffers.Binary; using System.Collections.Generic; using System.Linq; using SIPSorcery.Sys; @@ -91,7 +92,7 @@ public RTCPSenderReport(uint ssrc, ulong ntpTimestamp, uint rtpTimestamp, uint p /// Create a new RTCP Sender Report from a serialised byte array. /// /// The byte array holding the serialised sender report. - public RTCPSenderReport(byte[] packet) + public RTCPSenderReport(ReadOnlySpan packet) { if (packet.Length < MIN_PACKET_SIZE) { @@ -101,27 +102,16 @@ public RTCPSenderReport(byte[] packet) Header = new RTCPHeader(packet); ReceptionReports = new List(); - if (BitConverter.IsLittleEndian) - { - SSRC = NetConvert.DoReverseEndian(BitConverter.ToUInt32(packet, 4)); - NtpTimestamp = NetConvert.DoReverseEndian(BitConverter.ToUInt64(packet, 8)); - RtpTimestamp = NetConvert.DoReverseEndian(BitConverter.ToUInt32(packet, 16)); - PacketCount = NetConvert.DoReverseEndian(BitConverter.ToUInt32(packet, 20)); - OctetCount = NetConvert.DoReverseEndian(BitConverter.ToUInt32(packet, 24)); - } - else - { - SSRC = BitConverter.ToUInt32(packet, 4); - NtpTimestamp = BitConverter.ToUInt64(packet, 8); - RtpTimestamp = BitConverter.ToUInt32(packet, 16); - PacketCount = BitConverter.ToUInt32(packet, 20); - OctetCount = BitConverter.ToUInt32(packet, 24); - } + SSRC = BinaryPrimitives.ReadUInt32BigEndian(packet.Slice(4)); + NtpTimestamp = BinaryPrimitives.ReadUInt64BigEndian(packet.Slice(8)); + RtpTimestamp = BinaryPrimitives.ReadUInt32BigEndian(packet.Slice(16)); + PacketCount = BinaryPrimitives.ReadUInt32BigEndian(packet.Slice(20)); + OctetCount = BinaryPrimitives.ReadUInt32BigEndian(packet.Slice(24)); int rrIndex = 28; for (int i = 0; i < Header.ReceptionReportCount; i++) { - var rr = new ReceptionReportSample(packet.Skip(rrIndex + i * ReceptionReportSample.PAYLOAD_SIZE).ToArray()); + var rr = new ReceptionReportSample(packet.Slice(rrIndex + i * ReceptionReportSample.PAYLOAD_SIZE)); ReceptionReports.Add(rr); } } diff --git a/src/net/RTCP/RTCPSession.cs b/src/net/RTCP/RTCPSession.cs index 18af4171b..ebe2a20db 100644 --- a/src/net/RTCP/RTCPSession.cs +++ b/src/net/RTCP/RTCPSession.cs @@ -91,6 +91,11 @@ public class RTCPSession /// public DateTime LastActivityAt { get; private set; } = DateTime.MinValue; + /// + /// Time to wait before classifying the session as timed out due to inactivity. + /// + public TimeSpan NoActivityTimeout { get; private set; } = new(ticks: NO_ACTIVITY_TIMEOUT_MILLISECONDS * TimeSpan.TicksPerMillisecond); + /// /// Indicates whether the session is currently in a timed out state. This /// occurs if no RTP or RTCP packets have been received during an expected @@ -313,8 +318,8 @@ private void SendReportTimerCallback(Object stateInfo) { lock (m_rtcpReportTimer) { - if ((LastActivityAt != DateTime.MinValue && DateTime.Now.Subtract(LastActivityAt).TotalMilliseconds > NO_ACTIVITY_TIMEOUT_MILLISECONDS) || - (LastActivityAt == DateTime.MinValue && DateTime.Now.Subtract(CreatedAt).TotalMilliseconds > NO_ACTIVITY_TIMEOUT_MILLISECONDS)) + if ((LastActivityAt != DateTime.MinValue && DateTime.Now.Subtract(LastActivityAt) > NoActivityTimeout) || + (LastActivityAt == DateTime.MinValue && DateTime.Now.Subtract(CreatedAt) > NoActivityTimeout)) { if (!IsTimedOut) { diff --git a/src/net/RTCP/ReceptionReport.cs b/src/net/RTCP/ReceptionReport.cs index 0fbef3779..17ec2b040 100644 --- a/src/net/RTCP/ReceptionReport.cs +++ b/src/net/RTCP/ReceptionReport.cs @@ -33,6 +33,7 @@ //----------------------------------------------------------------------------- using System; +using System.Buffers.Binary; using SIPSorcery.Sys; namespace SIPSorcery.Net @@ -111,28 +112,16 @@ public ReceptionReportSample( DelaySinceLastSenderReport = delaySinceLastSR; } - public ReceptionReportSample(byte[] packet) + public ReceptionReportSample(ReadOnlySpan packet) { - if (BitConverter.IsLittleEndian) - { - SSRC = NetConvert.DoReverseEndian(BitConverter.ToUInt32(packet, 0)); - FractionLost = packet[4]; - PacketsLost = NetConvert.DoReverseEndian(BitConverter.ToInt32(new byte[] { 0x00, packet[5], packet[6], packet[7] }, 0)); - ExtendedHighestSequenceNumber = NetConvert.DoReverseEndian(BitConverter.ToUInt32(packet, 8)); - Jitter = NetConvert.DoReverseEndian(BitConverter.ToUInt32(packet, 12)); - LastSenderReportTimestamp = NetConvert.DoReverseEndian(BitConverter.ToUInt32(packet, 16)); - DelaySinceLastSenderReport = NetConvert.DoReverseEndian(BitConverter.ToUInt32(packet, 20)); - } - else - { - SSRC = BitConverter.ToUInt32(packet, 4); - FractionLost = packet[4]; - PacketsLost = BitConverter.ToInt32(new byte[] { 0x00, packet[5], packet[6], packet[7] }, 0); - ExtendedHighestSequenceNumber = BitConverter.ToUInt32(packet, 8); - Jitter = BitConverter.ToUInt32(packet, 12); - LastSenderReportTimestamp = BitConverter.ToUInt32(packet, 16); - DelaySinceLastSenderReport = BitConverter.ToUInt32(packet, 20); - } + SSRC = BinaryPrimitives.ReadUInt32BigEndian(packet.Slice(0, 4)); + FractionLost = packet[4]; + Span packetsLost = stackalloc byte[] { 0x00, packet[5], packet[6], packet[7] }; + PacketsLost = BinaryPrimitives.ReadInt32BigEndian(packetsLost); + ExtendedHighestSequenceNumber = BinaryPrimitives.ReadUInt32BigEndian(packet.Slice(8)); + Jitter = BinaryPrimitives.ReadUInt32BigEndian(packet.Slice(12)); + LastSenderReportTimestamp = BinaryPrimitives.ReadUInt32BigEndian(packet.Slice(16)); + DelaySinceLastSenderReport = BinaryPrimitives.ReadUInt32BigEndian(packet.Slice(20)); } /// diff --git a/src/net/RTP/MediaStream.cs b/src/net/RTP/MediaStream.cs index 40ab12b9e..fd9a4ad87 100644 --- a/src/net/RTP/MediaStream.cs +++ b/src/net/RTP/MediaStream.cs @@ -139,6 +139,11 @@ public bool IsClosed } _isClosed = value; + if (value) + { + RtcpSession.OnTimeout -= RaiseOnTimeoutByIndex; + } + //Clear previous buffer ClearPendingPackages(); @@ -286,7 +291,7 @@ public Boolean IsSecurityContextReady() return (SecureContext != null); } - private (bool, byte[]) UnprotectBuffer(byte[] buffer) + private bool UnprotectBuffer(Span buffer, out ReadOnlySpan result) { if (SecureContext != null) { @@ -294,21 +299,23 @@ public Boolean IsSecurityContextReady() if (res == 0) { - return (true, buffer.Take(outBufLen).ToArray()); + result = buffer.Slice(0, outBufLen); + return true; } else { logger.LogWarning($"SRTP unprotect failed for {MediaType}, result {res}."); } } - return (false, buffer); + result = buffer; + return false; } - public bool EnsureBufferUnprotected(byte[] buf, RTPHeader header, out RTPPacket packet) + public bool EnsureBufferUnprotected(Span buf, RTPHeader header, out RTPPacket packet) { if (RtpSessionConfig.IsSecure || RtpSessionConfig.UseSdpCryptoNegotiation) { - var (succeeded, newBuffer) = UnprotectBuffer(buf); + var succeeded = UnprotectBuffer(buf, out var newBuffer); if (!succeeded) { packet = null; @@ -489,7 +496,7 @@ 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 } else { - rtpChannel.Send(RTPChannelSocketsEnum.RTP, DestinationEndPoint, rtpBuffer.Take(outBufLen).ToArray()); + rtpChannel.Send(RTPChannelSocketsEnum.RTP, DestinationEndPoint, rtpBuffer.AsSpan(0, outBufLen)); } } m_lastRtpTimestamp = timestamp; @@ -630,7 +637,7 @@ private bool SendRtcpReport(byte[] reportBuffer) } else { - rtpChannel.Send(sendOnSocket, ControlDestinationEndPoint, sendBuffer.Take(outBufLen).ToArray()); + rtpChannel.Send(sendOnSocket, ControlDestinationEndPoint, sendBuffer.AsSpan(0, outBufLen)); } } } @@ -672,7 +679,7 @@ public void SendRtcpFeedback(RTCPFeedback feedback) #region RECEIVE PACKET - public void OnReceiveRTPPacket(RTPHeader hdr, int localPort, IPEndPoint remoteEndPoint, byte[] buffer, VideoStream videoStream = null) + public void OnReceiveRTPPacket(RTPHeader hdr, int localPort, IPEndPoint remoteEndPoint, Span buffer, VideoStream videoStream = null) { RTPPacket rtpPacket; if (RemoteRtpEventPayloadID != 0 && hdr.PayloadType == RemoteRtpEventPayloadID) @@ -823,7 +830,7 @@ protected virtual void ClearPendingPackages() // Cache pending packages to use it later to prevent missing frames // when DTLS was not completed yet as a Server but already completed as a client - protected virtual bool AddPendingPackage(RTPHeader hdr, int localPort, IPEndPoint remoteEndPoint, byte[] buffer, VideoStream videoStream = null) + protected virtual bool AddPendingPackage(RTPHeader hdr, int localPort, IPEndPoint remoteEndPoint, ReadOnlySpan buffer, VideoStream videoStream = null) { const int MAX_PENDING_PACKAGES_BUFFER_SIZE = 32; @@ -836,7 +843,7 @@ protected virtual bool AddPendingPackage(RTPHeader hdr, int localPort, IPEndPoin { _pendingPackagesBuffer.RemoveAt(0); } - _pendingPackagesBuffer.Add(new PendingPackages(hdr, localPort, remoteEndPoint, buffer, videoStream)); + _pendingPackagesBuffer.Add(new PendingPackages(hdr, localPort, remoteEndPoint, buffer.ToArray(), videoStream)); } return true; } @@ -994,6 +1001,8 @@ public void ProcessHeaderExtensions(RTPHeader header, IPEndPoint remoteEndPoint) }); } + public override string ToString() => $"{MediaType}[{Index}]"; + public MediaStream(RtpSessionConfig config, int index) { RtpSessionConfig = config; diff --git a/src/net/RTP/RTPChannel.cs b/src/net/RTP/RTPChannel.cs index bb6e5abfa..18c0f631b 100644 --- a/src/net/RTP/RTPChannel.cs +++ b/src/net/RTP/RTPChannel.cs @@ -18,14 +18,16 @@ //----------------------------------------------------------------------------- using System; +using System.Buffers; using System.Net; using System.Net.Sockets; +using System.Threading.Tasks; using Microsoft.Extensions.Logging; using SIPSorcery.Sys; namespace SIPSorcery.Net { - public delegate void PacketReceivedDelegate(UdpReceiver receiver, int localPort, IPEndPoint remoteEndPoint, byte[] packet); + public delegate void PacketReceivedDelegate(UdpReceiver receiver, int localPort, IPEndPoint remoteEndPoint, ReadOnlySpan packet); /// /// A basic UDP socket manager. The RTP channel may need both an RTP and Control socket. This class encapsulates @@ -50,11 +52,11 @@ public class UdpReceiver protected static ILogger logger = Log.Logger; protected readonly Socket m_socket; - protected byte[] m_recvBuffer; + protected readonly byte[] m_recvBuffer; protected bool m_isClosed; protected bool m_isRunningReceive; - protected IPEndPoint m_localEndPoint; - protected AddressFamily m_addressFamily; + protected readonly IPEndPoint m_localEndPoint; + protected readonly AddressFamily m_addressFamily; public virtual bool IsClosed { @@ -104,8 +106,11 @@ public UdpReceiver(Socket socket, int mtu = RECEIVE_BUFFER_SIZE) m_localEndPoint = m_socket.LocalEndPoint as IPEndPoint; m_recvBuffer = new byte[mtu]; m_addressFamily = m_socket.LocalEndPoint.AddressFamily; + endReceiveFrom = EndReceiveFrom; } + static readonly IPEndPoint IPv4AnyEndPoint = new(IPAddress.Any, 0); + static readonly IPEndPoint IPv6AnyEndPoint = new(IPAddress.IPv6Any, 0); /// /// Starts the receive. This method returns immediately. An event will be fired in the corresponding "End" event to /// return any data received. @@ -125,8 +130,38 @@ public virtual void BeginReceiveFrom() try { m_isRunningReceive = true; - EndPoint recvEndPoint = m_addressFamily == AddressFamily.InterNetwork ? new IPEndPoint(IPAddress.Any, 0) : new IPEndPoint(IPAddress.IPv6Any, 0); - m_socket.BeginReceiveFrom(m_recvBuffer, 0, m_recvBuffer.Length, SocketFlags.None, ref recvEndPoint, EndReceiveFrom, null); + EndPoint recvEndPoint = m_addressFamily == AddressFamily.InterNetwork ? IPv4AnyEndPoint : IPv6AnyEndPoint; +#if FALSE // NET6_0_OR_GREATER bandwidth test falters at some point if this is enabled + var recive = m_socket.ReceiveFromAsync(m_recvBuffer.AsMemory(), SocketFlags.None, recvEndPoint); + if (recive.IsCompleted) + { + try + { + var result = recive.GetAwaiter().GetResult(); + EndReceiveFrom(result); + } + catch (Exception excp) + { + EndReceiveFrom(excp); + } + } + else + { + recive.AsTask().ContinueWith(t => + { + try + { + EndReceiveFrom(t.GetAwaiter().GetResult()); + } + catch (Exception excp) + { + EndReceiveFrom(excp); + } + }); + } +#else + m_socket.BeginReceiveFrom(m_recvBuffer, 0, m_recvBuffer.Length, SocketFlags.None, ref recvEndPoint, endReceiveFrom, null); +#endif } // Thrown when socket is closed. Can be safely ignored. // This exception can be thrown in response to an ICMP packet. The problem is the ICMP packet can be a false positive. @@ -153,6 +188,31 @@ public virtual void BeginReceiveFrom() } } +#if NET6_0_OR_GREATER + protected virtual void EndReceiveFrom(SocketReceiveFromResult result) + { + OnBytesRead(result.RemoteEndPoint, result.ReceivedBytes); + + try + { + Drain(); + } + catch (Exception error) + { + EndReceiveFrom(error); + } + finally + { + m_isRunningReceive = false; + if (!m_isClosed) + { + BeginReceiveFrom(); + } + } + } +#endif + + readonly AsyncCallback endReceiveFrom; /// /// Handler for end of the begin receive call. /// @@ -164,58 +224,35 @@ protected virtual void EndReceiveFrom(IAsyncResult ar) // When socket is closed the object will be disposed of in the middle of a receive. if (!m_isClosed) { - EndPoint remoteEP = m_addressFamily == AddressFamily.InterNetwork ? new IPEndPoint(IPAddress.Any, 0) : new IPEndPoint(IPAddress.IPv6Any, 0); + EndPoint remoteEP = m_addressFamily == AddressFamily.InterNetwork ? IPv4AnyEndPoint : IPv6AnyEndPoint; int bytesRead = m_socket.EndReceiveFrom(ar, ref remoteEP); - - if (bytesRead > 0) - { - // During experiments IPPacketInformation wasn't getting set on Linux. Without it the local IP address - // cannot be determined when a listener was bound to IPAddress.Any (or IPv6 equivalent). If the caller - // is relying on getting the local IP address on Linux then something may fail. - //if (packetInfo != null && packetInfo.Address != null) - //{ - // localEndPoint = new IPEndPoint(packetInfo.Address, localEndPoint.Port); - //} - - byte[] packetBuffer = new byte[bytesRead]; - // TODO: When .NET Framework support is dropped switch to using a slice instead of a copy. - Buffer.BlockCopy(m_recvBuffer, 0, packetBuffer, 0, bytesRead); - CallOnPacketReceivedCallback(m_localEndPoint.Port, remoteEP as IPEndPoint, packetBuffer); - } + OnBytesRead(remoteEP, bytesRead); } - // If there is still data available it should be read now. This is more efficient than calling - // BeginReceiveFrom which will incur the overhead of creating the callback and then immediately firing it. - // It also avoids the situation where if the application cannot keep up with the network then BeginReceiveFrom - // will be called synchronously (if data is available it calls the callback method immediately) which can - // create a very nasty stack. - if (!m_isClosed && m_socket.Available > 0) + Drain(); + } + catch (Exception error) + { + EndReceiveFrom(error); + } + finally + { + m_isRunningReceive = false; + if (!m_isClosed) { - while (!m_isClosed && m_socket.Available > 0) - { - EndPoint remoteEP = m_addressFamily == AddressFamily.InterNetwork ? new IPEndPoint(IPAddress.Any, 0) : new IPEndPoint(IPAddress.IPv6Any, 0); - int bytesReadSync = m_socket.ReceiveFrom(m_recvBuffer, 0, m_recvBuffer.Length, SocketFlags.None, ref remoteEP); - - if (bytesReadSync > 0) - { - byte[] packetBufferSync = new byte[bytesReadSync]; - // TODO: When .NET Framework support is dropped switch to using a slice instead of a copy. - Buffer.BlockCopy(m_recvBuffer, 0, packetBufferSync, 0, bytesReadSync); - CallOnPacketReceivedCallback(m_localEndPoint.Port, remoteEP as IPEndPoint, packetBufferSync); - } - else - { - break; - } - } + BeginReceiveFrom(); } } - catch (SocketException resetSockExcp) when (resetSockExcp.SocketErrorCode == SocketError.ConnectionReset) + } + + private void EndReceiveFrom(Exception excp) + { + switch(excp) { + case SocketException resetSockExcp when (resetSockExcp.SocketErrorCode == SocketError.ConnectionReset): // Thrown when close is called on a socket from this end. Safe to ignore. - } - catch (SocketException sockExcp) - { + break; + case SocketException sockExcp: // Socket errors do not trigger a close. The reason being that there are genuine situations that can cause them during // normal RTP operation. For example: // - the RTP connection may start sending before the remote socket starts listening, @@ -226,20 +263,68 @@ protected virtual void EndReceiveFrom(IAsyncResult ar) // BeginReceive before any packets have been exchanged. This means it's not safe to close if BeginReceive gets an ICMP // error since the remote party may not have initialised their socket yet. logger.LogWarning(sockExcp, $"SocketException UdpReceiver.EndReceiveFrom ({sockExcp.SocketErrorCode}). {sockExcp.Message}"); - } - catch (ObjectDisposedException) // Thrown when socket is closed. Can be safely ignored. - { } - catch (Exception excp) - { + break; + case ObjectDisposedException: // Thrown when socket is closed. Can be safely ignored. + break; + case AggregateException: + foreach (var innerExcp in (excp as AggregateException).InnerExceptions) + { + EndReceiveFrom(innerExcp); + } + break; + default: logger.LogError($"Exception UdpReceiver.EndReceiveFrom. {excp}"); Close(excp.Message); + break; } - finally + } + + private void OnBytesRead(EndPoint remoteEP, int bytesRead) + { + if (bytesRead > 0) { - m_isRunningReceive = false; - if (!m_isClosed) + // During experiments IPPacketInformation wasn't getting set on Linux. Without it the local IP address + // cannot be determined when a listener was bound to IPAddress.Any (or IPv6 equivalent). If the caller + // is relying on getting the local IP address on Linux then something may fail. + //if (packetInfo != null && packetInfo.Address != null) + //{ + // localEndPoint = new IPEndPoint(packetInfo.Address, localEndPoint.Port); + //} + + if (bytesRead < 256 * 1024) { - BeginReceiveFrom(); + Span packetBuffer = stackalloc byte[bytesRead]; + CallOnPacketReceivedCallback(m_localEndPoint.Port, remoteEP as IPEndPoint, m_recvBuffer.AsSpan().Slice(0, bytesRead)); + } + else + { + logger.LogCritical("UDP packet received was larger than 256KB and was ignored."); + } + } + } + + private void Drain() + { + // If there is still data available it should be read now. This is more efficient than calling + // BeginReceiveFrom which will incur the overhead of creating the callback and then immediately firing it. + // It also avoids the situation where if the application cannot keep up with the network then BeginReceiveFrom + // will be called synchronously (if data is available it calls the callback method immediately) which can + // create a very nasty stack. + while (!m_isClosed && m_socket.Available > 0) + { + EndPoint remoteEP = m_addressFamily == AddressFamily.InterNetwork ? IPv4AnyEndPoint : IPv6AnyEndPoint; + int bytesReadSync = m_socket.ReceiveFrom(m_recvBuffer, 0, m_recvBuffer.Length, SocketFlags.None, ref remoteEP); + + if (bytesReadSync > 0) + { + byte[] packetBufferSync = new byte[bytesReadSync]; + // TODO: When .NET Framework support is dropped switch to using a slice instead of a copy. + Buffer.BlockCopy(m_recvBuffer, 0, packetBufferSync, 0, bytesReadSync); + CallOnPacketReceivedCallback(m_localEndPoint.Port, remoteEP as IPEndPoint, packetBufferSync); + } + else + { + break; } } } @@ -258,7 +343,7 @@ public virtual void Close(string reason) } } - protected virtual void CallOnPacketReceivedCallback(int localPort, IPEndPoint remoteEndPoint, byte[] packet) + protected virtual void CallOnPacketReceivedCallback(int localPort, IPEndPoint remoteEndPoint, ReadOnlySpan packet) { OnPacketReceived?.Invoke(this, localPort, remoteEndPoint, packet); } @@ -277,6 +362,7 @@ public enum RTPChannelSocketsEnum /// public class RTPChannel : IDisposable { + private static readonly bool isMono = Type.GetType("Mono.Runtime") != null; private static ILogger logger = Log.Logger; protected UdpReceiver m_rtpReceiver; private Socket m_controlSocket; @@ -343,8 +429,10 @@ public bool IsClosed get { return m_isClosed; } } - public event Action OnRTPDataReceived; - public event Action OnControlDataReceived; + public delegate void DataReceivedDelegate(int localPort, IPEndPoint remoteEndPoint, ReadOnlySpan buffer); + + public event DataReceivedDelegate OnRTPDataReceived; + public event DataReceivedDelegate OnControlDataReceived; public event Action OnClosed; /// @@ -375,6 +463,7 @@ public RTPChannel(bool createControlSocket, IPAddress bindAddress, int bindPort RTPPort = RTPLocalEndPoint.Port; ControlLocalEndPoint = (m_controlSocket != null) ? m_controlSocket.LocalEndPoint as IPEndPoint : null; ControlPort = (m_controlSocket != null) ? ControlLocalEndPoint.Port : 0; + endSendTo = EndSendTo; } /// @@ -460,9 +549,10 @@ public void Close(string reason) /// The socket to send on. Can be the RTP or Control socket. /// The destination end point to send to. /// The data to send. + /// If supplied and given an exception returns true, disables further error processing. /// The result of initiating the send. This result does not reflect anything about /// whether the remote party received the packet or not. - public virtual SocketError Send(RTPChannelSocketsEnum sendOn, IPEndPoint dstEndPoint, byte[] buffer) + public virtual SocketError Send(RTPChannelSocketsEnum sendOn, IPEndPoint dstEndPoint, ReadOnlySpan buffer, Func? onFailure = null) { if (m_isClosed) { @@ -504,7 +594,7 @@ public virtual SocketError Send(RTPChannelSocketsEnum sendOn, IPEndPoint dstEndP } //Prevent Send to IPV4 while socket is IPV6 (Mono Error) - if (dstEndPoint.AddressFamily == AddressFamily.InterNetwork && sendSocket.AddressFamily != dstEndPoint.AddressFamily) + if (isMono && dstEndPoint.AddressFamily == AddressFamily.InterNetwork && sendSocket.AddressFamily != dstEndPoint.AddressFamily) { dstEndPoint = new IPEndPoint(dstEndPoint.Address.MapToIPv6(), dstEndPoint.Port); } @@ -515,7 +605,49 @@ public virtual SocketError Send(RTPChannelSocketsEnum sendOn, IPEndPoint dstEndP m_rtpReceiver.BeginReceiveFrom(); } - sendSocket.BeginSendTo(buffer, 0, buffer.Length, SocketFlags.None, dstEndPoint, EndSendTo, sendSocket); +#if NET6_0_OR_GREATER + var tmp = ArrayPool.Shared.Rent(buffer.Length); + buffer.CopyTo(tmp); + ValueTask send; + try + { + send = sendSocket.SendToAsync(tmp.AsMemory(0, buffer.Length), SocketFlags.None, dstEndPoint); + } + catch + { + ArrayPool.Shared.Return(tmp); + throw; + } + if (send.IsCompleted) + { + try + { + send.GetAwaiter().GetResult(); + } + catch (Exception excp) + { + EndSendTo(excp, dstEndPoint, sendSocket, onFailure); + } + finally + { + ArrayPool.Shared.Return(tmp); + } + } + else + { + send.AsTask().ContinueWith(t => + { + ArrayPool.Shared.Return(tmp); + if (t.IsFaulted) + { + EndSendTo(t.Exception, dstEndPoint, sendSocket, onFailure); + } + }); + } + +#else + sendSocket.BeginSendTo(buffer.ToArray(), 0, buffer.Length, SocketFlags.None, dstEndPoint, endSendTo, sendSocket); +#endif return SocketError.Success; } catch (ObjectDisposedException) // Thrown when socket is closed. Can be safely ignored. @@ -534,31 +666,60 @@ public virtual SocketError Send(RTPChannelSocketsEnum sendOn, IPEndPoint dstEndP } } + readonly AsyncCallback endSendTo; /// /// Ends an async send on one of the channel's sockets. /// /// The async result to complete the send with. private void EndSendTo(IAsyncResult ar) { + Socket sendSocket = (Socket)ar.AsyncState; try { - Socket sendSocket = (Socket)ar.AsyncState; int bytesSent = sendSocket.EndSendTo(ar); } - catch (SocketException sockExcp) + catch (Exception excp) { + EndSendTo(excp, socket: sendSocket); + } + } + + private void EndSendTo(Exception exception, IPEndPoint? endPoint = null, Socket? socket = null, Func? onFailure = null) + { + string kind = socket switch + { + null => "Unknown", + { } m when m == m_controlSocket => "Control", + _ => "RTP", + }; + if (onFailure?.Invoke(exception) == true) + { + return; + } + switch (exception) + { + case SocketException sockExcp: // Socket errors do not trigger a close. The reason being that there are genuine situations that can cause them during // normal RTP operation. For example: // - the RTP connection may start sending before the remote socket starts listening, // - an on hold, transfer, etc. operation can change the RTP end point which could result in socket errors from the old // or new socket during the transition. - logger.LogWarning(sockExcp, $"SocketException RTPChannel EndSendTo ({sockExcp.ErrorCode}). {sockExcp.Message}"); - } - catch (ObjectDisposedException) // Thrown when socket is closed. Can be safely ignored. - { } - catch (Exception excp) - { - logger.LogError($"Exception RTPChannel EndSendTo. {excp.Message}"); + logger.LogWarning("SocketException RTPChannel EndSendTo {Kind} {EndPoint} ({ErrorCode}). {Message}", kind, endPoint, sockExcp.ErrorCode, sockExcp.Message); + break; + + case ObjectDisposedException: // Thrown when socket is closed. Can be safely ignored. + break; + + case AggregateException aggExcp: + foreach (var innerExcp in aggExcp.InnerExceptions) + { + EndSendTo(innerExcp, endPoint, socket, onFailure); + } + break; + + default: + logger.LogError("Exception RTPChannel EndSendTo. {Kind} {EndPoint} {Message}", kind, endPoint, exception.Message); + break; } } @@ -569,9 +730,9 @@ private void EndSendTo(IAsyncResult ar) /// The local port it was received on. /// The remote end point of the sender. /// The raw packet received (note this may not be RTP if other protocols are being multiplexed). - protected virtual void OnRTPPacketReceived(UdpReceiver receiver, int localPort, IPEndPoint remoteEndPoint, byte[] packet) + protected virtual void OnRTPPacketReceived(UdpReceiver receiver, int localPort, IPEndPoint remoteEndPoint, ReadOnlySpan packet) { - if (packet?.Length > 0) + if (packet.Length > 0) { LastRtpDestination = remoteEndPoint; OnRTPDataReceived?.Invoke(localPort, remoteEndPoint, packet); @@ -585,7 +746,7 @@ protected virtual void OnRTPPacketReceived(UdpReceiver receiver, int localPort, /// The local port it was received on. /// The remote end point of the sender. /// The raw packet received which should always be an RTCP packet. - private void OnControlPacketReceived(UdpReceiver receiver, int localPort, IPEndPoint remoteEndPoint, byte[] packet) + private void OnControlPacketReceived(UdpReceiver receiver, int localPort, IPEndPoint remoteEndPoint, ReadOnlySpan packet) { LastControlDestination = remoteEndPoint; OnControlDataReceived?.Invoke(localPort, remoteEndPoint, packet); diff --git a/src/net/RTP/RTPHeader.cs b/src/net/RTP/RTPHeader.cs index bb56a9d7a..afb1ab0ba 100644 --- a/src/net/RTP/RTPHeader.cs +++ b/src/net/RTP/RTPHeader.cs @@ -64,29 +64,17 @@ public RTPHeader() /// Extract and load the RTP header from an RTP packet. /// /// - public RTPHeader(byte[] packet) + public RTPHeader(ReadOnlySpan packet) { if (packet.Length < MIN_HEADER_LEN) { throw new ApplicationException("The packet did not contain the minimum number of bytes for an RTP header packet."); } - UInt16 firstWord = BitConverter.ToUInt16(packet, 0); - - if (BitConverter.IsLittleEndian) - { - firstWord = NetConvert.DoReverseEndian(firstWord); - SequenceNumber = NetConvert.DoReverseEndian(BitConverter.ToUInt16(packet, 2)); - Timestamp = NetConvert.DoReverseEndian(BitConverter.ToUInt32(packet, 4)); - SyncSource = NetConvert.DoReverseEndian(BitConverter.ToUInt32(packet, 8)); - } - else - { - SequenceNumber = BitConverter.ToUInt16(packet, 2); - Timestamp = BitConverter.ToUInt32(packet, 4); - SyncSource = BitConverter.ToUInt32(packet, 8); - } - + UInt16 firstWord = BinaryPrimitives.ReadUInt16BigEndian(packet); + SequenceNumber = BinaryPrimitives.ReadUInt16BigEndian(packet.Slice(2)); + Timestamp = BinaryPrimitives.ReadUInt32BigEndian(packet.Slice(4)); + SyncSource = BinaryPrimitives.ReadUInt32BigEndian(packet.Slice(8)); Version = firstWord >> 14; PaddingFlag = (firstWord >> 13) & 0x1; @@ -101,25 +89,16 @@ public RTPHeader(byte[] packet) if (HeaderExtensionFlag == 1 && (packet.Length >= (headerAndCSRCLength + 4))) { - if (BitConverter.IsLittleEndian) - { - ExtensionProfile = NetConvert.DoReverseEndian(BitConverter.ToUInt16(packet, 12 + 4 * CSRCCount)); - headerExtensionLength += 2; - ExtensionLength = NetConvert.DoReverseEndian(BitConverter.ToUInt16(packet, 14 + 4 * CSRCCount)); - headerExtensionLength += 2 + ExtensionLength * 4; - } - else - { - ExtensionProfile = BitConverter.ToUInt16(packet, 12 + 4 * CSRCCount); - headerExtensionLength += 2; - ExtensionLength = BitConverter.ToUInt16(packet, 14 + 4 * CSRCCount); - headerExtensionLength += 2 + ExtensionLength * 4; - } + + ExtensionProfile = BinaryPrimitives.ReadUInt16BigEndian(packet.Slice(12 + 4 * CSRCCount)); + headerExtensionLength += 2; + ExtensionLength = BinaryPrimitives.ReadUInt16BigEndian(packet.Slice(14 + 4 * CSRCCount)); + headerExtensionLength += 2 + ExtensionLength * 4; if (ExtensionLength > 0 && packet.Length >= (headerAndCSRCLength + 4 + ExtensionLength * 4)) { ExtensionPayload = new byte[ExtensionLength * 4]; - Buffer.BlockCopy(packet, headerAndCSRCLength + 4, ExtensionPayload, 0, ExtensionLength * 4); + packet.Slice(headerAndCSRCLength + 4, ExtensionPayload.Length).CopyTo(ExtensionPayload); } } diff --git a/src/net/RTP/RTPPacket.cs b/src/net/RTP/RTPPacket.cs index 4da11bf74..5cdd3caac 100644 --- a/src/net/RTP/RTPPacket.cs +++ b/src/net/RTP/RTPPacket.cs @@ -34,11 +34,11 @@ public RTPPacket(int payloadSize) Payload = new byte[payloadSize]; } - public RTPPacket(byte[] packet) + public RTPPacket(ReadOnlySpan packet) { Header = new RTPHeader(packet); Payload = new byte[Header.PayloadSize]; - Array.Copy(packet, Header.Length, Payload, 0, Payload.Length); + packet.Slice(Header.Length, Header.PayloadSize).CopyTo(Payload); } public byte[] GetBytes() diff --git a/src/net/RTP/RTPSession.cs b/src/net/RTP/RTPSession.cs index a9da214e1..2469d3122 100644 --- a/src/net/RTP/RTPSession.cs +++ b/src/net/RTP/RTPSession.cs @@ -20,6 +20,7 @@ //----------------------------------------------------------------------------- using System; +using System.Buffers.Binary; using System.Collections.Generic; using System.Linq; using System.Net; @@ -33,7 +34,7 @@ namespace SIPSorcery.Net { - public delegate int ProtectRtpPacket(byte[] payload, int length, out int outputBufferLength); + public delegate int ProtectRtpPacket(Span payload, int length, out int outputBufferLength); public enum SetDescriptionResultEnum { @@ -333,17 +334,19 @@ public bool IsSecureContextReady() /// public int MaxReconstructedVideoFrameSize { get => VideoStream.MaxReconstructedVideoFrameSize; set => VideoStream.MaxReconstructedVideoFrameSize = value; } + Once isClosed; /// /// Indicates whether the session has been closed. Once a session is closed it cannot /// be restarted. /// - public bool IsClosed { get; private set; } + public bool IsClosed => isClosed.HasOccurred; + Once isStarted; /// /// Indicates whether the session has been started. Starting a session tells the RTP /// socket to start receiving, /// - public bool IsStarted { get; private set; } + public bool IsStarted => isStarted.HasOccurred; /// /// Indicates whether this session is using audio. @@ -1923,11 +1926,8 @@ protected List GetMediaStreams() /// public virtual Task Start() { - if (!IsStarted) + if (isStarted.TryMarkOccurred()) { - IsStarted = true; - - foreach (var audioStream in AudioStreamList) { if (audioStream.HasAudio && audioStream.RtcpSession != null && audioStream.LocalTrack.StreamStatus != MediaStreamStatusEnum.Inactive) @@ -1997,11 +1997,8 @@ public Task SendDtmfEvent(RTPEvent rtpEvent, CancellationToken cancellationToken /// public virtual void Close(string reason) { - if (!IsClosed) + if (isClosed.TryMarkOccurred()) { - IsClosed = true; - - foreach (var audioStream in AudioStreamList) { if (audioStream != null) @@ -2043,7 +2040,7 @@ public virtual void Close(string reason) } } - protected void OnReceive(int localPort, IPEndPoint remoteEndPoint, byte[] buffer) + protected void OnReceive(int localPort, IPEndPoint remoteEndPoint, ReadOnlySpan buffer) { if (remoteEndPoint.Address.IsIPv4MappedToIPv6) { @@ -2053,7 +2050,7 @@ protected void OnReceive(int localPort, IPEndPoint remoteEndPoint, byte[] buffer } // Quick sanity check on whether this is not an RTP or RTCP packet. - if (buffer?.Length > RTPHeader.MIN_HEADER_LEN && buffer[0] >= 128 && buffer[0] <= 191) + if (buffer.Length > RTPHeader.MIN_HEADER_LEN && buffer[0] >= 128 && buffer[0] <= 191) { if ((rtpSessionConfig.IsSecure || rtpSessionConfig.UseSdpCryptoNegotiation) && !IsSecureContextReady()) { @@ -2082,23 +2079,17 @@ protected void OnReceive(int localPort, IPEndPoint remoteEndPoint, byte[] buffer } } - private void OnReceiveRTCPPacket(int localPort, IPEndPoint remoteEndPoint, byte[] buffer) + private void OnReceiveRTCPPacket(int localPort, IPEndPoint remoteEndPoint, ReadOnlySpan bufferRO) { //logger.LogDebug($"RTCP packet received from {remoteEndPoint} {buffer.HexStr()}"); #region RTCP packet. + Span buffer = stackalloc byte[bufferRO.Length]; + bufferRO.CopyTo(buffer); // Get the SSRC in order to be able to figure out which media type // This will let us choose the apropriate unprotect methods - uint ssrc; - if (BitConverter.IsLittleEndian) - { - ssrc = NetConvert.DoReverseEndian(BitConverter.ToUInt32(buffer, 4)); - } - else - { - ssrc = BitConverter.ToUInt32(buffer, 4); - } + uint ssrc = BinaryPrimitives.ReadUInt32BigEndian(buffer.Slice(4)); MediaStream mediaStream = GetMediaStream(ssrc); if (mediaStream != null) @@ -2114,7 +2105,7 @@ private void OnReceiveRTCPPacket(int localPort, IPEndPoint remoteEndPoint, byte[ } else { - buffer = buffer.Take(outBufLen).ToArray(); + buffer = buffer.Slice(0, outBufLen); } } } @@ -2124,7 +2115,6 @@ private void OnReceiveRTCPPacket(int localPort, IPEndPoint remoteEndPoint, byte[ } var rtcpPkt = new RTCPCompoundPacket(buffer); - if (rtcpPkt != null) { mediaStream = GetMediaStream(rtcpPkt); if (rtcpPkt.Bye != null) @@ -2182,19 +2172,15 @@ private void OnReceiveRTCPPacket(int localPort, IPEndPoint remoteEndPoint, byte[ } } } - else - { - logger.LogWarning("Failed to parse RTCP compound report."); - } #endregion } - private void OnReceiveRTPPacket(int localPort, IPEndPoint remoteEndPoint, byte[] buffer) + private void OnReceiveRTPPacket(int localPort, IPEndPoint remoteEndPoint, ReadOnlySpan bufferRO) { if (!IsClosed) { - var hdr = new RTPHeader(buffer); + var hdr = new RTPHeader(bufferRO); MediaStream mediaStream = GetMediaStream(hdr.SyncSource); @@ -2212,10 +2198,14 @@ private void OnReceiveRTPPacket(int localPort, IPEndPoint remoteEndPoint, byte[] hdr.ReceivedTime = DateTime.Now; if (mediaStream.MediaType == SDPMediaTypesEnum.audio) { + Span buffer = stackalloc byte[bufferRO.Length]; + bufferRO.CopyTo(buffer); mediaStream.OnReceiveRTPPacket(hdr, localPort, remoteEndPoint, buffer, null); } else if (mediaStream.MediaType == SDPMediaTypesEnum.video) { + Span buffer = stackalloc byte[bufferRO.Length]; + bufferRO.CopyTo(buffer); mediaStream.OnReceiveRTPPacket(hdr, localPort, remoteEndPoint, buffer, mediaStream as VideoStream); } } @@ -2315,6 +2305,7 @@ private MediaStream GetMediaStream(uint ssrc) { if (!HasVideo) { + logger.LogDebug("An RTP packet with SSRC {ssrc} force matched to the only audio stream.", ssrc); return AudioStream; } } @@ -2322,10 +2313,17 @@ private MediaStream GetMediaStream(uint ssrc) { if (HasVideo) { + logger.LogDebug("An RTP packet with SSRC {ssrc} force matched to the only video stream.", ssrc); return VideoStream; } } + if (ssrc == 1 || ssrc == RTCP_RR_NOSTREAM_SSRC) + { + logger.LogDebug("An RTP packet with SSRC {ssrc} force matched to the primary stream {Stream}.", ssrc, PrimaryStream); + return PrimaryStream; + } + return null; } diff --git a/src/net/SCTP/Chunks/SctpChunk.cs b/src/net/SCTP/Chunks/SctpChunk.cs index 6f324a6c4..b96ee0338 100644 --- a/src/net/SCTP/Chunks/SctpChunk.cs +++ b/src/net/SCTP/Chunks/SctpChunk.cs @@ -132,7 +132,7 @@ public SctpChunkType? KnownType { get { - if (Enum.IsDefined(typeof(SctpChunkType), ChunkType)) + if (((SctpChunkType)ChunkType).IsDefined()) { return (SctpChunkType)ChunkType; } @@ -147,7 +147,7 @@ public SctpChunkType? KnownType /// Records any unrecognised parameters received from the remote peer and are classified /// as needing to be reported. These can be sent back to the remote peer if needed. /// - public List UnrecognizedPeerParameters = new List(); + public IReadOnlyList UnrecognizedPeerParameters = Array.Empty(); public SctpChunk(SctpChunkType chunkType, byte chunkFlags = 0x00) { @@ -185,11 +185,11 @@ public virtual ushort GetChunkLength(bool padded) /// The buffer holding the serialised chunk. /// The position in the buffer that indicates the start of the chunk. /// The chunk length value. - public ushort ParseFirstWord(byte[] buffer, int posn) + public ushort ParseFirstWord(ReadOnlySpan buffer, int posn) { ChunkType = buffer[posn]; ChunkFlags = buffer[posn + 1]; - ushort chunkLength = NetConvert.ParseUInt16(buffer, posn + 2); + ushort chunkLength = NetConvert.ParseUInt16(buffer.Slice(posn + 2)); if (chunkLength > 0 && buffer.Length < posn + chunkLength) { @@ -208,12 +208,14 @@ public ushort ParseFirstWord(byte[] buffer, int posn) /// /// The buffer to write the chunk header to. /// The position in the buffer to write at. - /// The padded length of this chunk. - protected void WriteChunkHeader(byte[] buffer, int posn) + /// Unpadded length of this chunk. + protected ushort WriteChunkHeader(Span buffer, int posn) { buffer[posn] = ChunkType; buffer[posn + 1] = ChunkFlags; - NetConvert.ToBuffer(GetChunkLength(false), buffer, posn + 2); + ushort length = GetChunkLength(false); + NetConvert.ToBuffer(length, buffer, posn + 2); + return length; } /// @@ -225,18 +227,25 @@ protected void WriteChunkHeader(byte[] buffer, int posn) /// must have the required space already allocated. /// The position in the buffer to write to. /// The number of bytes, including padding, written to the buffer. - public virtual ushort WriteTo(byte[] buffer, int posn) + public virtual ushort WriteTo(Span buffer, int posn) { WriteChunkHeader(buffer, posn); if (ChunkValue?.Length > 0) { - Buffer.BlockCopy(ChunkValue, 0, buffer, posn + SCTP_CHUNK_HEADER_LENGTH, ChunkValue.Length); + ChunkValue.CopyTo(buffer.Slice(posn + SCTP_CHUNK_HEADER_LENGTH)); } return GetChunkLength(true); } + internal SctpChunkView View() + { + byte[] bytes = new byte[GetChunkLength(padded: true)]; + ushort written = WriteTo(bytes, 0); + return new SctpChunkView(bytes.AsSpan().Slice(0, written)); + } + /// /// Handler for processing an unrecognised chunk parameter. /// @@ -255,12 +264,14 @@ public bool GotUnrecognisedParameter(SctpTlvChunkParameter chunkParameter) break; case SctpUnrecognisedParameterActions.StopAndReport: stop = true; - UnrecognizedPeerParameters.Add(chunkParameter); + UnrecognizedPeerParameters = UnrecognizedPeerParameters as List ?? []; + ((List)UnrecognizedPeerParameters).Add(chunkParameter); break; case SctpUnrecognisedParameterActions.Skip: break; case SctpUnrecognisedParameterActions.SkipAndReport: - UnrecognizedPeerParameters.Add(chunkParameter); + UnrecognizedPeerParameters = UnrecognizedPeerParameters as List ?? []; + ((List)UnrecognizedPeerParameters).Add(chunkParameter); break; } @@ -276,14 +287,14 @@ public bool GotUnrecognisedParameter(SctpTlvChunkParameter chunkParameter) /// The buffer holding the serialised chunk. /// The position to start parsing at. /// An SCTP chunk instance. - public static SctpChunk ParseBaseChunk(byte[] buffer, int posn) + public static SctpChunk ParseBaseChunk(ReadOnlySpan buffer, int posn) { var chunk = new SctpChunk(); ushort chunkLength = chunk.ParseFirstWord(buffer, posn); if (chunkLength > SCTP_CHUNK_HEADER_LENGTH) { chunk.ChunkValue = new byte[chunkLength - SCTP_CHUNK_HEADER_LENGTH]; - Buffer.BlockCopy(buffer, posn + SCTP_CHUNK_HEADER_LENGTH, chunk.ChunkValue, 0, chunk.ChunkValue.Length); + buffer.Slice(posn + SCTP_CHUNK_HEADER_LENGTH, chunk.ChunkValue.Length).CopyTo(chunk.ChunkValue); } return chunk; @@ -298,7 +309,7 @@ public static SctpChunk ParseBaseChunk(byte[] buffer, int posn) /// parameters from. /// The length of the TLV chunk parameters in the buffer. /// A list of chunk parameters. Can be empty. - public static IEnumerable GetParameters(byte[] buffer, int posn, int length) + public static void GetParameters(ReadOnlySpan buffer, int posn, int length, Func onParam) { int paramPosn = posn; @@ -306,7 +317,10 @@ public static IEnumerable GetParameters(byte[] buffer, in { var chunkParam = SctpTlvChunkParameter.ParseTlvParameter(buffer, paramPosn); - yield return chunkParam; + if (!onParam(chunkParam)) + { + break; + } paramPosn += chunkParam.GetParameterLength(true); } @@ -318,7 +332,7 @@ public static IEnumerable GetParameters(byte[] buffer, in /// The buffer holding the serialised chunk. /// The position to start parsing at. /// An SCTP chunk instance. - public static SctpChunk Parse(byte[] buffer, int posn) + public static SctpChunk Parse(ReadOnlySpan buffer, int posn) { if (buffer.Length < posn + SCTP_CHUNK_HEADER_LENGTH) { @@ -327,7 +341,7 @@ public static SctpChunk Parse(byte[] buffer, int posn) byte chunkType = buffer[posn]; - if (Enum.IsDefined(typeof(SctpChunkType), chunkType)) + if (((SctpChunkType)chunkType).IsDefined()) { switch ((SctpChunkType)chunkType) { @@ -369,7 +383,7 @@ public static SctpChunk Parse(byte[] buffer, int posn) /// The start position of the serialised chunk. /// If true the length field will be padded to a 4 byte boundary. /// The padded length of the serialised chunk. - public static uint GetChunkLengthFromHeader(byte[] buffer, int posn, bool padded) + public static uint GetChunkLengthFromHeader(ReadOnlySpan buffer, int posn, bool padded) { ushort len = NetConvert.ParseUInt16(buffer, posn + 2); return (padded) ? SctpPadding.PadTo4ByteBoundary(len) : len; @@ -389,11 +403,10 @@ public static SctpUnrecognisedChunkActions GetUnrecognisedChunkAction(ushort chu /// The buffer containing the chunk. /// The position in the buffer that the unrecognised chunk starts. /// A new buffer containing a copy of the chunk. - public static byte[] CopyUnrecognisedChunk(byte[] buffer, int posn) + public static ReadOnlySpan CopyUnrecognisedChunk(ReadOnlySpan buffer, int posn) { - byte[] unrecognised = new byte[SctpChunk.GetChunkLengthFromHeader(buffer, posn, true)]; - Buffer.BlockCopy(buffer, posn, unrecognised, 0, unrecognised.Length); - return unrecognised; + uint length = GetChunkLengthFromHeader(buffer, posn, true); + return buffer.Slice(posn, checked((int)length)); } } } diff --git a/src/net/SCTP/Chunks/SctpDataChunk.cs b/src/net/SCTP/Chunks/SctpDataChunk.cs index 2b2a15a48..79a3565f4 100644 --- a/src/net/SCTP/Chunks/SctpDataChunk.cs +++ b/src/net/SCTP/Chunks/SctpDataChunk.cs @@ -18,11 +18,13 @@ //----------------------------------------------------------------------------- using System; +using System.Buffers; +using System.Diagnostics; using SIPSorcery.Sys; namespace SIPSorcery.Net { - public class SctpDataChunk : SctpChunk + public class SctpDataChunk : SctpChunk, IDisposable { /// /// An empty data chunk. The main use is to indicate a DATA chunk has @@ -78,13 +80,22 @@ public class SctpDataChunk : SctpChunk /// public uint PPID; + BorrowedArray userData; /// /// This is the payload user data. /// - public byte[] UserData; + public Span UserData => userData; + public int UserDataLength => userData.Length; + internal struct Timestamp + { + readonly long ticks; + Timestamp(long ticks) => this.ticks = ticks; + public readonly double Milliseconds => ticks / (Stopwatch.Frequency / 1000.0); + public static Timestamp Now => new(Stopwatch.GetTimestamp()); + } // These properties are used by the data sender. - internal DateTime LastSentAt; + internal Timestamp LastSentAt; internal int SendCount; private SctpDataChunk() @@ -113,9 +124,9 @@ public SctpDataChunk( ushort streamID, ushort seqnum, uint ppid, - byte[] data) : base(SctpChunkType.DATA) + ReadOnlySpan data) : base(SctpChunkType.DATA) { - if (data == null || data.Length == 0) + if (data.Length == 0) { throw new ArgumentNullException("data", "The SctpDataChunk data parameter cannot be empty."); } @@ -127,7 +138,7 @@ public SctpDataChunk( StreamID = streamID; StreamSeqNum = seqnum; PPID = ppid; - UserData = data; + userData.Set(data); ChunkFlags = (byte)( (Unordered ? 0x04 : 0x0) + @@ -143,7 +154,7 @@ public SctpDataChunk( public override ushort GetChunkLength(bool padded) { ushort len = SCTP_CHUNK_HEADER_LENGTH + FIXED_PARAMETERS_LENGTH; - len += (ushort)(UserData != null ? UserData.Length : 0); + len += (ushort)userData.Length; return (padded) ? SctpPadding.PadTo4ByteBoundary(len) : len; } @@ -154,9 +165,9 @@ public override ushort GetChunkLength(bool padded) /// must have the required space already allocated. /// The position in the buffer to write to. /// The number of bytes, including padding, written to the buffer. - public override ushort WriteTo(byte[] buffer, int posn) + public override ushort WriteTo(Span buffer, int posn) { - WriteChunkHeader(buffer, posn); + ushort length = WriteChunkHeader(buffer, posn); // Write fixed parameters. int startPosn = posn + SCTP_CHUNK_HEADER_LENGTH; @@ -168,17 +179,14 @@ public override ushort WriteTo(byte[] buffer, int posn) int userDataPosn = startPosn + FIXED_PARAMETERS_LENGTH; - if (UserData != null) - { - Buffer.BlockCopy(UserData, 0, buffer, userDataPosn, UserData.Length); - } + userData.DataMayBeEmpty.CopyTo(buffer.Slice(userDataPosn)); - return GetChunkLength(true); + return SctpPadding.PadTo4ByteBoundary(length); } public bool IsEmpty() { - return UserData == null; + return userData.IsNull(); } /// @@ -186,7 +194,7 @@ public bool IsEmpty() /// /// The buffer holding the serialised chunk. /// The position to start parsing at. - public static SctpDataChunk ParseChunk(byte[] buffer, int posn) + public static SctpDataChunk ParseChunk(ReadOnlySpan buffer, int posn, ArrayPool pool = null) { var dataChunk = new SctpDataChunk(); ushort chunkLen = dataChunk.ParseFirstWord(buffer, posn); @@ -212,11 +220,32 @@ public static SctpDataChunk ParseChunk(byte[] buffer, int posn) if (userDataLen > 0) { - dataChunk.UserData = new byte[userDataLen]; - Buffer.BlockCopy(buffer, userDataPosn, dataChunk.UserData, 0, dataChunk.UserData.Length); + if (pool != null) + { + dataChunk.userData.Set(buffer.Slice(userDataPosn, userDataLen), pool); + } + else + { + dataChunk.userData.Set(new byte[userDataLen]); + buffer.Slice(userDataPosn, userDataLen).CopyTo(dataChunk.UserData); + } } return dataChunk; } + + public void Dispose() + { + userData.Dispose(); + } + + [Flags] + public enum Flags: byte + { + None = 0x00, + Unordered = 0x04, + Beginning = 0x02, + Ending = 0x01 + } } } diff --git a/src/net/SCTP/Chunks/SctpErrorCauses.cs b/src/net/SCTP/Chunks/SctpErrorCauses.cs index 8159d1984..54c86ba2c 100644 --- a/src/net/SCTP/Chunks/SctpErrorCauses.cs +++ b/src/net/SCTP/Chunks/SctpErrorCauses.cs @@ -19,6 +19,7 @@ //----------------------------------------------------------------------------- using System; +using System.Buffers.Binary; using System.Collections.Generic; using System.Text; using SIPSorcery.Sys; @@ -49,7 +50,7 @@ public interface ISctpErrorCause { SctpErrorCauseCode CauseCode { get; } ushort GetErrorCauseLength(bool padded); - int WriteTo(byte[] buffer, int posn); + int WriteTo(Span buffer, int posn); } /// @@ -87,7 +88,7 @@ public SctpCauseOnlyError(SctpErrorCauseCode causeCode) public ushort GetErrorCauseLength(bool padded) => ERROR_CAUSE_LENGTH; - public int WriteTo(byte[] buffer, int posn) + public int WriteTo(Span buffer, int posn) { NetConvert.ToBuffer((ushort)CauseCode, buffer, posn); NetConvert.ToBuffer(ERROR_CAUSE_LENGTH, buffer, posn + 2); @@ -115,7 +116,7 @@ public struct SctpErrorInvalidStreamIdentifier : ISctpErrorCause public ushort GetErrorCauseLength(bool padded) => ERROR_CAUSE_LENGTH; - public int WriteTo(byte[] buffer, int posn) + public int WriteTo(Span buffer, int posn) { NetConvert.ToBuffer((ushort)CauseCode, buffer, posn); NetConvert.ToBuffer(ERROR_CAUSE_LENGTH, buffer, posn + 2); @@ -143,7 +144,7 @@ public ushort GetErrorCauseLength(bool padded) return padded ? SctpPadding.PadTo4ByteBoundary(len) : len; } - public int WriteTo(byte[] buffer, int posn) + public int WriteTo(Span buffer, int posn) { var len = GetErrorCauseLength(true); NetConvert.ToBuffer((ushort)CauseCode, buffer, posn); @@ -180,7 +181,7 @@ public struct SctpErrorStaleCookieError : ISctpErrorCause public ushort GetErrorCauseLength(bool padded) => ERROR_CAUSE_LENGTH; - public int WriteTo(byte[] buffer, int posn) + public int WriteTo(Span buffer, int posn) { NetConvert.ToBuffer((ushort)CauseCode, buffer, posn); NetConvert.ToBuffer(ERROR_CAUSE_LENGTH, buffer, posn + 2); @@ -214,15 +215,12 @@ public ushort GetErrorCauseLength(bool padded) return padded ? SctpPadding.PadTo4ByteBoundary(len) : len; } - public int WriteTo(byte[] buffer, int posn) + public int WriteTo(Span buffer, int posn) { var len = GetErrorCauseLength(true); NetConvert.ToBuffer((ushort)CauseCode, buffer, posn); NetConvert.ToBuffer(len, buffer, posn + 2); - if (UnresolvableAddress != null) - { - Buffer.BlockCopy(UnresolvableAddress, 0, buffer, posn + 4, UnresolvableAddress.Length); - } + UnresolvableAddress?.CopyTo(buffer.Slice(posn + 4)); return len; } } @@ -251,15 +249,12 @@ public ushort GetErrorCauseLength(bool padded) return padded ? SctpPadding.PadTo4ByteBoundary(len) : len; } - public int WriteTo(byte[] buffer, int posn) + public int WriteTo(Span buffer, int posn) { var len = GetErrorCauseLength(true); NetConvert.ToBuffer((ushort)CauseCode, buffer, posn); NetConvert.ToBuffer(len, buffer, posn + 2); - if (UnrecognizedChunk != null) - { - Buffer.BlockCopy(UnrecognizedChunk, 0, buffer, posn + 4, UnrecognizedChunk.Length); - } + UnrecognizedChunk?.CopyTo(buffer.Slice(posn + 4)); return len; } } @@ -292,15 +287,12 @@ public ushort GetErrorCauseLength(bool padded) return padded ? SctpPadding.PadTo4ByteBoundary(len) : len; } - public int WriteTo(byte[] buffer, int posn) + public int WriteTo(Span buffer, int posn) { var len = GetErrorCauseLength(true); NetConvert.ToBuffer((ushort)CauseCode, buffer, posn); NetConvert.ToBuffer(len, buffer, posn + 2); - if (UnrecognizedParameters != null) - { - Buffer.BlockCopy(UnrecognizedParameters, 0, buffer, posn + 4, UnrecognizedParameters.Length); - } + UnrecognizedParameters?.CopyTo(buffer.Slice(posn + 4)); return len; } } @@ -326,7 +318,7 @@ public struct SctpErrorNoUserData : ISctpErrorCause public ushort GetErrorCauseLength(bool padded) => ERROR_CAUSE_LENGTH; - public int WriteTo(byte[] buffer, int posn) + public int WriteTo(Span buffer, int posn) { NetConvert.ToBuffer((ushort)CauseCode, buffer, posn); NetConvert.ToBuffer(ERROR_CAUSE_LENGTH, buffer, posn + 2); @@ -360,15 +352,12 @@ public ushort GetErrorCauseLength(bool padded) return padded ? SctpPadding.PadTo4ByteBoundary(len) : len; } - public int WriteTo(byte[] buffer, int posn) + public int WriteTo(Span buffer, int posn) { var len = GetErrorCauseLength(true); NetConvert.ToBuffer((ushort)CauseCode, buffer, posn); NetConvert.ToBuffer(len, buffer, posn + 2); - if (NewAddressTLVs != null) - { - Buffer.BlockCopy(NewAddressTLVs, 0, buffer, posn + 4, NewAddressTLVs.Length); - } + NewAddressTLVs?.CopyTo(buffer.Slice(posn + 4)); return len; } } @@ -395,7 +384,7 @@ public ushort GetErrorCauseLength(bool padded) return padded ? SctpPadding.PadTo4ByteBoundary(len) : len; } - public int WriteTo(byte[] buffer, int posn) + public int WriteTo(Span buffer, int posn) { var len = GetErrorCauseLength(true); NetConvert.ToBuffer((ushort)CauseCode, buffer, posn); @@ -403,7 +392,7 @@ public int WriteTo(byte[] buffer, int posn) if (!string.IsNullOrEmpty(AbortReason)) { var reasonBuffer = Encoding.UTF8.GetBytes(AbortReason); - Buffer.BlockCopy(reasonBuffer, 0, buffer, posn + 4, reasonBuffer.Length); + reasonBuffer.CopyTo(buffer.Slice(posn + 4)); } return len; } @@ -432,7 +421,7 @@ public ushort GetErrorCauseLength(bool padded) return padded ? SctpPadding.PadTo4ByteBoundary(len) : len; } - public int WriteTo(byte[] buffer, int posn) + public int WriteTo(Span buffer, int posn) { var len = GetErrorCauseLength(true); NetConvert.ToBuffer((ushort)CauseCode, buffer, posn); @@ -440,7 +429,7 @@ public int WriteTo(byte[] buffer, int posn) if (!string.IsNullOrEmpty(AdditionalInformation)) { var reasonBuffer = Encoding.UTF8.GetBytes(AdditionalInformation); - Buffer.BlockCopy(reasonBuffer, 0, buffer, posn + 4, reasonBuffer.Length); + reasonBuffer.CopyTo(buffer.Slice(posn + 4)); } return len; } diff --git a/src/net/SCTP/Chunks/SctpErrorChunk.cs b/src/net/SCTP/Chunks/SctpErrorChunk.cs index c79a5ae62..819b92439 100644 --- a/src/net/SCTP/Chunks/SctpErrorChunk.cs +++ b/src/net/SCTP/Chunks/SctpErrorChunk.cs @@ -107,7 +107,7 @@ public override ushort GetChunkLength(bool padded) /// must have the required space already allocated. /// The position in the buffer to write to. /// The number of bytes, including padding, written to the buffer. - public override ushort WriteTo(byte[] buffer, int posn) + public override ushort WriteTo(Span buffer, int posn) { WriteChunkHeader(buffer, posn); if (ErrorCauses != null && ErrorCauses.Count > 0) @@ -126,7 +126,7 @@ public override ushort WriteTo(byte[] buffer, int posn) /// /// The buffer holding the serialised chunk. /// The position to start parsing at. - public static SctpErrorChunk ParseChunk(byte[] buffer, int posn, bool isAbort) + public static SctpErrorChunk ParseChunk(ReadOnlySpan buffer, int posn, bool isAbort) { var errorChunk = (isAbort) ? new SctpAbortChunk(false) : new SctpErrorChunk(); ushort chunkLen = errorChunk.ParseFirstWord(buffer, posn); @@ -138,7 +138,7 @@ public static SctpErrorChunk ParseChunk(byte[] buffer, int posn, bool isAbort) { bool stopProcessing = false; - foreach (var varParam in GetParameters(buffer, paramPosn, paramsBufferLength)) + GetParameters(buffer, paramPosn, paramsBufferLength, varParam => { switch (varParam.ParameterType) { @@ -220,9 +220,10 @@ public static SctpErrorChunk ParseChunk(byte[] buffer, int posn, bool isAbort) { logger.LogWarning($"SCTP unrecognised parameter {varParam.ParameterType} for chunk type {SctpChunkType.ERROR} " + "indicated no further chunks should be processed."); - break; + return false; } - } + return true; + }); } return errorChunk; diff --git a/src/net/SCTP/Chunks/SctpInitChunk.cs b/src/net/SCTP/Chunks/SctpInitChunk.cs index 9c6983e05..219bfb90a 100644 --- a/src/net/SCTP/Chunks/SctpInitChunk.cs +++ b/src/net/SCTP/Chunks/SctpInitChunk.cs @@ -286,7 +286,7 @@ public override ushort GetChunkLength(bool padded) /// must have the required space already allocated. /// The position in the buffer to write to. /// The number of bytes, including padding, written to the buffer. - public override ushort WriteTo(byte[] buffer, int posn) + public override ushort WriteTo(Span buffer, int posn) { WriteChunkHeader(buffer, posn); @@ -319,7 +319,7 @@ public override ushort WriteTo(byte[] buffer, int posn) /// /// The buffer holding the serialised chunk. /// The position to start parsing at. - public static SctpInitChunk ParseChunk(byte[] buffer, int posn) + public static SctpInitChunk ParseChunk(ReadOnlySpan buffer, int posn) { var initChunk = new SctpInitChunk(); ushort chunkLen = initChunk.ParseFirstWord(buffer, posn); @@ -339,7 +339,7 @@ public static SctpInitChunk ParseChunk(byte[] buffer, int posn) { bool stopProcessing = false; - foreach (var varParam in GetParameters(buffer, paramPosn, paramsBufferLength)) + GetParameters(buffer, paramPosn, paramsBufferLength, varParam => { switch (varParam.ParameterType) { @@ -399,9 +399,10 @@ public static SctpInitChunk ParseChunk(byte[] buffer, int posn) { logger.LogWarning($"SCTP unrecognised parameter {varParam.ParameterType} for chunk type {initChunk.KnownType} " + "indicated no further chunks should be processed."); - break; + return false; } - } + return true; + }); } return initChunk; diff --git a/src/net/SCTP/Chunks/SctpSackChunk.cs b/src/net/SCTP/Chunks/SctpSackChunk.cs index 313170db6..d7c26e1d7 100644 --- a/src/net/SCTP/Chunks/SctpSackChunk.cs +++ b/src/net/SCTP/Chunks/SctpSackChunk.cs @@ -17,8 +17,10 @@ // BSD 3-Clause "New" or "Revised" License, see included LICENSE.md file. //----------------------------------------------------------------------------- -using System.Collections.Generic; +using System; using SIPSorcery.Sys; +using Small.Collections; +using TypeNum; namespace SIPSorcery.Net { @@ -31,8 +33,8 @@ namespace SIPSorcery.Net public class SctpSackChunk : SctpChunk { public const int FIXED_PARAMETERS_LENGTH = 12; - private const int GAP_REPORT_LENGTH = 4; - private const int DUPLICATE_TSN_LENGTH = 4; + internal const int GAP_REPORT_LENGTH = 4; + internal const int DUPLICATE_TSN_LENGTH = 4; /// /// This parameter contains the TSN of the last chunk received in @@ -50,13 +52,13 @@ public class SctpSackChunk : SctpChunk /// The gap ACK blocks. Each entry represents a gap in the forward out of order /// TSNs received. /// - public List GapAckBlocks = new List(); + public SmallList, SctpTsnGapBlock> GapAckBlocks = new (); /// /// Indicates the number of times a TSN was received in duplicate /// since the last SACK was sent. /// - public List DuplicateTSN = new List(); + public SmallList, uint> DuplicateTSN = new(); private SctpSackChunk() : base(SctpChunkType.SACK) { } @@ -95,7 +97,7 @@ public override ushort GetChunkLength(bool padded) /// must have the required space already allocated. /// The position in the buffer to write to. /// The number of bytes, including padding, written to the buffer. - public override ushort WriteTo(byte[] buffer, int posn) + public override ushort WriteTo(Span buffer, int posn) { WriteChunkHeader(buffer, posn); @@ -129,7 +131,7 @@ public override ushort WriteTo(byte[] buffer, int posn) /// /// The buffer holding the serialised chunk. /// The position to start parsing at. - public static SctpSackChunk ParseChunk(byte[] buffer, int posn) + public static SctpSackChunk ParseChunk(ReadOnlySpan buffer, int posn) { var sackChunk = new SctpSackChunk(); ushort chunkLen = sackChunk.ParseFirstWord(buffer, posn); diff --git a/src/net/SCTP/Chunks/SctpShutdownChunk.cs b/src/net/SCTP/Chunks/SctpShutdownChunk.cs index 61cdb0a31..954cc3980 100644 --- a/src/net/SCTP/Chunks/SctpShutdownChunk.cs +++ b/src/net/SCTP/Chunks/SctpShutdownChunk.cs @@ -17,6 +17,7 @@ // BSD 3-Clause "New" or "Revised" License, see included LICENSE.md file. //----------------------------------------------------------------------------- +using System; using SIPSorcery.Sys; namespace SIPSorcery.Net @@ -67,7 +68,7 @@ public override ushort GetChunkLength(bool padded) /// must have the required space already allocated. /// The position in the buffer to write to. /// The number of bytes, including padding, written to the buffer. - public override ushort WriteTo(byte[] buffer, int posn) + public override ushort WriteTo(Span buffer, int posn) { WriteChunkHeader(buffer, posn); NetConvert.ToBuffer(CumulativeTsnAck.GetValueOrDefault(), buffer, posn + SCTP_CHUNK_HEADER_LENGTH); @@ -79,7 +80,7 @@ public override ushort WriteTo(byte[] buffer, int posn) /// /// The buffer holding the serialised chunk. /// The position to start parsing at. - public static SctpShutdownChunk ParseChunk(byte[] buffer, int posn) + public static SctpShutdownChunk ParseChunk(ReadOnlySpan buffer, int posn) { var shutdownChunk = new SctpShutdownChunk(); shutdownChunk.CumulativeTsnAck = NetConvert.ParseUInt32(buffer, posn + SCTP_CHUNK_HEADER_LENGTH); diff --git a/src/net/SCTP/Chunks/SctpTlvChunkParameter.cs b/src/net/SCTP/Chunks/SctpTlvChunkParameter.cs index 12f5e8f3c..7e0d1c46f 100644 --- a/src/net/SCTP/Chunks/SctpTlvChunkParameter.cs +++ b/src/net/SCTP/Chunks/SctpTlvChunkParameter.cs @@ -157,7 +157,7 @@ public virtual ushort GetParameterLength(bool padded) /// /// The buffer to write the chunk parameter header to. /// The position in the buffer to write at. - protected void WriteParameterHeader(byte[] buffer, int posn) + protected void WriteParameterHeader(Span buffer, int posn) { NetConvert.ToBuffer(ParameterType, buffer, posn); NetConvert.ToBuffer(GetParameterLength(false), buffer, posn + 2); @@ -172,13 +172,13 @@ protected void WriteParameterHeader(byte[] buffer, int posn) /// must have the required space already allocated. /// The position in the buffer to write to. /// The number of bytes, including padding, written to the buffer. - public virtual int WriteTo(byte[] buffer, int posn) + public virtual int WriteTo(Span buffer, int posn) { WriteParameterHeader(buffer, posn); if (ParameterValue?.Length > 0) { - Buffer.BlockCopy(ParameterValue, 0, buffer, posn + SCTP_PARAMETER_HEADER_LENGTH, ParameterValue.Length); + ParameterValue.CopyTo(buffer.Slice(posn + SCTP_PARAMETER_HEADER_LENGTH)); } return GetParameterLength(true); @@ -201,16 +201,29 @@ public byte[] GetBytes() /// /// The buffer holding the serialised chunk parameter. /// The position in the buffer that indicates the start of the chunk parameter. - public ushort ParseFirstWord(byte[] buffer, int posn) + public ushort ParseFirstWord(ReadOnlySpan buffer, int posn) { - ParameterType = NetConvert.ParseUInt16(buffer, posn); - ushort paramLen = NetConvert.ParseUInt16(buffer, posn + 2); + ushort len = ParseFirstWord(buffer.Slice(posn), out ushort type); + ParameterType = type; + return len; + } + + /// + /// The first 32 bits of all chunk parameters represent the type and length. This method + /// parses those fields and sets them on the current instance. + /// + /// The buffer holding the serialised chunk parameter. + /// The position in the buffer that indicates the start of the chunk parameter. + public static ushort ParseFirstWord(ReadOnlySpan buffer, out ushort type) + { + type = NetConvert.ParseUInt16(buffer, 0); + ushort paramLen = NetConvert.ParseUInt16(buffer, 2); - if (paramLen > 0 && buffer.Length < posn + paramLen) + if (paramLen > 0 && buffer.Length < paramLen) { // The buffer was not big enough to supply the specified chunk parameter. int bytesRequired = paramLen; - int bytesAvailable = buffer.Length - posn; + int bytesAvailable = buffer.Length; throw new ApplicationException($"The SCTP chunk parameter buffer was too short. " + $"Required {bytesRequired} bytes but only {bytesAvailable} available."); } @@ -224,7 +237,7 @@ public ushort ParseFirstWord(byte[] buffer, int posn) /// The buffer holding the serialised TLV chunk parameter. /// The position to start parsing at. /// An SCTP TLV chunk parameter instance. - public static SctpTlvChunkParameter ParseTlvParameter(byte[] buffer, int posn) + public static SctpTlvChunkParameter ParseTlvParameter(ReadOnlySpan buffer, int posn) { if (buffer.Length < posn + SCTP_PARAMETER_HEADER_LENGTH) { @@ -236,8 +249,7 @@ public static SctpTlvChunkParameter ParseTlvParameter(byte[] buffer, int posn) if (paramLen > SCTP_PARAMETER_HEADER_LENGTH) { tlvParam.ParameterValue = new byte[paramLen - SCTP_PARAMETER_HEADER_LENGTH]; - Buffer.BlockCopy(buffer, posn + SCTP_PARAMETER_HEADER_LENGTH, tlvParam.ParameterValue, - 0, tlvParam.ParameterValue.Length); + buffer.Slice(posn + SCTP_PARAMETER_HEADER_LENGTH, tlvParam.ParameterValue.Length).CopyTo(tlvParam.ParameterValue); } return tlvParam; } diff --git a/src/net/SCTP/SctpAssociation.cs b/src/net/SCTP/SctpAssociation.cs index e823cc776..4ed479675 100644 --- a/src/net/SCTP/SctpAssociation.cs +++ b/src/net/SCTP/SctpAssociation.cs @@ -366,38 +366,42 @@ internal void InitRemoteProperties( /// SCTP Association State Diagram: /// https://tools.ietf.org/html/rfc4960#section-4 /// - internal void OnPacketReceived(SctpPacket packet) + internal void OnPacketReceived(SctpPacketView packet) { if (_wasAborted) { - logger.LogWarning($"SCTP packet received but association has been aborted, ignoring."); + logger.LogWarning("SCTP packet received but association has been aborted, ignoring."); } else if (packet.Header.VerificationTag != VerificationTag) { - logger.LogWarning($"SCTP packet dropped due to wrong verification tag, expected " + - $"{VerificationTag} got {packet.Header.VerificationTag}."); + logger.LogWarning("SCTP packet dropped due to wrong verification tag, expected {Expected} got {Actual}.", + VerificationTag, packet.Header.VerificationTag); } else if (!_sctpTransport.IsPortAgnostic && packet.Header.DestinationPort != _sctpSourcePort) { - logger.LogWarning($"SCTP packet dropped due to wrong SCTP destination port, expected " + - $"{_sctpSourcePort} got {packet.Header.DestinationPort}."); + logger.LogWarning("SCTP packet dropped due to wrong SCTP destination port, expected {Expected} got {Actual}.", + _sctpSourcePort, packet.Header.DestinationPort); } else if (!_sctpTransport.IsPortAgnostic && packet.Header.SourcePort != _sctpDestinationPort) { - logger.LogWarning($"SCTP packet dropped due to wrong SCTP source port, expected " + - $"{_sctpDestinationPort} got {packet.Header.SourcePort}."); + logger.LogWarning("SCTP packet dropped due to wrong SCTP source port, expected {Expected} got {Actual}.", + _sctpDestinationPort, packet.Header.SourcePort); } else { - foreach (var chunk in packet.Chunks) + for (int chunkIndex = 0; chunkIndex < packet.ChunkCount; chunkIndex++) { - var chunkType = (SctpChunkType)chunk.ChunkType; + var chunk = packet[chunkIndex]; + var chunkType = chunk.Type; switch (chunkType) { case SctpChunkType.ABORT: - string abortReason = (chunk as SctpAbortChunk).GetAbortReason(); - logger.LogWarning($"SCTP packet ABORT chunk received from remote party, reason {abortReason}."); + var abortChunk = (SctpAbortChunk)chunk.AsChunk(); + string abortReason = abortChunk.GetAbortReason(); + var logLevel = abortChunk.ErrorCauses.Any(cause => cause.CauseCode == SctpErrorCauseCode.UserInitiatedAbort) + ? LogLevel.Debug : LogLevel.Warning; + logger.LogWarning("SCTP packet ABORT chunk received from remote party, reason {Message}.", abortReason); _wasAborted = true; OnAbortReceived?.Invoke(abortReason); break; @@ -423,19 +427,20 @@ internal void OnPacketReceived(SctpPacket packet) break; case SctpChunkType.DATA: - var dataChunk = chunk as SctpDataChunk; + var dataChunk = chunk; - if (dataChunk.UserData == null || dataChunk.UserData.Length == 0) + if (dataChunk.UserData.Length == 0) { // Fatal condition: // - If an endpoint receives a DATA chunk with no user data (i.e., the // Length field is set to 16), it MUST send an ABORT with error cause // set to "No User Data". (RFC4960 pg. 80) - Abort(new SctpErrorNoUserData { TSN = (chunk as SctpDataChunk).TSN }); + Abort(new SctpErrorNoUserData { TSN = dataChunk.TSN }); } else { - logger.LogTrace($"SCTP data chunk received on ID {ID} with TSN {dataChunk.TSN}, payload length {dataChunk.UserData.Length}, flags {dataChunk.ChunkFlags:X2}."); + logger.LogTrace("SCTP data chunk received on ID {ID} with TSN {TSN}, payload length {Length}, flags {Flags}.", + ID, dataChunk.TSN, dataChunk.UserData.Length, dataChunk.Flags); // A received data chunk can result in multiple data frames becoming available. // For example if a stream has out of order frames already received and the next @@ -451,23 +456,24 @@ internal void OnPacketReceived(SctpPacket packet) foreach (var frame in sortedFrames) { OnData?.Invoke(frame); + frame.Dispose(); } } break; case SctpChunkType.ERROR: - var errorChunk = chunk as SctpErrorChunk; - foreach (var err in errorChunk.ErrorCauses) + foreach (var code in chunk.GetErrorCodes()) { - logger.LogWarning($"SCTP error {err.CauseCode}."); + logger.LogWarning("SCTP error {Code}.", code); } break; case SctpChunkType.HEARTBEAT: // The HEARTBEAT ACK sends back the same chunk but with the type changed. - chunk.ChunkType = (byte)SctpChunkType.HEARTBEAT_ACK; - SendChunk(chunk); + var ack = chunk.AsChunk(); + ack.ChunkType = (byte)SctpChunkType.HEARTBEAT_ACK; + SendChunk(ack); break; case var ct when ct == SctpChunkType.INIT_ACK && State != SctpAssociationState.CookieWait: @@ -484,7 +490,7 @@ internal void OnPacketReceived(SctpPacket packet) _t1Init = null; } - var initAckChunk = chunk as SctpInitChunk; + var initAckChunk = (SctpInitChunk)chunk.AsChunk(); if (initAckChunk.InitiateTag == 0 || initAckChunk.NumberInboundStreams == 0 || @@ -530,11 +536,11 @@ internal void OnPacketReceived(SctpPacket packet) break; case var ct when ct == SctpChunkType.INIT_ACK && State != SctpAssociationState.CookieWait: - logger.LogWarning($"SCTP association received INIT_ACK chunk in wrong state of {State}, ignoring."); + logger.LogWarning("SCTP association received INIT_ACK chunk in wrong state of {State}, ignoring.", State); break; case SctpChunkType.SACK: - _dataSender.GotSack(chunk as SctpSackChunk); + _dataSender.GotSack(chunk); break; case var ct when ct == SctpChunkType.SHUTDOWN && State == SctpAssociationState.Established: @@ -560,7 +566,7 @@ internal void OnPacketReceived(SctpPacket packet) break; default: - logger.LogWarning($"SCTP association no rule for {chunkType} in state of {State}."); + logger.LogWarning("SCTP association no rule for {chunkType} in state of {State}.", chunkType, State); break; } } @@ -589,7 +595,7 @@ public void SendData(ushort streamID, uint ppid, string message) /// The stream ID to sent the data on. /// The payload protocol ID for the data. /// The byte data to send. - public void SendData(ushort streamID, uint ppid, byte[] data) + public void SendData(ushort streamID, uint ppid, ReadOnlySpan data) { if (_wasAborted) { @@ -707,8 +713,7 @@ private void SendInit() SetState(SctpAssociationState.CookieWait); - byte[] buffer = init.GetBytes(); - _sctpTransport.Send(ID, buffer, 0, buffer.Length); + SendPacket(init); _t1Init = new Timer(T1InitTimerExpired, init, T1_INIT_TIMER_MILLISECONDS, T1_INIT_TIMER_MILLISECONDS); } @@ -729,9 +734,7 @@ internal void SendChunk(SctpChunk chunk) pkt.AddChunk(chunk); - byte[] buffer = pkt.GetBytes(); - - _sctpTransport.Send(ID, buffer, 0, buffer.Length); + SendPacket(pkt); } } @@ -743,8 +746,7 @@ private void SendPacket(SctpPacket pkt) { if (!_wasAborted) { - byte[] buffer = pkt.GetBytes(); - _sctpTransport.Send(ID, buffer, 0, buffer.Length); + _sctpTransport.Send(ID, pkt); } } diff --git a/src/net/SCTP/SctpChunkView.cs b/src/net/SCTP/SctpChunkView.cs new file mode 100644 index 000000000..ba296bf3a --- /dev/null +++ b/src/net/SCTP/SctpChunkView.cs @@ -0,0 +1,214 @@ +using System; +using System.Buffers.Binary; +using System.Runtime.InteropServices; +using Microsoft.Extensions.Logging; + +using SIPSorcery.Sys; + +using static SIPSorcery.Net.SctpChunkType; + +namespace SIPSorcery.Net; + +public readonly ref struct SctpChunkView +{ + static readonly ILogger logger = LogFactory.CreateLogger(); + + readonly ReadOnlySpan buffer; + + public ReadOnlySpan Buffer => buffer; + + public SctpChunkType Type => (SctpChunkType)buffer[0]; + public SctpDataChunk.Flags Flags => (SctpDataChunk.Flags)buffer[1]; + public bool Unordered => (Flags & SctpDataChunk.Flags.Unordered) != default; + public bool Beginning => (Flags & SctpDataChunk.Flags.Beginning) != default; + public bool Ending => (Flags & SctpDataChunk.Flags.Ending) != default; + public ushort Length => BinaryPrimitives.ReadUInt16BigEndian(buffer.Slice(2, 2)); + public ReadOnlySpan Value => buffer.Slice(SctpChunk.SCTP_CHUNK_HEADER_LENGTH, Length - SctpChunk.SCTP_CHUNK_HEADER_LENGTH); + + #region Data Chunk + public uint TSN => BinaryPrimitives.ReadUInt32BigEndian(Value); + public ushort StreamID => BinaryPrimitives.ReadUInt16BigEndian(Value.Slice(4, 2)); + public ushort StreamSeqNum => BinaryPrimitives.ReadUInt16BigEndian(Value.Slice(6, 2)); + public uint PPID => BinaryPrimitives.ReadUInt32BigEndian(Value.Slice(8, 4)); + public ReadOnlySpan UserData + { + get + { + int pos = SctpChunk.SCTP_CHUNK_HEADER_LENGTH + SctpDataChunk.FIXED_PARAMETERS_LENGTH; + int len = Length - pos; + return buffer.Slice(pos, len); + } + } + #endregion Data Chunk + + #region SACK Chunk + public uint CumulativeTsnAck => BinaryPrimitives.ReadUInt32BigEndian(Value); + public uint ARwnd => BinaryPrimitives.ReadUInt32BigEndian(Value.Slice(4, 4)); + public ushort NumGapAckBlocks => BinaryPrimitives.ReadUInt16BigEndian(Value.Slice(8, 2)); + public ushort NumDuplicateTSNs => BinaryPrimitives.ReadUInt16BigEndian(Value.Slice(10, 2)); + public SctpTsnGapBlock GetTsnGapBlock(int index) + { + int posn = SctpSackChunk.FIXED_PARAMETERS_LENGTH + + (index * SctpSackChunk.GAP_REPORT_LENGTH); + return new SctpTsnGapBlock + { + Start = BinaryPrimitives.ReadUInt16BigEndian(Value.Slice(posn, 2)), + End = BinaryPrimitives.ReadUInt16BigEndian(Value.Slice(posn + 2, 2)) + }; + } + public ReadOnlySpan GapAckBlocks + => Value.Slice(SctpSackChunk.FIXED_PARAMETERS_LENGTH, NumGapAckBlocks * SctpSackChunk.GAP_REPORT_LENGTH); + public uint GetDuplicateTSN(int index) + { + int posn = SctpSackChunk.FIXED_PARAMETERS_LENGTH + + (NumGapAckBlocks * SctpSackChunk.GAP_REPORT_LENGTH) + + (index * SctpSackChunk.DUPLICATE_TSN_LENGTH); + return BinaryPrimitives.ReadUInt32BigEndian(Value.Slice(posn)); + } + #endregion SACK Chunk + + #region Init Chunk + public uint InitiateTag => BinaryPrimitives.ReadUInt32BigEndian(Value); + public uint ARwndInit => BinaryPrimitives.ReadUInt32BigEndian(Value.Slice(4, 4)); + public ushort NumberInboundStreams => BinaryPrimitives.ReadUInt16BigEndian(Value.Slice(8, 2)); + public ushort NumberOutboundStreams => BinaryPrimitives.ReadUInt16BigEndian(Value.Slice(10, 2)); + public uint InitialTSN => BinaryPrimitives.ReadUInt32BigEndian(Value.Slice(12, 4)); + #endregion Init Chunk + + public ErrorCodeEnumerable GetErrorCodes() + { + int paramsBufferLength = Length - SctpChunk.SCTP_CHUNK_HEADER_LENGTH; + int paramPosn = 0; + var paramsBuffer = Value.Slice(paramPosn, paramsBufferLength); + return new ErrorCodeEnumerable(paramsBuffer); + } + + public readonly ref struct ErrorCodeEnumerable(ReadOnlySpan paramsBuffer) + { + readonly ReadOnlySpan originalParamsBuffer = paramsBuffer; + public ErrorCodeEnumerator GetEnumerator() => new (originalParamsBuffer); + } + + public ref struct ErrorCodeEnumerator(ReadOnlySpan paramsBuffer) + { + readonly ReadOnlySpan originalParamsBuffer = paramsBuffer; + ReadOnlySpan paramsBuffer; + bool started; + + public readonly SctpErrorCauseCode Current + { + get + { + SctpTlvChunkParameter.ParseFirstWord(paramsBuffer, out var type); + return (SctpErrorCauseCode)type; + } + } + + public bool MoveNext() + { + if (started) + { + int length = SctpTlvChunkParameter.ParseFirstWord(paramsBuffer, out var type); + paramsBuffer = paramsBuffer.Slice(length); + } + else + { + started = true; + paramsBuffer = originalParamsBuffer; + } + return !paramsBuffer.IsEmpty; + } + public void Dispose() { } + + public void Reset() + { + throw new NotSupportedException(); + } + } + + public SctpChunkView(ReadOnlySpan buffer) + { + if (buffer.Length < 4) + { + throw new ArgumentException("Buffer too short to be a valid SCTP chunk."); + } + this.buffer = buffer; + if (!Type.IsDefined()) + { + throw new ArgumentException($"Unknown chunk type {Type}"); + } + + _ = Type switch + { + ABORT => ValidateError(isAbort: true), + ERROR => ValidateError(isAbort: false), + DATA => ValidateData(), + SACK => ValidateSack(), + COOKIE_ACK or COOKIE_ECHO + or HEARTBEAT or HEARTBEAT_ACK + or SHUTDOWN_ACK or SHUTDOWN_COMPLETE + => ValidateBase(), + INIT or INIT_ACK => ValidateInit(), + SHUTDOWN => ValidateShutdown(), + _ => ValidateUnknownBase(), + }; + } + + public SctpChunk AsChunk() => SctpChunk.Parse(buffer, 0); + + bool ValidateShutdown() + { + ValidateBase(); + return true; + } + + bool ValidateInit() + { + ValidateBase(); + return true; + } + + bool ValidateSack() + { + int gapAckSize = NumGapAckBlocks * SctpSackChunk.GAP_REPORT_LENGTH; + int dupTsnSize = NumDuplicateTSNs * SctpSackChunk.DUPLICATE_TSN_LENGTH; + int expectedLength = SctpSackChunk.FIXED_PARAMETERS_LENGTH + gapAckSize + dupTsnSize; + if (Length < expectedLength) + { + throw new ApplicationException($"SCTP SACK chunk length {Length} does not match expected length {expectedLength}."); + } + return true; + } + + bool ValidateBase() + { + if (Length > buffer.Length) + { + throw new ArgumentException("Buffer too short to be a valid SCTP chunk."); + } + return true; + } + + bool ValidateData() + { + ValidateBase(); + if (Length < SctpDataChunk.FIXED_PARAMETERS_LENGTH) + { + throw new ApplicationException($"SCTP data chunk cannot be parsed as buffer too short for fixed parameter fields."); + } + return true; + } + + bool ValidateError(bool isAbort) + { + ValidateBase(); + SctpErrorChunk.ParseChunk(buffer, 0, isAbort: isAbort); + return true; + } + + bool ValidateUnknownBase() + { + logger.LogDebug("TODO: Implement parsing logic for well known chunk type {Type}.", Type); + return ValidateBase(); + } +} diff --git a/src/net/SCTP/SctpDataReceiver.cs b/src/net/SCTP/SctpDataReceiver.cs index 8106e42b0..4aa93e889 100644 --- a/src/net/SCTP/SctpDataReceiver.cs +++ b/src/net/SCTP/SctpDataReceiver.cs @@ -15,41 +15,47 @@ //----------------------------------------------------------------------------- using System; -using System.Collections.Concurrent; +using System.Buffers; +using System.Buffers.Binary; using System.Collections.Generic; -using System.Linq; -using System.Text; using Microsoft.Extensions.Logging; +using SIPSorcery.Sys; +using Small.Collections; +using TypeNum; + namespace SIPSorcery.Net { - public struct SctpDataFrame + public struct SctpDataFrame : IDisposable { - public static SctpDataFrame Empty = new SctpDataFrame(); + public static SctpDataFrame Empty => default; public bool Unordered; public ushort StreamID; public ushort StreamSeqNum; public uint PPID; - public byte[] UserData; + BorrowedArray userData; + public readonly ReadOnlySpan UserData => userData.Data; /// The stream ID of the chunk. /// The stream sequence number of the chunk. Will be 0 for unordered streams. /// The payload protocol ID for the chunk. - /// The chunk data. - public SctpDataFrame(bool unordered, ushort streamID, ushort streamSeqNum, uint ppid, byte[] userData) + public SctpDataFrame(bool unordered, ushort streamID, ushort streamSeqNum, uint ppid) { Unordered = unordered; StreamID = streamID; StreamSeqNum = streamSeqNum; PPID = ppid; - UserData = userData; } - public bool IsEmpty() + public void SetUserData(ReadOnlySpan userData) { - return UserData == null; + this.userData.Set(userData); } + + public void Dispose() => userData.Dispose(); + + public readonly bool IsEmpty() => userData.IsNull(); } public struct SctpTsnGapBlock @@ -69,6 +75,12 @@ public struct SctpTsnGapBlock /// DATA chunk received in this Gap Ack Block. /// public ushort End; + + public static SctpTsnGapBlock Read(ReadOnlySpan bytes) => new() + { + Start = BinaryPrimitives.ReadUInt16BigEndian(bytes), + End = BinaryPrimitives.ReadUInt16BigEndian(bytes.Slice(2)), + }; } /// @@ -203,35 +215,44 @@ public void SetInitialTSN(uint tsn) /// or more new frames will be returned otherwise an empty frame is returned. Multiple /// frames may be returned if this chunk is part of a stream and was received out /// or order. For unordered chunks the list will always have a single entry. - public List OnDataChunk(SctpDataChunk dataChunk) + public List OnDataChunk(SctpChunkView dataChunk) { + if (dataChunk.Type != SctpChunkType.DATA) + { + throw new ArgumentException($"An attempt was made to process a {dataChunk.Type} chunk as a DATA chunk."); + } + var sortedFrames = new List(); var frame = SctpDataFrame.Empty; if (_inOrderReceiveCount == 0 && GetDistance(_initialTSN, dataChunk.TSN) > _windowSize) { - logger.LogWarning($"SCTP data receiver received a data chunk with a {dataChunk.TSN} " + - $"TSN when the initial TSN was {_initialTSN} and a " + - $"window size of {_windowSize}, ignoring."); + logger.LogWarning("SCTP data receiver received a data chunk with a {TSN} " + + "TSN when the initial TSN was {InitialTSN} and a " + + "window size of {Size}, ignoring.", + dataChunk.TSN, _initialTSN, _windowSize); } else if (_inOrderReceiveCount > 0 && GetDistance(_lastInOrderTSN, dataChunk.TSN) > _windowSize) { - logger.LogWarning($"SCTP data receiver received a data chunk with a {dataChunk.TSN} " + - $"TSN when the expected TSN was {_lastInOrderTSN + 1} and a " + - $"window size of {_windowSize}, ignoring."); + logger.LogWarning("SCTP data receiver received a data chunk with a {TSN} " + + "TSN when the expected TSN was {LastInOrderTSN} and a " + + "window size of {Size}, ignoring.", + dataChunk.TSN, _lastInOrderTSN + 1, _windowSize); } else if (_inOrderReceiveCount > 0 && !IsNewer(_lastInOrderTSN, dataChunk.TSN)) { - logger.LogWarning($"SCTP data receiver received an old data chunk with {dataChunk.TSN} " + - $"TSN when the expected TSN was {_lastInOrderTSN + 1}, ignoring."); + logger.LogDebug("SCTP received an old data chunk with {TSN} " + + "TSN when the expected TSN was {LastInOrderTSN}, ignoring.", + dataChunk.TSN, _lastInOrderTSN + 1); } else if (!_forwardTSN.ContainsKey(dataChunk.TSN)) { - logger.LogTrace($"SCTP receiver got data chunk with TSN {dataChunk.TSN}, " + - $"last in order TSN {_lastInOrderTSN}, in order receive count {_inOrderReceiveCount}."); + logger.LogTrace("SCTP receiver got data chunk with TSN {TSN}, " + + "last in order TSN {LastInOrderTSN}, in order receive count {InOrderReceiveCount}.", + dataChunk.TSN, _lastInOrderTSN, _inOrderReceiveCount); bool processFrame = true; @@ -263,7 +284,8 @@ public List OnDataChunk(SctpDataChunk dataChunk) outOfOrder.Count >= MAXIMUM_OUTOFORDER_FRAMES) { // Stream is nearing capacity, only chunks that advance _lastInOrderTSN can be accepted. - logger.LogWarning($"Stream {dataChunk.StreamID} is at buffer capacity. Rejected out-of-order data chunk TSN {dataChunk.TSN}."); + logger.LogWarning("Stream {StreamID} is at buffer capacity. Rejected out-of-order data chunk TSN {TSN}.", + dataChunk.StreamID, dataChunk.TSN); processFrame = false; } else @@ -276,32 +298,34 @@ public List OnDataChunk(SctpDataChunk dataChunk) if (processFrame) { // Now go about processing the data chunk. - if (dataChunk.Begining && dataChunk.Ending) + if (dataChunk.Beginning && dataChunk.Ending) { // Single packet chunk. frame = new SctpDataFrame( dataChunk.Unordered, dataChunk.StreamID, dataChunk.StreamSeqNum, - dataChunk.PPID, - dataChunk.UserData); + dataChunk.PPID); + + frame.SetUserData(dataChunk.UserData); } else { + var tmp = SctpDataChunk.ParseChunk(dataChunk.Buffer, 0, ArrayPool.Shared); // This is a data chunk fragment. - _fragmentedChunks.Add(dataChunk.TSN, dataChunk); + _fragmentedChunks.Add(dataChunk.TSN, tmp); (var begin, var end) = GetChunkBeginAndEnd(_fragmentedChunks, dataChunk.TSN); if (begin != null && end != null) { - frame = GetFragmentedChunk(_fragmentedChunks, begin.Value, end.Value); + frame = ExtractFragmentedChunk(_fragmentedChunks, begin.Value, end.Value); } } } } else { - logger.LogTrace($"SCTP duplicate TSN received for {dataChunk.TSN}."); + logger.LogTrace("SCTP duplicate TSN received for {TSN}.", dataChunk.TSN); if (!_duplicateTSN.ContainsKey(dataChunk.TSN)) { _duplicateTSN.Add(dataChunk.TSN, 1); @@ -339,7 +363,7 @@ public SctpSackChunk GetSackChunk() { SctpSackChunk sack = new SctpSackChunk(_lastInOrderTSN, _receiveWindow); sack.GapAckBlocks = GetForwardTSNGaps(); - sack.DuplicateTSN = _duplicateTSN.Keys.ToList(); + sack.DuplicateTSN.AddRange(_duplicateTSN.Keys.GetEnumerator()); return sack; } else @@ -354,9 +378,9 @@ public SctpSackChunk GetSackChunk() /// TSNs have not yet been received. /// /// A list of TSN gap blocks. An empty list means there are no gaps. - internal List GetForwardTSNGaps() + internal SmallList, SctpTsnGapBlock> GetForwardTSNGaps() { - List gaps = new List(); + var gaps = new SmallList, SctpTsnGapBlock>(); // Can't create gap reports until the initial DATA chunk has been received. if (_inOrderReceiveCount > 0) @@ -508,28 +532,30 @@ private List ProcessStreamFrame(SctpDataFrame frame) /// The dictionary containing the chunk fragments. /// The beginning TSN for the fragment. /// The end TSN for the fragment. - private SctpDataFrame GetFragmentedChunk(Dictionary fragments, uint beginTSN, uint endTSN) + private SctpDataFrame ExtractFragmentedChunk(Dictionary fragments, uint beginTSN, uint endTSN) { unchecked { - byte[] full = new byte[MAX_FRAME_SIZE]; + Span full = stackalloc byte[MAX_FRAME_SIZE]; int posn = 0; var beginChunk = fragments[beginTSN]; - var frame = new SctpDataFrame(beginChunk.Unordered, beginChunk.StreamID, beginChunk.StreamSeqNum, beginChunk.PPID, full); uint afterEndTSN = endTSN + 1; uint tsn = beginTSN; while (tsn != afterEndTSN) { - var fragment = fragments[tsn].UserData; - Buffer.BlockCopy(fragment, 0, full, posn, fragment.Length); - posn += fragment.Length; + var fragment = fragments[tsn]; + var fragmentData = fragment.UserData; + fragmentData.CopyTo(full.Slice(posn)); + posn += fragmentData.Length; fragments.Remove(tsn); + fragment.Dispose(); tsn++; } - frame.UserData = frame.UserData.Take(posn).ToArray(); + var frame = new SctpDataFrame(beginChunk.Unordered, beginChunk.StreamID, beginChunk.StreamSeqNum, beginChunk.PPID); + frame.SetUserData(full.Slice(0, posn)); return frame; } diff --git a/src/net/SCTP/SctpDataSender.cs b/src/net/SCTP/SctpDataSender.cs index dd814bf84..e71a36098 100644 --- a/src/net/SCTP/SctpDataSender.cs +++ b/src/net/SCTP/SctpDataSender.cs @@ -74,12 +74,14 @@ public class SctpDataSender private uint _initialTSN; private bool _gotFirstSACK; private bool _isStarted; - private bool _isClosed; + private Once _closed; private int _lastAckedDataChunkSize; - private bool _inRetransmitMode; - private bool _inFastRecoveryMode; + private OnOff _inRetransmitMode; + private OnOff _inFastRecoveryMode; + /// Only ever accessed inside private uint _fastRecoveryExitPoint; private ManualResetEventSlim _senderMre = new ManualResetEventSlim(); + private readonly ManualResetEventSlim _queueSpaceAvailable = new ManualResetEventSlim(initialState: true); /// /// Congestion control window (cwnd, in bytes), which is adjusted by @@ -127,11 +129,10 @@ public class SctpDataSender /// /// A count of the bytes currently in-flight to the remote peer. /// - internal uint _outstandingBytes => - (uint)(_unconfirmedChunks.Sum(x => x.Value.UserData.Length)); + internal int _outstandingBytes => _unconfirmedChunks.Sum(x => x.Value.UserDataLength); /// - /// The TSN that the remote peer has acknowledged. + /// The TSN that the remote peer has acknowledged. Only ever accessed inside /// private uint _cumulativeAckTSN; @@ -141,6 +142,8 @@ public class SctpDataSender /// private Dictionary _streamSeqnums = new Dictionary(); + public int MaxSendQueueCount => 128; +#warning this must be rewritten to use BlockingQueue /// /// Queue to hold SCTP frames that are waiting to be sent to the remote peer. /// @@ -160,12 +163,13 @@ public class SctpDataSender /// /// The total size (in bytes) of queued user data that will be sent to the peer. /// - public ulong BufferedAmount => (ulong)_sendQueue.Sum(x => x.UserData?.Length ?? 0); + public ulong BufferedAmount => (ulong)_sendQueue.Sum(x => x.UserDataLength); + int tsn; /// /// The Transaction Sequence Number (TSN) that will be used in the next DATA chunk sent. /// - public uint TSN { get; internal set; } + public uint TSN => unchecked((uint)Interlocked.CompareExchange(ref tsn, 0, 0)); public SctpDataSender( string associationID, @@ -178,7 +182,7 @@ public SctpDataSender( _sendDataChunk = sendDataChunk; _defaultMTU = defaultMTU > 0 ? defaultMTU : DEFAULT_SCTP_MTU; _initialTSN = initialTSN; - TSN = initialTSN; + tsn = unchecked((int)initialTSN); _initialRemoteARwnd = remoteARwnd; _receiverWindow = remoteARwnd; @@ -198,14 +202,12 @@ public void SetReceiverWindow(uint remoteARwnd) /// Handler for SACK chunks received from the remote peer. /// /// The SACK chunk from the remote peer. - public void GotSack(SctpSackChunk sack) + public void GotSack(SctpChunkView sack) { - if (sack != null) { - if (_inRetransmitMode) + if (_inRetransmitMode.TryTurnOff()) { - logger.LogTrace($"SCTP sender exiting retransmit mode."); - _inRetransmitMode = false; + logger.LogDebug("SCTP sender exiting retransmit mode."); } unchecked @@ -222,7 +224,7 @@ public void GotSack(SctpSackChunk sack) UpdateRoundTripTime(result); } - _lastAckedDataChunkSize = result.UserData.Length; + Interlocked.Exchange(ref _lastAckedDataChunkSize, result.UserDataLength); } if (!_gotFirstSACK) @@ -230,7 +232,8 @@ public void GotSack(SctpSackChunk sack) if (SctpDataReceiver.GetDistance(_initialTSN, sack.CumulativeTsnAck) < maxTSNDistance && SctpDataReceiver.IsNewerOrEqual(_initialTSN, sack.CumulativeTsnAck)) { - logger.LogTrace($"SCTP first SACK remote peer TSN ACK {sack.CumulativeTsnAck} next sender TSN {TSN}, arwnd {sack.ARwnd} (gap reports {sack.GapAckBlocks.Count})."); + logger.LogTrace("SCTP first SACK remote peer TSN ACK {CumulativeTsnAck} next sender TSN {TSN}, arwnd {ARwnd} (gap reports {NumGapAckBlocks}).", + sack.CumulativeTsnAck, TSN, sack.ARwnd, sack.NumGapAckBlocks); _gotFirstSACK = true; _cumulativeAckTSN = _initialTSN; RemoveAckedUnconfirmedChunks(sack.CumulativeTsnAck); @@ -242,42 +245,47 @@ public void GotSack(SctpSackChunk sack) { if (SctpDataReceiver.GetDistance(_cumulativeAckTSN, sack.CumulativeTsnAck) > maxTSNDistance) { - logger.LogWarning($"SCTP SACK TSN from remote peer of {sack.CumulativeTsnAck} was too distant from the expected {_cumulativeAckTSN}, ignoring."); + logger.LogWarning("SCTP SACK TSN from remote peer of {PeerCumulativeTsnAck} was too distant from the expected {ExpectedCumulativeAckTSN}, ignoring.", + sack.CumulativeTsnAck, _cumulativeAckTSN); processGapReports = false; } else if (!SctpDataReceiver.IsNewer(_cumulativeAckTSN, sack.CumulativeTsnAck)) { - logger.LogWarning($"SCTP SACK TSN from remote peer of {sack.CumulativeTsnAck} was behind expected {_cumulativeAckTSN}, ignoring."); + logger.LogWarning("SCTP SACK TSN from remote peer of {PeerCumulativeTsnAck} was behind expected {ExpectedCumulativeAckTSN}, ignoring.", + sack.CumulativeTsnAck, _cumulativeAckTSN); processGapReports = false; } else { - logger.LogTrace($"SCTP SACK remote peer TSN ACK {sack.CumulativeTsnAck}, next sender TSN {TSN}, arwnd {sack.ARwnd} (gap reports {sack.GapAckBlocks.Count})."); + logger.LogTrace("SCTP SACK remote peer TSN ACK {CumulativeTsnAck}, next sender TSN {TSN}, arwnd {ARwnd} (gap reports {NumGapAckBlocks}).", + sack.CumulativeTsnAck, TSN, sack.ARwnd, sack.NumGapAckBlocks); RemoveAckedUnconfirmedChunks(sack.CumulativeTsnAck); } } else { - logger.LogTrace($"SCTP SACK remote peer TSN ACK no change {_cumulativeAckTSN}, next sender TSN {TSN}, arwnd {sack.ARwnd} (gap reports {sack.GapAckBlocks.Count})."); + logger.LogTrace("SCTP SACK remote peer TSN ACK no change {CumulativeAckTSN}, next sender TSN {TSN}, arwnd {ARwnd} (gap reports {NumGapAckBlocks}).", + _cumulativeAckTSN, TSN, sack.ARwnd, sack.NumGapAckBlocks); RemoveAckedUnconfirmedChunks(sack.CumulativeTsnAck); } } - if (sack.DuplicateTSN.Count > 0) + if (sack.NumDuplicateTSNs > 0) { // The remote is reporting that we have sent a duplicate TSN. // This is probably because a SACK chunk was dropped. // Ensure that we stop sending the duplicate. - foreach (uint duplicateTSN in sack.DuplicateTSN) + for (int tsnIndex = 0; tsnIndex < sack.NumDuplicateTSNs; tsnIndex++) { - _unconfirmedChunks.TryRemove(duplicateTSN, out _); + uint duplicateTSN = sack.GetDuplicateTSN(tsnIndex); + RemoveUnconfirmedChunk(duplicateTSN); _missingChunks.TryRemove(duplicateTSN, out _); } } // Check gap reports. Only process them if the cumulative ACK TSN was acceptable. - if (processGapReports && sack.GapAckBlocks.Count > 0) + if (processGapReports && sack.NumGapAckBlocks > 0) { bool didIncrementCumAckTSN = SctpDataReceiver.IsNewer(cumAckTSNBeforeSackProcessing, _cumulativeAckTSN); ProcessGapReports(sack.GapAckBlocks, maxTSNDistance, didIncrementCumAckTSN); @@ -285,15 +293,15 @@ public void GotSack(SctpSackChunk sack) // rfc4960 6.2.1 D iv // If the Cumulative TSN Ack matches or exceeds the Fast Recovery exitpoint(Section 7.2.4), Fast Recovery is exited. - if (_inFastRecoveryMode && SctpDataReceiver.IsNewerOrEqual(_fastRecoveryExitPoint, _cumulativeAckTSN)) + if (SctpDataReceiver.IsNewerOrEqual(_fastRecoveryExitPoint, _cumulativeAckTSN) && _inFastRecoveryMode.TryTurnOff()) { - logger.LogTrace($"SCTP sender exiting fast recovery at TSN {_fastRecoveryExitPoint}"); - _inFastRecoveryMode = false; + logger.LogTrace("SCTP sender exiting fast recovery at TSN {TSN}", _fastRecoveryExitPoint); } } - _receiverWindow = CalculateReceiverWindow(sack.ARwnd); - _congestionWindow = CalculateCongestionWindow(_lastAckedDataChunkSize); + var outstandingBytes = _outstandingBytes; + _receiverWindow = CalculateReceiverWindow(sack.ARwnd, outstandingBytes: (uint)outstandingBytes); + _congestionWindow = CalculateCongestionWindow(InterlockedEx.Read(ref _lastAckedDataChunkSize), outstandingBytes: (uint)outstandingBytes); // SACK's will normally allow more data to be sent. _senderMre.Set(); @@ -306,10 +314,20 @@ public void GotSack(SctpSackChunk sack) /// The stream ID to sent the data on. /// The payload protocol ID for the data. /// The byte data to send. - public void SendData(ushort streamID, uint ppid, byte[] data) + public void SendData(ushort streamID, uint ppid, ReadOnlySpan data) { + // combined spin/lock wait + while (!_queueSpaceAvailable.Wait(TimeSpan.FromMilliseconds(10)) && _sendQueue.Count > MaxSendQueueCount) + { + + } + lock (_sendQueue) { + if (_closed.HasOccurred) + { + return; + } ushort seqnum = 0; if (_streamSeqnums.ContainsKey(streamID)) @@ -330,10 +348,6 @@ public void SendData(ushort streamID, uint ppid, byte[] data) int offset = (index == 0) ? 0 : (index * _defaultMTU); int payloadLength = (offset + _defaultMTU < data.Length) ? _defaultMTU : data.Length - offset; - // Future TODO: Replace with slice when System.Memory is introduced as a dependency. - byte[] payload = new byte[payloadLength]; - Buffer.BlockCopy(data, offset, payload, 0, payloadLength); - bool isBegining = index == 0; bool isEnd = ((offset + payloadLength) >= data.Length) ? true : false; @@ -345,11 +359,16 @@ public void SendData(ushort streamID, uint ppid, byte[] data) streamID, seqnum, ppid, - payload); + data.Slice(offset, payloadLength)); _sendQueue.Enqueue(dataChunk); - TSN = (TSN == UInt32.MaxValue) ? 0 : TSN + 1; + Interlocked.Increment(ref tsn); + } + + if (_sendQueue.Count > MaxSendQueueCount) + { + _queueSpaceAvailable.Reset(); } _senderMre.Set(); @@ -365,7 +384,11 @@ public void StartSending() if (!_isStarted) { _isStarted = true; - var sendThread = new Thread(DoSend); + var sendThread = new Thread(DoSend) + { + IsBackground = true, + Name = $"{nameof(SctpDataSender)}-{_associationID}", + }; sendThread.IsBackground = true; sendThread.Start(); } @@ -376,7 +399,18 @@ public void StartSending() /// public void Close() { - _isClosed = true; + _closed.TryMarkOccurred(); + foreach (var chunk in _unconfirmedChunks) + { + chunk.Value.Dispose(); + } + lock (_sendQueue) + { + foreach (var chunk in _sendQueue) + { + chunk.Dispose(); + } + } } /// @@ -388,7 +422,7 @@ public void Close() /// ACK'ed TSN. If this distance gets exceeded by a gap report then it's likely something has been /// miscalculated. /// If true, processing of the SACK incremented the - private void ProcessGapReports(List sackGapBlocks, uint maxTSNDistance, bool didSackIncrementTSN) + private void ProcessGapReports(ReadOnlySpan sackGapBlocks, uint maxTSNDistance, bool didSackIncrementTSN) { uint lastAckTSN = _cumulativeAckTSN; @@ -399,23 +433,26 @@ private void ProcessGapReports(List sackGapBlocks, uint maxTSND unchecked { // Parse the gap report to identify missing chunks that have now been acknowledged in the gap report - foreach (var block in sackGapBlocks) + for(int index = 0; index < sackGapBlocks.Length; index += SctpSackChunk.GAP_REPORT_LENGTH) { + var block = SctpTsnGapBlock.Read(sackGapBlocks.Slice(index)); for (ushort offset = block.Start; offset <= block.End; offset++) { uint goodTSN = _cumulativeAckTSN + offset; _missingChunks.TryRemove(goodTSN, out _); - if (_unconfirmedChunks.TryRemove(goodTSN, out _)) + if (_unconfirmedChunks.TryRemove(goodTSN, out var chunk)) { - logger.LogTrace($"SCTP acknowledged data chunk receipt in gap report for TSN {goodTSN}"); + logger.LogTrace("SCTP acknowledged data chunk receipt in gap report for TSN {TSN}", goodTSN); highestTsnNewlyAcknowledged = goodTSN; + chunk.Dispose(); } } } - foreach (var gapBlock in sackGapBlocks) + for (int index = 0; index < sackGapBlocks.Length; index += SctpSackChunk.GAP_REPORT_LENGTH) { + var gapBlock = SctpTsnGapBlock.Read(sackGapBlocks.Slice(index)); uint goodTSNStart = _cumulativeAckTSN + gapBlock.Start; if (SctpDataReceiver.GetDistance(lastAckTSN, goodTSNStart) > maxTSNDistance) @@ -457,7 +494,7 @@ private void ProcessGapReports(List sackGapBlocks, uint maxTSND else if ( // If an endpoint is in Fast Recovery and a SACK arrives that advances the Cumulative TSN Ack // Point, the miss indications are incremented for all TSNs reported missing in the SACK. - (_inFastRecoveryMode && didSackIncrementTSN) || + (_inFastRecoveryMode.IsOn() && didSackIncrementTSN) || // For each incoming SACK, miss indications are incremented only // for missing TSNs prior to the highest TSN newly acknowledged in the SACK. SctpDataReceiver.IsNewer(missingTSN, highestTsnNewlyAcknowledged)) @@ -467,16 +504,16 @@ private void ProcessGapReports(List sackGapBlocks, uint maxTSND // rfc 7.2.4: When the third consecutive miss indication is received for a TSN(s), the data sender shall do the following... if (missCount + 1 == 3) { - if (!_inFastRecoveryMode) // RFC4960 7.2.4 (2) + if (_inFastRecoveryMode.TryTurnOn()) // RFC4960 7.2.4 (2) { - _inFastRecoveryMode = true; // mark the highest outstanding TSN as the Fast Recovery exit point - _fastRecoveryExitPoint = _cumulativeAckTSN + sackGapBlocks.Last().End; + var last = SctpTsnGapBlock.Read(sackGapBlocks.Slice(sackGapBlocks.Length - SctpSackChunk.GAP_REPORT_LENGTH)); + _fastRecoveryExitPoint = _cumulativeAckTSN + last.End; - logger.LogTrace($"SCTP sender entering fast recovery mode due to missing TSN {missingTSN}. Fast recovery exit point {_fastRecoveryExitPoint}."); + logger.LogDebug($"SCTP sender entering fast recovery mode due to missing TSN {missingTSN}. Fast recovery exit point {_fastRecoveryExitPoint}."); // RFC4960 7.2.3 _slowStartThreshold = (uint)Math.Max(_congestionWindow / 2, 4 * _defaultMTU); - _congestionWindow = _defaultMTU; + _congestionWindow = _slowStartThreshold; } } } @@ -497,12 +534,13 @@ private void ProcessGapReports(List sackGapBlocks, uint maxTSND /// The acknowledged TSN received from in a SACK from the remote peer. private void RemoveAckedUnconfirmedChunks(uint sackTSN) { - logger.LogTrace($"SCTP data sender removing unconfirmed chunks cumulative ACK TSN {_cumulativeAckTSN}, SACK TSN {sackTSN}."); + logger.LogTrace("SCTP data sender removing unconfirmed chunks cumulative ACK TSN {CumulativeAckTSN}, SACK TSN {SackTSN}.", + _cumulativeAckTSN, sackTSN); if (_cumulativeAckTSN == sackTSN) { // This is normal for the first SACK received. - _unconfirmedChunks.TryRemove(_cumulativeAckTSN, out _); + RemoveUnconfirmedChunk(_cumulativeAckTSN); _missingChunks.TryRemove(_cumulativeAckTSN, out _); } else @@ -512,7 +550,7 @@ private void RemoveAckedUnconfirmedChunks(uint sackTSN) for (uint offset = 0; offset <= SctpDataReceiver.GetDistance(_cumulativeAckTSN, sackTSN); offset++) { uint ackd = _cumulativeAckTSN + offset; - _unconfirmedChunks.TryRemove(ackd, out _); + RemoveUnconfirmedChunk(ackd); _missingChunks.TryRemove(ackd, out _); } _cumulativeAckTSN = sackTSN; @@ -520,27 +558,35 @@ private void RemoveAckedUnconfirmedChunks(uint sackTSN) } } + private void RemoveUnconfirmedChunk(uint tsn) + { + if (_unconfirmedChunks.TryRemove(tsn, out var chunk)) + { + chunk.Dispose(); + } + } + /// /// Worker thread to process the send and retransmit queues. /// private void DoSend(object state) { - logger.LogDebug($"SCTP association data send thread started for association {_associationID}."); + logger.LogDebug("SCTP association data send thread started for association {ID}.", _associationID); - while (!_isClosed) + while (!_closed.HasOccurred) { - var outstandingBytes = _outstandingBytes; + var outstandingBytes = (uint)_outstandingBytes; // DateTime.Now calls have been a tiny bit expensive in the past so get a small saving by only // calling once per loop. - DateTime now = DateTime.Now; + var now = SctpDataChunk.Timestamp.Now; - int burstSize = (_inRetransmitMode || _inFastRecoveryMode || _congestionWindow < outstandingBytes || _receiverWindow == 0) ? 1 : MAX_BURST; + int burstSize = (_inRetransmitMode.IsOn() || _inFastRecoveryMode.IsOn() || _congestionWindow < outstandingBytes || _receiverWindow == 0) ? 1 : MAX_BURST; int chunksSent = 0; //logger.LogTrace($"SCTP sender burst size {burstSize}, in retransmit mode {_inRetransmitMode}, cwnd {_congestionWindow}, arwnd {_receiverWindow}."); // Missing chunks from a SACK gap report take priority. - if (_missingChunks.Count > 0) + if (!_missingChunks.IsEmpty) { foreach (var missing in _missingChunks) { @@ -551,8 +597,9 @@ private void DoSend(object state) missingChunk.LastSentAt = now; missingChunk.SendCount += 1; - logger.LogTrace($"SCTP resending missing data chunk for TSN {missingChunk.TSN}, data length {missingChunk.UserData.Length}, " + - $"flags {missingChunk.ChunkFlags:X2}, send count {missingChunk.SendCount}."); + logger.LogTrace("SCTP resending missing data chunk for TSN {TSN}, data length {Length}, " + + "flags {Flags:X2}, send count {Count}.", + missingChunk.TSN, missingChunk.UserDataLength, missingChunk.ChunkFlags, missingChunk.SendCount); _sendDataChunk(missingChunk); chunksSent++; @@ -567,30 +614,42 @@ private void DoSend(object state) } // Check if there are any unconfirmed transactions that are due for a retransmit. - if (chunksSent < burstSize && _unconfirmedChunks.Count > 0) + if (chunksSent < burstSize && !_unconfirmedChunks.IsEmpty) { - foreach (var chunk in _unconfirmedChunks.Values - .Where(x => now.Subtract(x.LastSentAt).TotalMilliseconds > (_hasRoundTripTime ? _rto : _rtoInitialMilliseconds)) - .Take(burstSize - chunksSent)) + int taken = 0, send = burstSize - chunksSent; + foreach (var entry in _unconfirmedChunks) { - chunk.LastSentAt = DateTime.Now; + var chunk = entry.Value; + if (now.Milliseconds - chunk.LastSentAt.Milliseconds <= (_hasRoundTripTime ? _rto : _rtoInitialMilliseconds)) + { + continue; + } + if (taken >= send) + { + break; + } + taken++; + + chunk.LastSentAt = SctpDataChunk.Timestamp.Now; chunk.SendCount += 1; - logger.LogTrace($"SCTP retransmitting data chunk for TSN {chunk.TSN}, data length {chunk.UserData.Length}, " + - $"flags {chunk.ChunkFlags:X2}, send count {chunk.SendCount}."); + logger.LogTrace("SCTP retransmitting data chunk for TSN {TSN}, data length {Length}, " + + "flags {Flags:X2}, send count {Count}.", + chunk.TSN, chunk.UserDataLength, chunk.ChunkFlags, chunk.SendCount); _sendDataChunk(chunk); chunksSent++; - - if (!_inRetransmitMode) + + if (_inRetransmitMode.TryTurnOn()) { - logger.LogTrace($"SCTP sender entering retransmit mode."); - _inRetransmitMode = true; + logger.LogDebug("SCTP sender entering retransmit mode."); // When the T3-rtx timer expires on an address, SCTP should perform slow start. // RFC4960 7.2.3 _slowStartThreshold = (uint)Math.Max(_congestionWindow / 2, 4 * _defaultMTU); - _congestionWindow = _defaultMTU; + // did not clarify, but I believe entering retransmit mode is NOT the same + // as T3-rtx timer expiring. Will just use regular halving formula here. + _congestionWindow = _slowStartThreshold; // For the destination address for which the timer expires, set RTO <- RTO * 2("back off the timer") // RFC4960 6.3.3 E2 @@ -605,18 +664,29 @@ private void DoSend(object state) // if it has cwnd or more bytes of data outstanding to that transport address. // Send any new data chunks that have not yet been sent. - if (chunksSent < burstSize && _sendQueue.Count > 0 && _congestionWindow > outstandingBytes) + if (chunksSent < burstSize && !_sendQueue.IsEmpty && _congestionWindow > outstandingBytes) { while (chunksSent < burstSize && _sendQueue.TryDequeue(out var dataChunk)) { - dataChunk.LastSentAt = DateTime.Now; + dataChunk.LastSentAt = SctpDataChunk.Timestamp.Now; dataChunk.SendCount = 1; - logger.LogTrace($"SCTP sending data chunk for TSN {dataChunk.TSN}, data length {dataChunk.UserData.Length}, " + - $"flags {dataChunk.ChunkFlags:X2}, send count {dataChunk.SendCount}."); + logger.LogTrace("SCTP sending data chunk for TSN {TSN}, data length {Length}, " + + "flags {Flags:X2}, send count {Count}.", + dataChunk.TSN, dataChunk.UserDataLength, dataChunk.ChunkFlags, dataChunk.SendCount); - _unconfirmedChunks.TryAdd(dataChunk.TSN, dataChunk); - _sendDataChunk(dataChunk); + if (_unconfirmedChunks.TryAdd(dataChunk.TSN, dataChunk)) + { + _sendDataChunk(dataChunk); + } + else + { + logger.LogDebug("SCTP duplicate TSN {TSN} detected in send queue.", dataChunk.TSN); + } + if (_sendQueue.Count < MaxSendQueueCount) + { + _queueSpaceAvailable.Set(); + } chunksSent++; } } @@ -631,7 +701,7 @@ private void DoSend(object state) _senderMre.Wait(wait); } - logger.LogDebug($"SCTP association data send thread stopped for association {_associationID}."); + logger.LogDebug("SCTP association data send thread stopped for association {ID}.", _associationID); } /// @@ -639,9 +709,9 @@ private void DoSend(object state) /// private int GetSendWaitMilliseconds() { - if (_sendQueue.Count > 0 || _missingChunks.Count > 0) + if (!_sendQueue.IsEmpty || !_missingChunks.IsEmpty) { - if (_receiverWindow > 0 && _congestionWindow > _outstandingBytes) + if (_receiverWindow > 0 && _congestionWindow > (uint)_outstandingBytes) { return _burstPeriodMilliseconds; } @@ -650,7 +720,7 @@ private int GetSendWaitMilliseconds() return _rtoMinimumMilliseconds; } } - else if (_unconfirmedChunks.Count > 0) + else if (!_unconfirmedChunks.IsEmpty) { return (int)(_hasRoundTripTime ? _rto : _rtoInitialMilliseconds); } @@ -674,7 +744,7 @@ private void UpdateRoundTripTime(SctpDataChunk acknowledgedChunk) return; } - var rttMilliseconds = (DateTime.Now - acknowledgedChunk.LastSentAt).TotalMilliseconds; + var rttMilliseconds = SctpDataChunk.Timestamp.Now.Milliseconds - acknowledgedChunk.LastSentAt.Milliseconds; if (!_hasRoundTripTime) { @@ -707,9 +777,9 @@ private void UpdateRoundTripTime(SctpDataChunk acknowledgedChunk) /// See https://tools.ietf.org/html/rfc4960#section-6.2.1. /// /// The new value to use for the receiver window. - private uint CalculateReceiverWindow(uint advertisedReceiveWindow) + private uint CalculateReceiverWindow(uint advertisedReceiveWindow, uint outstandingBytes) { - return (advertisedReceiveWindow > _outstandingBytes) ? advertisedReceiveWindow - _outstandingBytes : 0; + return (advertisedReceiveWindow > outstandingBytes) ? advertisedReceiveWindow - outstandingBytes : 0; } /// @@ -717,20 +787,21 @@ private uint CalculateReceiverWindow(uint advertisedReceiveWindow) /// /// The size of last ACK'ed DATA chunk. /// A congestion window value. - private uint CalculateCongestionWindow(int lastAckDataChunkSize) + private uint CalculateCongestionWindow(int lastAckDataChunkSize, uint outstandingBytes) { - if (_congestionWindow < _slowStartThreshold) + if (_congestionWindow <= _slowStartThreshold) { // In Slow-Start mode, see RFC4960 7.2.1. - - if (_congestionWindow < _outstandingBytes) + // Updated to RFC9260 7.2.1 + if (_congestionWindow <= outstandingBytes && !_inFastRecoveryMode.IsOn()) { // When cwnd is less than or equal to ssthresh, an SCTP endpoint MUST // use the slow - start algorithm to increase cwnd only if the current // congestion window is being fully utilized. uint increasedCwnd = (uint)(_congestionWindow + Math.Min(lastAckDataChunkSize, _defaultMTU)); - logger.LogTrace($"SCTP sender congestion window in slow-start increased from {_congestionWindow} to {increasedCwnd}."); + logger.LogTrace("SCTP sender congestion window in slow-start increased from {Original} to {Increased}.", + _congestionWindow, increasedCwnd); return increasedCwnd; } @@ -743,9 +814,10 @@ private uint CalculateCongestionWindow(int lastAckDataChunkSize) { // In Congestion Avoidance mode, see RFC4960 7.2.2. - if (_congestionWindow < _outstandingBytes) + if (_congestionWindow <= outstandingBytes) { - logger.LogTrace($"SCTP sender congestion window in congestion avoidance increased from {_congestionWindow} to {_congestionWindow + _defaultMTU}."); + logger.LogTrace("SCTP sender congestion window in congestion avoidance increased from {Original} to {Increased}.", + _congestionWindow, _congestionWindow + _defaultMTU); return _congestionWindow + _defaultMTU; } diff --git a/src/net/SCTP/SctpHeader.cs b/src/net/SCTP/SctpHeader.cs index 569778d12..8c5611ae8 100644 --- a/src/net/SCTP/SctpHeader.cs +++ b/src/net/SCTP/SctpHeader.cs @@ -61,13 +61,25 @@ public void WriteToBuffer(byte[] buffer, int posn) NetConvert.ToBuffer(VerificationTag, buffer, posn + 4); } + /// + /// Serialises the header to a pre-allocated buffer. + /// + /// The buffer to write the SCTP header bytes to. It + /// must have the required space already allocated. + public readonly void WriteToBuffer(Span buffer) + { + NetConvert.ToBuffer(SourcePort, buffer, 0); + NetConvert.ToBuffer(DestinationPort, buffer, 2); + NetConvert.ToBuffer(VerificationTag, buffer, 4); + } + /// /// Parses the an SCTP header from a buffer. /// /// The buffer to parse the SCTP header from. /// The position in the buffer to start parsing the header from. /// A new SCTPHeaer instance. - public static SctpHeader Parse(byte[] buffer, int posn) + public static SctpHeader Parse(ReadOnlySpan buffer) { if (buffer.Length < SCTP_HEADER_LENGTH) { @@ -76,10 +88,10 @@ public static SctpHeader Parse(byte[] buffer, int posn) SctpHeader header = new SctpHeader(); - header.SourcePort = NetConvert.ParseUInt16(buffer, posn); - header.DestinationPort = NetConvert.ParseUInt16(buffer, posn + 2); - header.VerificationTag = NetConvert.ParseUInt32(buffer, posn + 4); - header.Checksum = NetConvert.ParseUInt32(buffer, posn + 8); + header.SourcePort = NetConvert.ParseUInt16(buffer); + header.DestinationPort = NetConvert.ParseUInt16(buffer.Slice(2)); + header.VerificationTag = NetConvert.ParseUInt32(buffer.Slice(4)); + header.Checksum = NetConvert.ParseUInt32(buffer.Slice(8)); return header; } diff --git a/src/net/SCTP/SctpPacket.cs b/src/net/SCTP/SctpPacket.cs index de6e675cf..865717abe 100644 --- a/src/net/SCTP/SctpPacket.cs +++ b/src/net/SCTP/SctpPacket.cs @@ -20,6 +20,11 @@ using System; using System.Collections.Generic; using System.Linq; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; +#if NET6_0_OR_GREATER +using System.Runtime.Intrinsics.X86; +#endif using Microsoft.Extensions.Logging; using SIPSorcery.Sys; @@ -54,6 +59,30 @@ public static uint Calculate(byte[] buffer, int offset, int length) } return crc ^ 0xffffffff; } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal static uint Calculate(ReadOnlySpan span) + { + uint crc = ~0u; +#if NET6_0_OR_GREATER + if (Sse42.X64.IsSupported) + { + int strides = span.Length / 8; + int stridedBytes = strides * 8; + ReadOnlySpan ulongs = MemoryMarshal.Cast(span[..stridedBytes]); + for (int i = 0; i < strides; i++) + { + crc = (uint)Sse42.X64.Crc32(crc, ulongs[i]); + } + span = span[stridedBytes..]; + } +#endif + for (int i = 0; i < span.Length; i++) + { + crc = _table[(crc ^ span[i]) & 0xff] ^ crc >> 8; + } + return crc ^ 0xffffffff; + } } /// @@ -91,10 +120,14 @@ public class SctpPacket /// A list of the blobs for chunks that weren't recognised when parsing /// a received packet. /// - public List UnrecognisedChunks; + public IReadOnlyList UnrecognisedChunks; - private SctpPacket() - { } + private SctpPacket(SctpHeader header, List chunks, IReadOnlyList unrecognizedChunks) + { + Header = header; + Chunks = chunks; + UnrecognisedChunks = unrecognizedChunks; + } /// /// Creates a new SCTP packet instance. @@ -115,7 +148,7 @@ public SctpPacket( }; Chunks = new List(); - UnrecognisedChunks = new List(); + UnrecognisedChunks = Array.Empty(); } /// @@ -142,6 +175,32 @@ public byte[] GetBytes() return buffer; } + public int GetBytes(Span buffer) + { + int chunksLength = Chunks.Sum(x => x.GetChunkLength(true)); + int totalSize = SctpHeader.SCTP_HEADER_LENGTH + chunksLength; + if (buffer.Length < totalSize) + { + return -totalSize; + } + + buffer = buffer.Slice(0, totalSize); + + Header.WriteToBuffer(buffer); + + int writePosn = SctpHeader.SCTP_HEADER_LENGTH; + foreach (var chunk in Chunks) + { + writePosn += chunk.WriteTo(buffer, writePosn); + } + + NetConvert.ToBuffer(0U, buffer, CHECKSUM_BUFFER_POSITION); + uint checksum = CRC32C.Calculate(buffer); + NetConvert.ToBuffer(NetConvert.EndianFlip(checksum), buffer, CHECKSUM_BUFFER_POSITION); + + return totalSize; + } + /// /// Adds a new chunk to send with an outgoing packet. /// @@ -157,28 +216,15 @@ public void AddChunk(SctpChunk chunk) /// The buffer holding the serialised packet. /// The position in the buffer of the packet. /// The length of the serialised packet in the buffer. - public static SctpPacket Parse(byte[] buffer, int offset, int length) - { - var pkt = new SctpPacket(); - pkt.Header = SctpHeader.Parse(buffer, offset); - (pkt.Chunks, pkt.UnrecognisedChunks) = ParseChunks(buffer, offset, length); - - return pkt; - } - - /// - /// Parses the chunks from a serialised SCTP packet. - /// - /// The buffer holding the serialised packet. - /// The position in the buffer of the packet. - /// The length of the serialised packet in the buffer. - /// The lsit of parsed chunks and a list of unrecognised chunks that were not de-serialised. - private static (List chunks, List unrecognisedChunks) ParseChunks(byte[] buffer, int offset, int length) + public static SctpPacket Parse(ReadOnlySpan buffer) { + var header = SctpHeader.Parse(buffer); List chunks = new List(); - List unrecognisedChunks = new List(); + // avoid allocation + IReadOnlyList unrecognisedChunks = Array.Empty(); - int posn = offset + SctpHeader.SCTP_HEADER_LENGTH; + int posn = SctpHeader.SCTP_HEADER_LENGTH; + int length = buffer.Length; bool stop = false; @@ -186,7 +232,7 @@ private static (List chunks, List unrecognisedChunks) ParseCh { byte chunkType = buffer[posn]; - if (Enum.IsDefined(typeof(SctpChunkType), chunkType)) + if (((SctpChunkType)chunkType).IsDefined()) { var chunk = SctpChunk.Parse(buffer, posn); chunks.Add(chunk); @@ -200,12 +246,14 @@ private static (List chunks, List unrecognisedChunks) ParseCh break; case SctpUnrecognisedChunkActions.StopAndReport: stop = true; - unrecognisedChunks.Add(SctpChunk.CopyUnrecognisedChunk(buffer, posn)); + unrecognisedChunks = unrecognisedChunks as List ?? []; + ((List)unrecognisedChunks).Add(SctpChunk.CopyUnrecognisedChunk(buffer, posn).ToArray()); break; case SctpUnrecognisedChunkActions.Skip: break; case SctpUnrecognisedChunkActions.SkipAndReport: - unrecognisedChunks.Add(SctpChunk.CopyUnrecognisedChunk(buffer, posn)); + unrecognisedChunks = unrecognisedChunks as List ?? []; + ((List)unrecognisedChunks).Add(SctpChunk.CopyUnrecognisedChunk(buffer, posn).ToArray()); break; } } @@ -219,7 +267,7 @@ private static (List chunks, List unrecognisedChunks) ParseCh posn += (int)SctpChunk.GetChunkLengthFromHeader(buffer, posn, true); } - return (chunks, unrecognisedChunks); + return new SctpPacket(header, chunks, unrecognisedChunks); } /// @@ -229,14 +277,16 @@ private static (List chunks, List unrecognisedChunks) ParseCh /// The start position in the buffer. /// The length of the packet in the buffer. /// True if the checksum was valid, false if not. - public static bool VerifyChecksum(byte[] buffer, int posn, int length) + public static bool VerifyChecksum(ReadOnlySpan bufferRO) { - uint origChecksum = NetConvert.ParseUInt32(buffer, posn + CHECKSUM_BUFFER_POSITION); - NetConvert.ToBuffer(0U, buffer, posn + CHECKSUM_BUFFER_POSITION); - uint calcChecksum = CRC32C.Calculate(buffer, posn, length); + uint origChecksum = NetConvert.ParseUInt32(bufferRO, CHECKSUM_BUFFER_POSITION); + Span buffer = stackalloc byte[bufferRO.Length]; + bufferRO.CopyTo(buffer); + NetConvert.ToBuffer(0U, buffer, CHECKSUM_BUFFER_POSITION); + uint calcChecksum = CRC32C.Calculate(buffer); // Put the original checksum back. - NetConvert.ToBuffer(origChecksum, buffer, posn + CHECKSUM_BUFFER_POSITION); + NetConvert.ToBuffer(origChecksum, buffer, CHECKSUM_BUFFER_POSITION); return origChecksum == NetConvert.EndianFlip(calcChecksum); } @@ -246,27 +296,27 @@ public static bool VerifyChecksum(byte[] buffer, int posn, int length) /// a pre-flight check to be carried out before de-serialising the whole buffer. /// /// The buffer holding the serialised packet. - /// The start position in the buffer. - /// The length of the packet in the buffer. /// The verification tag for the serialised SCTP packet. - public static uint GetVerificationTag(byte[] buffer, int posn, int length) + public static uint GetVerificationTag(ReadOnlySpan buffer) { - return NetConvert.ParseUInt32(buffer, posn + VERIFICATIONTAG_BUFFER_POSITION); + return NetConvert.ParseUInt32(buffer, VERIFICATIONTAG_BUFFER_POSITION); } /// /// Performs verification checks on a serialised SCTP packet. /// /// The buffer holding the serialised packet. - /// The start position in the buffer. - /// The length of the packet in the buffer. /// The required verification tag for the serialised /// packet. This should match the verification tag supplied by the remote party. /// True if the packet is valid, false if not. - public static bool IsValid(byte[] buffer, int posn, int length, uint requiredTag) + public static bool IsValid(ReadOnlySpan buffer, uint requiredTag) + { + return GetVerificationTag(buffer) == requiredTag && VerifyChecksum(buffer); + } + + public SctpChunk GetChunk(SctpChunkType chunkType) { - return GetVerificationTag(buffer, posn, length) == requiredTag && - VerifyChecksum(buffer, posn, length); + return Chunks.Single(x => x.ChunkType == (byte)chunkType); } } } diff --git a/src/net/SCTP/SctpPacketView.cs b/src/net/SCTP/SctpPacketView.cs new file mode 100644 index 000000000..0fd4b4b0e --- /dev/null +++ b/src/net/SCTP/SctpPacketView.cs @@ -0,0 +1,125 @@ +using System; +using System.Collections.Generic; + +using Microsoft.Extensions.Logging; + +using SIPSorcery.Sys; + +using Small.Collections; + +using TypeNum; + +namespace SIPSorcery.Net; + +public readonly ref struct SctpPacketView +{ + static readonly ILogger logger = LogFactory.CreateLogger(); + readonly ReadOnlySpan buffer; + readonly SmallList, Chunk> chunks; + readonly SmallList, Chunk> unrecognized; + + public readonly SctpHeader Header => SctpHeader.Parse(buffer); + public int ChunkCount => chunks.Count; + public SctpChunkView this[int index] => chunks[index].View(buffer); + public SctpChunkView GetChunk(SctpChunkType type) + { + bool found = false; + SctpChunkView result = default; + for (int i = 0; i < chunks.Count; i++) + { + var chunk = chunks[i]; + var view = chunk.View(buffer); + if (view.Type == type) + { + if (found) + { + throw new InvalidOperationException($"Multiple {type} chunks found."); + } + + result = view; + found = true; + } + } + return found ? result : throw new KeyNotFoundException(); + } + public bool Has(SctpChunkType type) + { + for (int i = 0; i < chunks.Count; i++) + { + var chunk = chunks[i]; + var view = chunk.View(buffer); + if (view.Type == type) + { + return true; + } + } + return false; + } + public int UnrecognizedChunkCount => unrecognized.Count; + + public static SctpPacketView Parse(ReadOnlySpan buffer) + { + var chunks = new SmallList, Chunk>(); + var unrecognized = new SmallList, Chunk>(); + int posn = SctpHeader.SCTP_HEADER_LENGTH; + + bool stop = false; + + while (posn < buffer.Length) + { + byte chunkType = buffer[posn]; + int chunkLength = (int)SctpChunk.GetChunkLengthFromHeader(buffer, posn, true); + var chunk = new Chunk() { Offset = posn, Length = chunkLength }; + + if (((SctpChunkType)chunkType).IsDefined()) + { + chunk.View(buffer); + chunks.Add(chunk); + } + else + { + switch (SctpChunk.GetUnrecognisedChunkAction(chunkType)) + { + case SctpUnrecognisedChunkActions.Stop: + stop = true; + break; + case SctpUnrecognisedChunkActions.StopAndReport: + stop = true; + unrecognized.Add(chunk); + break; + case SctpUnrecognisedChunkActions.Skip: + break; + case SctpUnrecognisedChunkActions.SkipAndReport: + unrecognized.Add(chunk); + break; + } + } + + if (stop) + { + logger.LogWarning("SCTP unrecognised chunk type {Type} indicated no further chunks should be processed.", chunkType); + break; + } + + posn += chunkLength; + } + return new(buffer, chunks, unrecognized); + } + + public SctpPacket AsPacket() => SctpPacket.Parse(buffer); + + SctpPacketView(ReadOnlySpan buffer, SmallList, Chunk> chunks, SmallList, Chunk> unrecognized) + { + this.buffer = buffer; + this.chunks = chunks; + this.unrecognized = unrecognized; + } + + struct Chunk + { + public int Offset { get; set; } + public int Length { get; set; } + public SctpChunkView View(ReadOnlySpan buffer) + => new(buffer.Slice(Offset, Length)); + } +} diff --git a/src/net/SCTP/SctpTransport.cs b/src/net/SCTP/SctpTransport.cs index dd166c11f..17ddc8156 100644 --- a/src/net/SCTP/SctpTransport.cs +++ b/src/net/SCTP/SctpTransport.cs @@ -93,19 +93,19 @@ public abstract class SctpTransport /// public virtual bool IsPortAgnostic => false; - public abstract void Send(string associationID, byte[] buffer, int offset, int length); + public abstract void Send(string associationID, ReadOnlySpan buffer); static SctpTransport() { Crypto.GetRandomBytes(_hmacKey); } - protected void GotInit(SctpPacket initPacket, IPEndPoint remoteEndPoint) + protected void GotInit(SctpPacketView initPacket, IPEndPoint remoteEndPoint) { // INIT packets have specific processing rules in order to prevent resource exhaustion. // See Section 5 of RFC 4960 https://tools.ietf.org/html/rfc4960#section-5 "Association Initialization". - SctpInitChunk initChunk = initPacket.Chunks.Single(x => x.KnownType == SctpChunkType.INIT) as SctpInitChunk; + var initChunk = initPacket.GetChunk(SctpChunkType.INIT); if (initChunk.InitiateTag == 0 || initChunk.NumberInboundStreams == 0 || @@ -130,9 +130,27 @@ protected void GotInit(SctpPacket initPacket, IPEndPoint remoteEndPoint) } else { - var initAckPacket = GetInitAck(initPacket, remoteEndPoint); - var buffer = initAckPacket.GetBytes(); - Send(null, buffer, 0, buffer.Length); + var initAckPacket = GetInitAck(initPacket.AsPacket(), remoteEndPoint); + Send(null, initAckPacket); + } + } + + /// + /// Sends an SCTP packet to the remote peer. + /// + /// The packet to send. + internal void Send(string? ID, SctpPacket pkt) + { + Span span = stackalloc byte[4 * 1024]; + if (pkt.GetBytes(span) is { } size and >= 0) + { + Send(ID, span.Slice(0, size)); + } + else + { + System.Diagnostics.Debug.WriteLine("SCTP packet too large to send without allocation."); + byte[] buffer = pkt.GetBytes(); + Send(ID, buffer.AsSpan()); } } @@ -182,7 +200,7 @@ protected virtual SctpTransportCookie GetInitAckCookie( /// An SCTP packet with a single INIT ACK chunk. protected SctpPacket GetInitAck(SctpPacket initPacket, IPEndPoint remoteEP) { - SctpInitChunk initChunk = initPacket.Chunks.Single(x => x.KnownType == SctpChunkType.INIT) as SctpInitChunk; + var initChunk = (SctpInitChunk)initPacket.GetChunk(SctpChunkType.INIT); SctpPacket initAckPacket = new SctpPacket( initPacket.Header.DestinationPort, @@ -233,11 +251,11 @@ protected SctpPacket GetInitAck(SctpPacket initPacket, IPEndPoint remoteEP) /// The packet containing the COOKIE ECHO chunk received from the remote party. /// If the state cookie in the chunk is valid a new SCTP association will be returned. IF /// it's not valid an empty cookie will be returned and an error response gets sent to the peer. - protected SctpTransportCookie GetCookie(SctpPacket sctpPacket) + protected SctpTransportCookie GetCookie(SctpPacketView sctpPacket) { - var cookieEcho = sctpPacket.Chunks.Single(x => x.KnownType == SctpChunkType.COOKIE_ECHO); - var cookieBuffer = cookieEcho.ChunkValue; - var cookie = JSONParser.FromJson(Encoding.UTF8.GetString(cookieBuffer)); + var cookieEcho = sctpPacket.GetChunk(SctpChunkType.COOKIE_ECHO); + var cookieBuffer = cookieEcho.Value; + var cookie = JSONParser.FromJson(cookieBuffer.ToString(Encoding.UTF8)); logger.LogDebug($"Cookie: {cookie.ToJson()}"); @@ -277,9 +295,9 @@ protected SctpTransportCookie GetCookie(SctpPacket sctpPacket) /// /// The buffer holding the state cookie. /// True if the cookie is determined as valid, false if not. - protected string GetCookieHMAC(byte[] buffer) + protected string GetCookieHMAC(ReadOnlySpan buffer) { - var cookie = JSONParser.FromJson(Encoding.UTF8.GetString(buffer)); + var cookie = JSONParser.FromJson(buffer.ToString(Encoding.UTF8)); string hmacCalculated = null; cookie.HMAC = string.Empty; @@ -318,8 +336,7 @@ private void SendError( errorChunk.AddErrorCause(error); errorPacket.AddChunk(errorChunk); - var buffer = errorPacket.GetBytes(); - Send(null, buffer, 0, buffer.Length); + Send(null, errorPacket); } /// diff --git a/src/net/SCTP/SctpUdpTransport.cs b/src/net/SCTP/SctpUdpTransport.cs index 9daaa19fa..7b0594790 100644 --- a/src/net/SCTP/SctpUdpTransport.cs +++ b/src/net/SCTP/SctpUdpTransport.cs @@ -18,6 +18,7 @@ //----------------------------------------------------------------------------- using System; +using System.Buffers; using System.Collections.Concurrent; using System.Linq; using System.Net; @@ -65,24 +66,24 @@ public SctpUdpTransport(int udpEncapPort = 0, PortRange portRange = null) /// The local port the packet was received on. /// The remote end point the packet was received from. /// A buffer containing the packet. - private void OnEncapsulationSocketPacketReceived(UdpReceiver receiver, int localPort, IPEndPoint remoteEndPoint, byte[] packet) + private void OnEncapsulationSocketPacketReceived(UdpReceiver receiver, int localPort, IPEndPoint remoteEndPoint, ReadOnlySpan packet) { try { - if (!SctpPacket.VerifyChecksum(packet, 0, packet.Length)) + if (!SctpPacket.VerifyChecksum(packet)) { logger.LogWarning($"SCTP packet from UDP {remoteEndPoint} dropped due to invalid checksum."); } else { - var sctpPacket = SctpPacket.Parse(packet, 0, packet.Length); + var sctpPacket = SctpPacketView.Parse(packet); // Process packet. if (sctpPacket.Header.VerificationTag == 0) { GotInit(sctpPacket, remoteEndPoint); } - else if (sctpPacket.Chunks.Any(x => x.KnownType == SctpChunkType.COOKIE_ECHO)) + else if (sctpPacket.Has(SctpChunkType.COOKIE_ECHO)) { // The COOKIE ECHO chunk is the 3rd step in the SCTP handshake when the remote party has // requested a new association be created. @@ -100,7 +101,7 @@ private void OnEncapsulationSocketPacketReceived(UdpReceiver receiver, int local if (_associations.TryAdd(association.ID, association)) { - if (sctpPacket.Chunks.Count > 1) + if (sctpPacket.ChunkCount > 1) { // There could be DATA chunks after the COOKIE ECHO chunk. association.OnPacketReceived(sctpPacket); @@ -134,11 +135,20 @@ private void OnEncapsulationSocketClosed(string reason) logger.LogInformation($"SCTP transport encapsulation receiver closed with reason: {reason}."); } - public override void Send(string associationID, byte[] buffer, int offset, int length) + public override void Send(string associationID, ReadOnlySpan span) { if (_associations.TryGetValue(associationID, out var assoc)) { - _udpEncapSocket.SendTo(buffer, offset, length, SocketFlags.None, assoc.Destination); + byte[] buffer = ArrayPool.Shared.Rent(span.Length); + span.CopyTo(buffer); + try + { + _udpEncapSocket.SendTo(buffer, 0, span.Length, SocketFlags.None, assoc.Destination); + } + finally + { + ArrayPool.Shared.Return(buffer); + } } } diff --git a/src/net/STUN/STUNAttributes/STUNAttribute.cs b/src/net/STUN/STUNAttributes/STUNAttribute.cs index 6fc997a65..b57143be0 100644 --- a/src/net/STUN/STUNAttributes/STUNAttribute.cs +++ b/src/net/STUN/STUNAttributes/STUNAttribute.cs @@ -157,17 +157,21 @@ public STUNAttribute(STUNAttributeTypesEnum attributeType, ulong value) Value = NetConvert.GetBytes(value); } - public static List ParseMessageAttributes(byte[] buffer, int startIndex, int endIndex) => ParseMessageAttributes(buffer, startIndex, endIndex, null); - - public static List ParseMessageAttributes(byte[] buffer, int startIndex, int endIndex, STUNHeader header) + public static bool TryParseMessageAttributes(List attributes, ReadOnlySpan buffer, int startIndex, int endIndex, STUNHeader? header) { - if (buffer != null && buffer.Length > startIndex && buffer.Length >= endIndex) + if (buffer.Length > startIndex && buffer.Length >= endIndex) { - List attributes = new List(); int startAttIndex = startIndex; while (startAttIndex < endIndex) { + int remainingBytes = endIndex - startAttIndex; + if (remainingBytes < 4) + { + logger.LogWarning("The remaining number of bytes in the STUN message was less than the minimum attribute length 4. Remaining bytes: {RemainingBytes}.", remainingBytes); + break; + } + UInt16 stunAttributeType = NetConvert.ParseUInt16(buffer, startAttIndex); UInt16 stunAttributeLength = NetConvert.ParseUInt16(buffer, startAttIndex + 2); byte[] stunAttributeValue = null; @@ -183,7 +187,7 @@ public static List ParseMessageAttributes(byte[] buffer, int star else { stunAttributeValue = new byte[stunAttributeLength]; - Buffer.BlockCopy(buffer, startAttIndex + 4, stunAttributeValue, 0, stunAttributeLength); + buffer.Slice(startAttIndex + 4, stunAttributeLength).CopyTo(stunAttributeValue); } } @@ -225,11 +229,12 @@ public static List ParseMessageAttributes(byte[] buffer, int star startAttIndex = startAttIndex + 4 + stunAttributeLength + padding; } - return attributes; + return true; } else { - return null; + logger.LogWarning("Bad STUN attribute parse request. Start: {Start}; End: {End}; Length: {Length}.", startIndex, endIndex, buffer.Length); + return false; } } diff --git a/src/net/STUN/STUNHeader.cs b/src/net/STUN/STUNHeader.cs index 3633b532a..9ea66ba1a 100644 --- a/src/net/STUN/STUNHeader.cs +++ b/src/net/STUN/STUNHeader.cs @@ -72,6 +72,7 @@ //----------------------------------------------------------------------------- using System; +using System.Buffers.Binary; using System.Text; using SIPSorcery.Sys; @@ -164,30 +165,24 @@ public static STUNHeader ParseSTUNHeader(byte[] buffer) return ParseSTUNHeader(new ArraySegment(buffer, 0, buffer.Length)); } - public static STUNHeader ParseSTUNHeader(ArraySegment bufferSegment) + public static STUNHeader ParseSTUNHeader(ReadOnlySpan bufferSegment) { - var startIndex = bufferSegment.Offset; - if ((bufferSegment.Array[startIndex] & STUN_INITIAL_BYTE_MASK) != 0) + var startIndex = 0; + if ((bufferSegment[startIndex] & STUN_INITIAL_BYTE_MASK) != 0) { throw new ApplicationException("The STUN header did not begin with 0x00."); } - if (bufferSegment != null && bufferSegment.Count > 0 && bufferSegment.Count >= STUN_HEADER_LENGTH) + if (bufferSegment != null && bufferSegment.Length > 0 && bufferSegment.Length >= STUN_HEADER_LENGTH) { STUNHeader stunHeader = new STUNHeader(); - UInt16 stunTypeValue = BitConverter.ToUInt16(bufferSegment.Array, startIndex); - UInt16 stunMessageLength = BitConverter.ToUInt16(bufferSegment.Array, startIndex + 2);; - - if (BitConverter.IsLittleEndian) - { - stunTypeValue = NetConvert.DoReverseEndian(stunTypeValue); - stunMessageLength = NetConvert.DoReverseEndian(stunMessageLength); - } + UInt16 stunTypeValue = BinaryPrimitives.ReadUInt16BigEndian(bufferSegment.Slice(startIndex)); + UInt16 stunMessageLength = BinaryPrimitives.ReadUInt16BigEndian(bufferSegment.Slice(startIndex + 2)); stunHeader.MessageType = STUNMessageTypes.GetSTUNMessageTypeForId(stunTypeValue); stunHeader.MessageLength = stunMessageLength; - Buffer.BlockCopy(bufferSegment.Array, startIndex + 8, stunHeader.TransactionId, 0, TRANSACTION_ID_LENGTH); + bufferSegment.Slice(startIndex + 8, TRANSACTION_ID_LENGTH).CopyTo(stunHeader.TransactionId); return stunHeader; } diff --git a/src/net/STUN/STUNMessage.cs b/src/net/STUN/STUNMessage.cs index e6ed1a02f..d5c1279aa 100644 --- a/src/net/STUN/STUNMessage.cs +++ b/src/net/STUN/STUNMessage.cs @@ -90,17 +90,17 @@ public void AddXORAddressAttribute(STUNAttributeTypesEnum addressType, IPAddress Attributes.Add(xorAddressAttribute); } - public static STUNMessage ParseSTUNMessage(byte[] buffer, int bufferLength) + public static STUNMessage ParseSTUNMessage(ReadOnlySpan buffer, int bufferLength) { if (buffer != null && buffer.Length > 0 && buffer.Length >= bufferLength) { STUNMessage stunMessage = new STUNMessage(); - stunMessage._receivedBuffer = buffer.Take(bufferLength).ToArray(); + stunMessage._receivedBuffer = buffer.Slice(0, bufferLength).ToArray(); stunMessage.Header = STUNHeader.ParseSTUNHeader(buffer); if (stunMessage.Header.MessageLength > 0) { - stunMessage.Attributes = STUNAttribute.ParseMessageAttributes(buffer, STUNHeader.STUN_HEADER_LENGTH, bufferLength, stunMessage.Header); + STUNAttribute.TryParseMessageAttributes(stunMessage.Attributes, buffer, STUNHeader.STUN_HEADER_LENGTH, bufferLength, stunMessage.Header); } if (stunMessage.Attributes.Count > 0 && stunMessage.Attributes.Last().AttributeType == STUNAttributeTypesEnum.FingerPrint) @@ -108,7 +108,7 @@ public static STUNMessage ParseSTUNMessage(byte[] buffer, int bufferLength) // Check fingerprint. var fingerprintAttribute = stunMessage.Attributes.Last(); - var input = buffer.Take(buffer.Length - STUNAttribute.STUNATTRIBUTE_HEADER_LENGTH - FINGERPRINT_ATTRIBUTE_CRC32_LENGTH).ToArray(); + var input = buffer.Slice(0, buffer.Length - STUNAttribute.STUNATTRIBUTE_HEADER_LENGTH - FINGERPRINT_ATTRIBUTE_CRC32_LENGTH); uint crc = Crc32.Compute(input) ^ FINGERPRINT_XOR; byte[] fingerPrint = (BitConverter.IsLittleEndian) ? BitConverter.GetBytes(NetConvert.DoReverseEndian(crc)) : BitConverter.GetBytes(crc); diff --git a/src/net/WebRTC/DCEP.cs b/src/net/WebRTC/DCEP.cs index 952cc093c..593d8a63e 100644 --- a/src/net/WebRTC/DCEP.cs +++ b/src/net/WebRTC/DCEP.cs @@ -137,7 +137,7 @@ public struct DataChannelOpenMessage /// The buffer to parse the message from. /// The position in the buffer to start parsing from. /// A new DCEP open message instance. - public static DataChannelOpenMessage Parse(byte[] buffer, int posn) + public static DataChannelOpenMessage Parse(ReadOnlySpan buffer, int posn) { if (buffer.Length < DCEP_OPEN_FIXED_PARAMETERS_LENGTH) { @@ -156,12 +156,12 @@ public static DataChannelOpenMessage Parse(byte[] buffer, int posn) if (labelLength > 0) { - dcepOpen.Label = Encoding.UTF8.GetString(buffer, 12, labelLength); + dcepOpen.Label = buffer.Slice(12, labelLength).ToString(Encoding.UTF8); } if (protocolLength > 0) { - dcepOpen.Protocol = Encoding.UTF8.GetString(buffer, 12 + labelLength, protocolLength); + dcepOpen.Protocol = buffer.Slice(12 + labelLength, protocolLength).ToString(Encoding.UTF8); } return dcepOpen; diff --git a/src/net/WebRTC/IRTCDataChannel.cs b/src/net/WebRTC/IRTCDataChannel.cs index b2c2e3618..d26526add 100644 --- a/src/net/WebRTC/IRTCDataChannel.cs +++ b/src/net/WebRTC/IRTCDataChannel.cs @@ -36,7 +36,7 @@ namespace SIPSorcery.Net { - public delegate void OnDataChannelMessageDelegate(RTCDataChannel dc, DataChannelPayloadProtocols protocol, byte[] data); + public delegate void OnDataChannelMessageDelegate(RTCDataChannel dc, DataChannelPayloadProtocols protocol, ReadOnlySpan data); public enum RTCDataChannelState { @@ -161,7 +161,7 @@ interface IRTCDataChannel string binaryType { get; set; } void send(string data); - void send(byte[] data); + void send(ReadOnlySpan data); }; public class RTCDataChannelInit diff --git a/src/net/WebRTC/IRTCPeerConnection.cs b/src/net/WebRTC/IRTCPeerConnection.cs index e65c0a52c..43649d0e1 100644 --- a/src/net/WebRTC/IRTCPeerConnection.cs +++ b/src/net/WebRTC/IRTCPeerConnection.cs @@ -24,6 +24,8 @@ using System.Net; using System.Security.Cryptography.X509Certificates; using System.Threading.Tasks; +using Org.BouncyCastle.Tls; +using Org.BouncyCastle.Tls.Crypto; using SIPSorcery.Sys; namespace SIPSorcery.Net @@ -232,7 +234,7 @@ public long expires public List getFingerprints() { - return new List { DtlsUtils.Fingerprint(Certificate) }; + return new List { DtlsUtils.Fingerprint(Org.BouncyCastle.Security.DotNetUtilities.FromX509Certificate(Certificate)) }; } } @@ -288,10 +290,7 @@ public class RTCConfiguration public RTCIceTransportPolicy iceTransportPolicy; public RTCBundlePolicy bundlePolicy; public RTCRtcpMuxPolicy rtcpMuxPolicy; -#pragma warning disable CS0618 // Type or member is obsolete - public List certificates; -#pragma warning restore CS0618 // Type or member is obsolete - public List certificates2; + public List certificates2; /// /// The Bouncy Castle DTLS logic enforces the use of Extended Master diff --git a/src/net/WebRTC/RTCDataChannel.cs b/src/net/WebRTC/RTCDataChannel.cs index e9e253287..2d5c17396 100644 --- a/src/net/WebRTC/RTCDataChannel.cs +++ b/src/net/WebRTC/RTCDataChannel.cs @@ -130,6 +130,7 @@ public void close() /// Sends a string data payload on the data channel. /// /// The string message to send. + /// SCTP transport is not connected. public void send(string message) { if (message != null && Encoding.UTF8.GetByteCount(message) > _transport.maxMessageSize) @@ -139,7 +140,7 @@ public void send(string message) } else if (_transport.state != RTCSctpTransportState.Connected) { - logger.LogWarning($"WebRTC data channel send failed due to SCTP transport in state {_transport.state}."); + throw new InvalidOperationException("SCTP transport is not connected."); } else { @@ -165,7 +166,8 @@ public void send(string message) /// Sends a binary data payload on the data channel. /// /// The data to send. - public void send(byte[] data) + /// SCTP transport is not connected. + public void send(ReadOnlySpan data) { if (data.Length > _transport.maxMessageSize) { @@ -174,13 +176,13 @@ public void send(byte[] data) } else if (_transport.state != RTCSctpTransportState.Connected) { - logger.LogWarning($"WebRTC data channel send failed due to SCTP transport in state {_transport.state}."); + throw new InvalidOperationException("SCTP transport is not connected."); } else { lock (this) { - if (data?.Length == 0) + if (data.Length == 0) { _transport.RTCSctpAssociation.SendData(id.GetValueOrDefault(), (uint)DataChannelPayloadProtocols.WebRTC_Binary_Empty, @@ -249,7 +251,7 @@ internal void SendDcepAck() /// /// Event handler for an SCTP data chunk being received for this data channel. /// - internal void GotData(ushort streamID, ushort streamSeqNum, uint ppID, byte[] data) + internal void GotData(ushort streamID, ushort streamSeqNum, uint ppID, ReadOnlySpan data) { //logger.LogTrace($"WebRTC data channel GotData stream ID {streamID}, stream seqnum {streamSeqNum}, ppid {ppID}, label {label}."); diff --git a/src/net/WebRTC/RTCPeerConnection.cs b/src/net/WebRTC/RTCPeerConnection.cs index 2ebff18d2..a2f064bf9 100644 --- a/src/net/WebRTC/RTCPeerConnection.cs +++ b/src/net/WebRTC/RTCPeerConnection.cs @@ -42,9 +42,10 @@ using System.Threading.Tasks; using Microsoft.Extensions.Logging; using SIPSorcery.net.RTP; -using Org.BouncyCastle.Crypto.Tls; using SIPSorcery.SIP.App; using SIPSorcery.Sys; +using Org.BouncyCastle.Tls; +using Org.BouncyCastle.Tls.Crypto.Impl.BC; namespace SIPSorcery.Net { @@ -181,8 +182,9 @@ public class RTCPeerConnection : RTPSession, IRTCPeerConnection readonly RTCDataChannelCollection dataChannels; public IReadOnlyCollection DataChannels => dataChannels; - private Org.BouncyCastle.Crypto.Tls.Certificate _dtlsCertificate; + private Org.BouncyCastle.Tls.Certificate _dtlsCertificate; private Org.BouncyCastle.Crypto.AsymmetricKeyParameter _dtlsPrivateKey; + private BcTlsCrypto _crypto; private DtlsSrtpTransport _dtlsHandle; private Task _iceGatheringTask; @@ -385,6 +387,7 @@ public RTCPeerConnection() : public RTCPeerConnection(RTCConfiguration configuration, int bindPort = 0, PortRange portRange = null, Boolean videoAsPrimary = false) : base(true, true, true, configuration?.X_BindAddress, bindPort, portRange) { + _crypto = new BcTlsCrypto(); dataChannels = new RTCDataChannelCollection(useEvenIds: () => _dtlsHandle.IsClient); if (_configuration != null && @@ -398,7 +401,7 @@ public RTCPeerConnection(RTCConfiguration configuration, int bindPort = 0, PortR { _configuration = configuration; - if (!InitializeCertificates(configuration) && !InitializeCertificates2(configuration)) + if (!InitializeCertificates2(configuration)) { logger.LogWarning("No DTLS certificate is provided in the configuration"); } @@ -416,7 +419,7 @@ public RTCPeerConnection(RTCConfiguration configuration, int bindPort = 0, PortR if (_dtlsCertificate == null) { // No certificate was provided so create a new self signed one. - (_dtlsCertificate, _dtlsPrivateKey) = DtlsUtils.CreateSelfSignedTlsCert(); + (_dtlsCertificate, _dtlsPrivateKey) = DtlsUtils.CreateSelfSignedTlsCert(_crypto); } DtlsCertificateFingerprint = DtlsUtils.Fingerprint(_dtlsCertificate); @@ -452,55 +455,6 @@ public RTCPeerConnection(RTCConfiguration configuration, int bindPort = 0, PortR _iceGatheringTask = Task.Run(_rtpIceChannel.StartGathering); } - private bool InitializeCertificates(RTCConfiguration configuration) - { - if (configuration.certificates == null || configuration.certificates.Count == 0) - { - return false; - } - - // Find the first certificate that has a usable private key. -#pragma warning disable CS0618 // Type or member is obsolete - RTCCertificate usableCert = null; -#pragma warning restore CS0618 // Type or member is obsolete - foreach (var cert in _configuration.certificates) - { - // Attempting to check that a certificate has an exportable private key. - // TODO: Does not seem to be a particularly reliable way of checking private key exportability. - if (cert.Certificate.HasPrivateKey) - { - //if (cert.Certificate.PrivateKey is RSACryptoServiceProvider) - //{ - // var rsa = cert.Certificate.PrivateKey as RSACryptoServiceProvider; - // if (!rsa.CspKeyContainerInfo.Exportable) - // { - // logger.LogWarning($"RTCPeerConnection was passed a certificate for {cert.Certificate.FriendlyName} with a non-exportable RSA private key."); - // } - // else - // { - // usableCert = cert; - // break; - // } - //} - //else - //{ - usableCert = cert; - break; - //} - } - } - - if (usableCert == null) - { - throw new ApplicationException("RTCPeerConnection was not able to find a certificate from the input configuration list with a usable private key."); - } - - _dtlsCertificate = DtlsUtils.LoadCertificateChain(usableCert.Certificate); - _dtlsPrivateKey = DtlsUtils.LoadPrivateKeyResource(usableCert.Certificate); - - return true; - } - private bool InitializeCertificates2(RTCConfiguration configuration) { if (configuration.certificates2 == null || configuration.certificates2.Count == 0) @@ -508,7 +462,7 @@ private bool InitializeCertificates2(RTCConfiguration configuration) return false; } - _dtlsCertificate = new Certificate(new [] { configuration.certificates2[0].Certificate.CertificateStructure }); + _dtlsCertificate = new Certificate(new [] { new BcTlsCertificate(_crypto, configuration.certificates2[0].Certificate.CertificateStructure) }); _dtlsPrivateKey = configuration.certificates2[0].PrivateKey; return true; @@ -555,11 +509,14 @@ private async void IceConnectionStateChange(RTCIceConnectionState iceState) logger.LogInformation($"ICE connected to remote end point {connectedEP}."); bool disableDtlsExtendedMasterSecret = _configuration != null && _configuration.X_DisableExtendedMasterSecretKey; + + + _dtlsHandle = new DtlsSrtpTransport( IceRole == IceRolesEnum.active ? - new DtlsSrtpClient(_dtlsCertificate, _dtlsPrivateKey) + new DtlsSrtpClient(_crypto, _dtlsCertificate, _dtlsPrivateKey) { ForceUseExtendedMasterSecret = !disableDtlsExtendedMasterSecret } : - (IDtlsSrtpPeer)new DtlsSrtpServer(_dtlsCertificate, _dtlsPrivateKey) + (IDtlsSrtpPeer)new DtlsSrtpServer(_crypto, _dtlsCertificate, _dtlsPrivateKey) { ForceUseExtendedMasterSecret = !disableDtlsExtendedMasterSecret } ); @@ -1372,7 +1329,7 @@ void AddIceCandidates(SDPMediaAnnouncement announcement) /// The local port on the RTP socket that received the packet. /// The remote end point the packet was received from. /// The data received. - private void OnRTPDataReceived(int localPort, IPEndPoint remoteEP, byte[] buffer) + private void OnRTPDataReceived(int localPort, IPEndPoint remoteEP, ReadOnlySpan buffer) { //logger.LogDebug($"RTP channel received a packet from {remoteEP}, {buffer?.Length} bytes."); @@ -1381,11 +1338,11 @@ private void OnRTPDataReceived(int localPort, IPEndPoint remoteEP, byte[] buffer // Because DTLS packets can be fragmented and RTP/RTCP should never be use the RTP/RTCP // prefix to distinguish. - if (buffer?.Length > 0) + if (buffer.Length > 0) { try { - if (buffer?.Length > RTPHeader.MIN_HEADER_LEN && buffer[0] >= 128 && buffer[0] <= 191) + if (buffer.Length > RTPHeader.MIN_HEADER_LEN && buffer[0] >= 128 && buffer[0] <= 191) { // RTP/RTCP packet. base.OnReceive(localPort, remoteEP, buffer); @@ -1656,7 +1613,7 @@ private void OnSctpAssociationDataChunk(SctpDataFrame frame) /// When a data channel is requested an SCTP association is needed. This method attempts to /// initialise the association if it is not already available. /// - private async Task InitialiseSctpAssociation() + private async Task InitialiseSctpAssociation(CancellationToken cancel = default) { if (sctp.RTCSctpAssociation.State != SctpAssociationState.Established) { @@ -1678,7 +1635,7 @@ private async Task InitialiseSctpAssociation() DateTime startTime = DateTime.Now; - var completedTask = await Task.WhenAny(onSctpConnectedTcs.Task, Task.Delay(SCTP_ASSOCIATE_TIMEOUT_SECONDS * 1000)).ConfigureAwait(false); + var completedTask = await Task.WhenAny(onSctpConnectedTcs.Task, Task.Delay(SCTP_ASSOCIATE_TIMEOUT_SECONDS * 1000, cancel)).ConfigureAwait(false); if (sctp.state != RTCSctpTransportState.Connected) { @@ -1705,7 +1662,7 @@ private async Task InitialiseSctpAssociation() /// /// The label used to identify the data channel. /// The data channel created. - public async Task createDataChannel(string label, RTCDataChannelInit init = null) + public async Task createDataChannel(string label, RTCDataChannelInit init = null, CancellationToken cancel = default) { logger.LogDebug($"Data channel create request for label {label}."); @@ -1729,7 +1686,7 @@ public async Task createDataChannel(string label, RTCDataChannel if (sctp.RTCSctpAssociation == null || sctp.RTCSctpAssociation.State != SctpAssociationState.Established) { - await InitialiseSctpAssociation().ConfigureAwait(false); + await InitialiseSctpAssociation(cancel).ConfigureAwait(false); } dataChannels.AddActiveChannel(channel); @@ -1739,6 +1696,7 @@ public async Task createDataChannel(string label, RTCDataChannel TaskCompletionSource isopen = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); channel.onopen += () => isopen.TrySetResult(string.Empty); channel.onerror += (err) => isopen.TrySetResult(err); + using var _ = cancel.Register(() => isopen.TrySetResult("cancelled")); var error = await isopen.Task.ConfigureAwait(false); if (error != string.Empty) diff --git a/src/net/WebRTC/RTCSctpTransport.cs b/src/net/WebRTC/RTCSctpTransport.cs index c918e3e00..d5603e580 100644 --- a/src/net/WebRTC/RTCSctpTransport.cs +++ b/src/net/WebRTC/RTCSctpTransport.cs @@ -16,11 +16,12 @@ //----------------------------------------------------------------------------- using System; +using System.Buffers; using System.Linq; using System.Net.Sockets; using System.Threading; using Microsoft.Extensions.Logging; -using Org.BouncyCastle.Crypto.Tls; +using Org.BouncyCastle.Tls; using SIPSorcery.Sys; namespace SIPSorcery.Net @@ -106,8 +107,8 @@ public class RTCSctpTransport : SctpTransport /// public event Action OnStateChanged; - private bool _isStarted; - private bool _isClosed; + private Once _isStarted; + private Once _isClosed; private Thread _receiveThread; /// @@ -162,10 +163,8 @@ public void UpdateDestinationPort(ushort port) /// public void Start(DatagramTransport dtlsTransport, bool isDtlsClient) { - if (!_isStarted) + if (_isStarted.TryMarkOccurred()) { - _isStarted = true; - transport = dtlsTransport; IsDtlsClient = isDtlsClient; @@ -174,6 +173,10 @@ public void Start(DatagramTransport dtlsTransport, bool isDtlsClient) _receiveThread.IsBackground = true; _receiveThread.Start(); } + else + { + logger.LogWarning($"RTCSctpTransport for association {RTCSctpAssociation.ID} has already been started."); + } } /// @@ -194,7 +197,7 @@ public void Close() { RTCSctpAssociation?.Shutdown(); } - _isClosed = true; + _isClosed.TryMarkOccurred(); } /// @@ -264,13 +267,22 @@ protected override SctpTransportCookie GetInitAckCookie( /// private void DoReceive(object state) { - byte[] recvBuffer = new byte[SctpAssociation.DEFAULT_ADVERTISED_RECEIVE_WINDOW]; - - while (!_isClosed) +#if NET6_0_OR_GREATER + Span recvBuffer = stackalloc byte[checked((int)SctpAssociation.DEFAULT_ADVERTISED_RECEIVE_WINDOW)]; +#else + byte[] recvBufferArray = new byte[SctpAssociation.DEFAULT_ADVERTISED_RECEIVE_WINDOW]; + Span recvBuffer = recvBufferArray.AsSpan(); +#endif + + while (!_isClosed.HasOccurred) { try { - int bytesRead = transport.Receive(recvBuffer, 0, recvBuffer.Length, RECEIVE_TIMEOUT_MILLISECONDS); +#if NET6_0_OR_GREATER + int bytesRead = transport.Receive(recvBuffer, RECEIVE_TIMEOUT_MILLISECONDS); +#else + int bytesRead = transport.Receive(recvBufferArray, 0, recvBuffer.Length, RECEIVE_TIMEOUT_MILLISECONDS); +#endif if (bytesRead == DtlsSrtpTransport.DTLS_RETRANSMISSION_CODE) { @@ -280,22 +292,22 @@ private void DoReceive(object state) } else if (bytesRead > 0) { - if (!SctpPacket.VerifyChecksum(recvBuffer, 0, bytesRead)) + if (!SctpPacket.VerifyChecksum(recvBuffer.Slice(0, bytesRead))) { logger.LogWarning($"SCTP packet received on DTLS transport dropped due to invalid checksum."); } else { - var pkt = SctpPacket.Parse(recvBuffer, 0, bytesRead); + var pkt = SctpPacketView.Parse(recvBuffer.Slice(0, bytesRead)); - if (pkt.Chunks.Any(x => x.KnownType == SctpChunkType.INIT)) + if (pkt.Has(SctpChunkType.INIT)) { - var initChunk = pkt.Chunks.First(x => x.KnownType == SctpChunkType.INIT) as SctpInitChunk; + var initChunk = pkt.GetChunk(SctpChunkType.INIT); logger.LogDebug($"SCTP INIT packet received, initial tag {initChunk.InitiateTag}, initial TSN {initChunk.InitialTSN}."); GotInit(pkt, null); } - else if (pkt.Chunks.Any(x => x.KnownType == SctpChunkType.COOKIE_ECHO)) + else if (pkt.Has(SctpChunkType.COOKIE_ECHO)) { // The COOKIE ECHO chunk is the 3rd step in the SCTP handshake when the remote party has // requested a new association be created. @@ -309,7 +321,7 @@ private void DoReceive(object state) { RTCSctpAssociation.GotCookie(cookie); - if (pkt.Chunks.Count() > 1) + if (pkt.ChunkCount > 1) { // There could be DATA chunks after the COOKIE ECHO chunk. RTCSctpAssociation.OnPacketReceived(pkt); @@ -322,7 +334,7 @@ private void DoReceive(object state) } } } - else if (_isClosed) + else if (_isClosed.HasOccurred) { // The DTLS transport has been closed or is no longer available. logger.LogWarning($"SCTP the RTCSctpTransport DTLS transport returned an error."); @@ -347,7 +359,7 @@ private void DoReceive(object state) } } - if (!_isClosed) + if (!_isClosed.HasOccurred) { logger.LogWarning($"SCTP association {RTCSctpAssociation.ID} receive thread stopped."); } @@ -360,22 +372,32 @@ private void DoReceive(object state) /// to the remote party. /// /// Not used for the DTLS transport. - /// The buffer containing the data to send. - /// The position in the buffer to send from. - /// The number of bytes to send. - public override void Send(string associationID, byte[] buffer, int offset, int length) + public override void Send(string associationID, ReadOnlySpan data) { - if (length > maxMessageSize) + if (data.Length > maxMessageSize) { - throw new ApplicationException($"RTCSctpTransport was requested to send data of length {length} " + + throw new ApplicationException($"RTCSctpTransport was requested to send data of length {data.Length} " + $" that exceeded the maximum allowed message size of {maxMessageSize}."); } - if (!_isClosed) + if (!_isClosed.HasOccurred) { lock (transport) { - transport.Send(buffer, offset, length); +#if NET6_0_OR_GREATER + transport.Send(data); +#else + byte[] tmp = ArrayPool.Shared.Rent(data.Length); + try + { + data.CopyTo(tmp); + transport.Send(tmp, 0, data.Length); + } + finally + { + ArrayPool.Shared.Return(tmp); + } +#endif } } } diff --git a/src/sys/BorrowedArray.cs b/src/sys/BorrowedArray.cs new file mode 100644 index 000000000..fcbee7106 --- /dev/null +++ b/src/sys/BorrowedArray.cs @@ -0,0 +1,57 @@ +#nullable enable +using System; +using System.Buffers; + +namespace SIPSorcery.Sys +{ + internal struct BorrowedArray: IDisposable + { + byte[]? data; + int length; + ArrayPool? dataOwner; + + public readonly bool IsNull() => data == null; + public readonly Span Data => data.AsSpan(0, length); + public readonly Span DataMayBeEmpty + => data is { } array + ? array.AsSpan(0, Math.Min(length, array.Length)) + : []; + public readonly int Length => length; + + public static implicit operator Span(BorrowedArray borrowed) => borrowed.Data; + + public void Set(ReadOnlySpan data, ArrayPool pool) + { + if (this.data?.Length >= data.Length) + { + data.CopyTo(this.data); + length = data.Length; + return; + } + + Empty(); + dataOwner = pool; + this.data = pool.Rent(data.Length); + data.CopyTo(this.data); + length = data.Length; + } + + public void Set(ReadOnlySpan data) => Set(data, ArrayPool.Shared); + + public void Set(byte[] data) + { + Empty(); + this.data = data; + length = data.Length; + } + + void Empty() + { + dataOwner?.Return(data); + data = null; + dataOwner = null; + } + + public void Dispose() => Empty(); + } +} diff --git a/src/sys/CRC32.cs b/src/sys/CRC32.cs index 52df73bd5..861a13822 100644 --- a/src/sys/CRC32.cs +++ b/src/sys/CRC32.cs @@ -36,7 +36,7 @@ public override void Initialize() protected override void HashCore(byte[] buffer, int start, int length) { - hash = CalculateHash(table, hash, buffer, start, length); + hash = CalculateHash(table, hash, buffer.AsSpan().Slice(start, length)); } protected override byte[] HashFinal() @@ -51,19 +51,19 @@ public override int HashSize get { return 32; } } - public static UInt32 Compute(byte[] buffer) + public static UInt32 Compute(ReadOnlySpan buffer) { - return ~CalculateHash(InitializeTable(DefaultPolynomial), DefaultSeed, buffer, 0, buffer.Length); + return ~CalculateHash(InitializeTable(DefaultPolynomial), DefaultSeed, buffer); } public static UInt32 Compute(UInt32 seed, byte[] buffer) { - return ~CalculateHash(InitializeTable(DefaultPolynomial), seed, buffer, 0, buffer.Length); + return ~CalculateHash(InitializeTable(DefaultPolynomial), seed, buffer); } public static UInt32 Compute(UInt32 polynomial, UInt32 seed, byte[] buffer) { - return ~CalculateHash(InitializeTable(polynomial), seed, buffer, 0, buffer.Length); + return ~CalculateHash(InitializeTable(polynomial), seed, buffer); } private static UInt32[] InitializeTable(UInt32 polynomial) @@ -99,10 +99,10 @@ private static UInt32[] InitializeTable(UInt32 polynomial) return createTable; } - private static UInt32 CalculateHash(UInt32[] table, UInt32 seed, byte[] buffer, int start, int size) + private static UInt32 CalculateHash(UInt32[] table, UInt32 seed, ReadOnlySpan buffer) { UInt32 crc = seed; - for (int i = start; i < size; i++) + for (int i = 0; i < buffer.Length; i++) { unchecked { diff --git a/src/sys/CollectionExtensions.cs b/src/sys/CollectionExtensions.cs new file mode 100644 index 000000000..d049176c8 --- /dev/null +++ b/src/sys/CollectionExtensions.cs @@ -0,0 +1,23 @@ +#nullable enable + +using System.Collections.Generic; + +using Small.Collections; + +using TypeNum; + +namespace SIPSorcery.Sys; + +static class CollectionExtensions +{ + public static void AddRange(this SmallList list, Enumerator enumerator) + where TSize : unmanaged, INumeral + where T : unmanaged + where Enumerator : IEnumerator + { + while (enumerator.MoveNext()) + { + list.Add(enumerator.Current); + } + } +} diff --git a/src/sys/Crypto/Crypto.cs b/src/sys/Crypto/Crypto.cs index bd4a64a3e..24fc4fd78 100644 --- a/src/sys/Crypto/Crypto.cs +++ b/src/sys/Crypto/Crypto.cs @@ -371,6 +371,13 @@ public static string GetSHA256Hash(byte[] buffer) } } + /// + /// Gets the HSA256 hash of an arbitrary buffer. + /// + /// A hex string representing the hashed buffer. + public static string GetSHA256Hash(ReadOnlySpan buffer) + => GetSHA256Hash(buffer.ToArray()); + /// /// Attempts to load an X509 certificate from a Windows OS certificate store. /// diff --git a/src/sys/EnumExtensions.cs b/src/sys/EnumExtensions.cs new file mode 100644 index 000000000..745f7d7ae --- /dev/null +++ b/src/sys/EnumExtensions.cs @@ -0,0 +1,17 @@ +using System; + +namespace SIPSorcery.Sys +{ + static class EnumExtensions + { + public static bool IsDefined(this T value) where T : struct, Enum + { + return Array.IndexOf(EnumDefined.Values, value) >= 0; + } + + class EnumDefined where T : struct, Enum + { + public static readonly T[] Values = (T[])Enum.GetValues(typeof(T)); + } + } +} diff --git a/src/sys/InterlockedEx.cs b/src/sys/InterlockedEx.cs new file mode 100644 index 000000000..decfa5af1 --- /dev/null +++ b/src/sys/InterlockedEx.cs @@ -0,0 +1,20 @@ +using System.Threading; + +namespace SIPSorcery.Sys; + +static class InterlockedEx +{ + public static int Read(ref int location) => Interlocked.CompareExchange(ref location, 0, 0); + public unsafe static uint CompareExchange(ref uint location, uint value, uint comparand) +#if NET6_0_OR_GREATER + => Interlocked.CompareExchange(ref location, value: value, comparand: comparand); +#else + { + fixed (uint* ptr = &location) + { + return unchecked((uint)Interlocked.CompareExchange(ref *(int*)ptr, (int)value, (int)comparand)); + } + } +#endif + public unsafe static uint Read(ref uint location) => CompareExchange(ref location, 0, 0); +} diff --git a/src/sys/MemoryExtensions.cs b/src/sys/MemoryExtensions.cs new file mode 100644 index 000000000..c246b93fd --- /dev/null +++ b/src/sys/MemoryExtensions.cs @@ -0,0 +1,15 @@ +using System; + +namespace SIPSorcery.Sys +{ + static class MemoryExtensions + { + public static unsafe string ToString(this ReadOnlySpan buffer, System.Text.Encoding encoding) + { + fixed (byte* ptr = buffer) + { + return encoding.GetString(ptr, buffer.Length); + } + } + } +} diff --git a/src/sys/Net/NetConvert.cs b/src/sys/Net/NetConvert.cs index 8c223f9d4..9907b41c6 100644 --- a/src/sys/Net/NetConvert.cs +++ b/src/sys/Net/NetConvert.cs @@ -14,6 +14,7 @@ //----------------------------------------------------------------------------- using System; +using System.Buffers.Binary; using System.Linq; namespace SIPSorcery.Sys @@ -54,6 +55,26 @@ public static ushort ParseUInt16(byte[] buffer, int posn) return (ushort)(buffer[posn] << 8 | buffer[posn + 1]); } + /// + /// Parse a UInt16 from a network buffer using network byte order. + /// + /// The buffer to parse the value from. + /// The position in the buffer to start the parse from. + /// A UInt16 value. + public static ushort ParseUInt16(ReadOnlySpan buffer, int posn) + { + return (ushort)(buffer[posn] << 8 | buffer[posn + 1]); + } + + /// + /// Parse a UInt16 from a network buffer using network byte order. + /// + /// The buffer to parse the value from. + public static ushort ParseUInt16(ReadOnlySpan buffer) + { + return (ushort)(buffer[0] << 8 | buffer[1]); + } + /// /// Parse a UInt32 from a network buffer using network byte order. /// @@ -65,6 +86,28 @@ public static uint ParseUInt32(byte[] buffer, int posn) return (uint)(buffer[posn] << 24 | buffer[posn + 1] << 16 | buffer[posn + 2] << 8 | buffer[posn + 3]); } + /// + /// Parse a UInt32 from a network buffer using network byte order. + /// + /// The buffer to parse the value from. + /// The position in the buffer to start the parse from. + /// A UInt32 value. + public static uint ParseUInt32(ReadOnlySpan buffer, int posn) + { + return (uint)(buffer[posn] << 24 | buffer[posn + 1] << 16 | buffer[posn + 2] << 8 | buffer[posn + 3]); + } + + /// + /// Parse a UInt32 from a network buffer using network byte order. + /// + /// The buffer to parse the value from. + /// The position in the buffer to start the parse from. + /// A UInt32 value. + public static uint ParseUInt32(ReadOnlySpan buffer) + { + return (uint)(buffer[0] << 24 | buffer[1] << 16 | buffer[2] << 8 | buffer[3]); + } + /// /// Parse a UInt64 from a network buffer using network byte order. /// @@ -101,6 +144,17 @@ public static void ToBuffer(ushort val, byte[] buffer, int posn) buffer[posn + 1] = (byte)val; } + /// + /// Writes a UInt16 value to a network buffer using network byte order. + /// + /// The value to write to the buffer. + /// The buffer to write the value to. + /// The start position in the buffer to write the value at. + public static void ToBuffer(ushort val, Span buffer, int posn) + { + BinaryPrimitives.WriteUInt16BigEndian(buffer.Slice(posn), val); + } + /// /// Get a buffer representing the unsigned 16 bit integer in network /// byte (big endian) order. @@ -133,6 +187,17 @@ public static void ToBuffer(uint val, byte[] buffer, int posn) buffer[posn + 3] = (byte)val; } + /// + /// Writes a UInt16 value to a network buffer using network byte order. + /// + /// The value to write to the buffer. + /// The buffer to write the value to. + /// The start position in the buffer to write the value at. + public static void ToBuffer(uint val, Span buffer, int posn) + { + BinaryPrimitives.WriteUInt32BigEndian(buffer.Slice(posn), val); + } + /// /// Get a buffer representing the 32 bit unsigned integer in network /// byte (big endian) order. diff --git a/src/sys/OnOff.cs b/src/sys/OnOff.cs new file mode 100644 index 000000000..4dbb97cdf --- /dev/null +++ b/src/sys/OnOff.cs @@ -0,0 +1,13 @@ +using System.Threading; + +namespace SIPSorcery.Sys; + +/// A thread-safe struct that represents an on/off state. +struct OnOff +{ + int on; + + public bool TryTurnOn() => Interlocked.CompareExchange(ref on, 1, comparand: 0) == 0; + public bool TryTurnOff() => Interlocked.CompareExchange(ref on, 0, comparand: 1) == 1; + public bool IsOn() => Interlocked.CompareExchange(ref on, 0, 0) == 1; +} diff --git a/src/sys/Once.cs b/src/sys/Once.cs new file mode 100644 index 000000000..b10ef7320 --- /dev/null +++ b/src/sys/Once.cs @@ -0,0 +1,36 @@ +using System; +using System.Threading; + +namespace SIPSorcery.Sys +{ + /// + /// A thread-safe struct that represents a one-time event. + /// + struct Once + { + int occured; + + /// + /// Gets a value indicating whether the event has occurred or not. + /// + public bool HasOccurred => Interlocked.CompareExchange(ref this.occured, 0, 0) != 0; + + /// + /// Tries to mark the event as occurred and returns true if successful. + /// Returns false if the even has already been marked. + /// + public bool TryMarkOccurred() => Interlocked.CompareExchange(ref this.occured, 1, comparand: 0) == 0; + + /// + /// Marks the event as occurred and throws if it has already occurred. + /// + /// If the event has already occurred. + public void MarkOccurred() + { + if (!this.TryMarkOccurred()) + { + throw new InvalidOperationException("Can only be called once"); + } + } + } +} diff --git a/src/sys/TypeExtensions.cs b/src/sys/TypeExtensions.cs index b78c422e7..d4ec8207b 100644 --- a/src/sys/TypeExtensions.cs +++ b/src/sys/TypeExtensions.cs @@ -115,12 +115,25 @@ public static string Slice(this string s, char startDelimiter, char endDelimeter } } + public static string HexStr(this ReadOnlySpan buffer, char? separator = null) + { + return HexStr(buffer, buffer.Length, separator); + } + public static string HexStr(this Span buffer, char? separator = null) + { + return HexStr(buffer, buffer.Length, separator); + } + public static string HexStr(this byte[] buffer, char? separator = null) { - return buffer.HexStr(buffer.Length, separator); + return HexStr(buffer, buffer.Length, separator); } public static string HexStr(this byte[] buffer, int length, char? separator = null) + { + return HexStr(buffer.AsSpan(), length, separator); + } + public static string HexStr(this ReadOnlySpan buffer, int length, char? separator = null) { string rv = string.Empty; diff --git a/test/integration/SIPSorcery.IntegrationTests.csproj b/test/integration/SIPSorcery.IntegrationTests.csproj index f07661375..19bdb5210 100755 --- a/test/integration/SIPSorcery.IntegrationTests.csproj +++ b/test/integration/SIPSorcery.IntegrationTests.csproj @@ -19,9 +19,10 @@ - - - + + + + all runtime; build; native; contentfiles; analyzers; buildtransitive diff --git a/test/integration/net/DtlsSrtp/DtlsSrtpTransportUnitTest.cs b/test/integration/net/DtlsSrtp/DtlsSrtpTransportUnitTest.cs index c575162a7..176b0df82 100644 --- a/test/integration/net/DtlsSrtp/DtlsSrtpTransportUnitTest.cs +++ b/test/integration/net/DtlsSrtp/DtlsSrtpTransportUnitTest.cs @@ -14,6 +14,7 @@ using System.Threading.Tasks; using Microsoft.Extensions.Logging; +using Org.BouncyCastle.Tls.Crypto.Impl.BC; using Xunit; namespace SIPSorcery.Net.IntegrationTests @@ -37,8 +38,9 @@ public void CreateClientInstanceUnitTest() logger.LogDebug("--> " + System.Reflection.MethodBase.GetCurrentMethod().Name); logger.BeginScope(System.Reflection.MethodBase.GetCurrentMethod().Name); - (var tlsCert, var pvtKey) = DtlsUtils.CreateSelfSignedTlsCert(); - DtlsSrtpTransport dtlsTransport = new DtlsSrtpTransport(new DtlsSrtpClient(tlsCert, pvtKey)); + var crypto = new BcTlsCrypto(); + (var tlsCert, var pvtKey) = DtlsUtils.CreateSelfSignedTlsCert(crypto); + DtlsSrtpTransport dtlsTransport = new DtlsSrtpTransport(new DtlsSrtpClient(crypto, tlsCert, pvtKey)); Assert.NotNull(dtlsTransport); } @@ -52,7 +54,7 @@ public void CreateServerInstanceUnitTest() logger.LogDebug("--> " + System.Reflection.MethodBase.GetCurrentMethod().Name); logger.BeginScope(System.Reflection.MethodBase.GetCurrentMethod().Name); - DtlsSrtpTransport dtlsTransport = new DtlsSrtpTransport(new DtlsSrtpServer()); + DtlsSrtpTransport dtlsTransport = new DtlsSrtpTransport(new DtlsSrtpServer(new BcTlsCrypto())); Assert.NotNull(dtlsTransport); } @@ -67,8 +69,8 @@ public void DoHandshakeUnitTest() logger.LogDebug("--> " + System.Reflection.MethodBase.GetCurrentMethod().Name); logger.BeginScope(System.Reflection.MethodBase.GetCurrentMethod().Name); - var dtlsClient = new DtlsSrtpClient(); - var dtlsServer = new DtlsSrtpServer(); + var dtlsClient = new DtlsSrtpClient(new BcTlsCrypto()); + var dtlsServer = new DtlsSrtpServer(new BcTlsCrypto()); DtlsSrtpTransport dtlsClientTransport = new DtlsSrtpTransport(dtlsClient); dtlsClientTransport.TimeoutMilliseconds = 5000; @@ -117,7 +119,7 @@ public async void DoHandshakeClientTimeoutUnitTest() logger.LogDebug("--> " + System.Reflection.MethodBase.GetCurrentMethod().Name); logger.BeginScope(System.Reflection.MethodBase.GetCurrentMethod().Name); - DtlsSrtpTransport dtlsClientTransport = new DtlsSrtpTransport(new DtlsSrtpClient()); + DtlsSrtpTransport dtlsClientTransport = new DtlsSrtpTransport(new DtlsSrtpClient(new BcTlsCrypto())); dtlsClientTransport.TimeoutMilliseconds = 2000; var result = await Task.Run(() => dtlsClientTransport.DoHandshake(out _)); @@ -134,7 +136,7 @@ public async void DoHandshakeServerTimeoutUnitTest() logger.LogDebug("--> " + System.Reflection.MethodBase.GetCurrentMethod().Name); logger.BeginScope(System.Reflection.MethodBase.GetCurrentMethod().Name); - DtlsSrtpTransport dtlsServerTransport = new DtlsSrtpTransport(new DtlsSrtpServer()); + DtlsSrtpTransport dtlsServerTransport = new DtlsSrtpTransport(new DtlsSrtpServer(new BcTlsCrypto())); dtlsServerTransport.TimeoutMilliseconds = 2000; var result = await Task.Run(() => dtlsServerTransport.DoHandshake(out _)); diff --git a/test/integration/net/DtlsSrtp/DtlsUtilsUnitTest.cs b/test/integration/net/DtlsSrtp/DtlsUtilsUnitTest.cs index b6ff14e3d..92e2090d6 100644 --- a/test/integration/net/DtlsSrtp/DtlsUtilsUnitTest.cs +++ b/test/integration/net/DtlsSrtp/DtlsUtilsUnitTest.cs @@ -14,6 +14,7 @@ using System.Runtime.InteropServices; using System.Security.Cryptography.X509Certificates; using Microsoft.Extensions.Logging; +using Org.BouncyCastle.Tls.Crypto.Impl.BC; using Xunit; namespace SIPSorcery.Net.IntegrationTests @@ -25,10 +26,12 @@ namespace SIPSorcery.Net.IntegrationTests public class DtlsUtilsUnitTest { private Microsoft.Extensions.Logging.ILogger logger = null; + private BcTlsCrypto crypto = null; public DtlsUtilsUnitTest(Xunit.Abstractions.ITestOutputHelper output) { logger = SIPSorcery.UnitTests.TestLogHelper.InitTestLogger(output); + crypto = new BcTlsCrypto(); } /// @@ -40,7 +43,7 @@ public void CreateSelfSignedCertifcateUnitTest() logger.LogDebug("--> " + System.Reflection.MethodBase.GetCurrentMethod().Name); logger.BeginScope(System.Reflection.MethodBase.GetCurrentMethod().Name); - (var tlsCert, var pvtKey) = DtlsUtils.CreateSelfSignedTlsCert(); + (var tlsCert, var pvtKey) = DtlsUtils.CreateSelfSignedTlsCert(crypto); logger.LogDebug(tlsCert.ToString()); @@ -57,7 +60,7 @@ public void GetCertifcateFingerprintUnitTest() logger.LogDebug("--> " + System.Reflection.MethodBase.GetCurrentMethod().Name); logger.BeginScope(System.Reflection.MethodBase.GetCurrentMethod().Name); - (var tlsCert, var pvtKey) = DtlsUtils.CreateSelfSignedTlsCert(); + (var tlsCert, var pvtKey) = DtlsUtils.CreateSelfSignedTlsCert(crypto); Assert.NotNull(tlsCert); var fingerprint = DtlsUtils.Fingerprint(tlsCert); @@ -115,7 +118,7 @@ public void BouncyCertFromCoreFxCert() Assert.NotNull(coreFxCert); Assert.NotNull(coreFxCert.PrivateKey); - string coreFxFingerprint = DtlsUtils.Fingerprint(coreFxCert).ToString(); + string coreFxFingerprint = DtlsUtils.Fingerprint(crypto, coreFxCert).ToString(); logger.LogDebug($"Core FX certificate fingerprint {coreFxFingerprint}."); var bcCert = Org.BouncyCastle.Security.DotNetUtilities.FromX509Certificate(coreFxCert); diff --git a/test/integration/net/ICE/MockTurnServer.cs b/test/integration/net/ICE/MockTurnServer.cs index f56486aa9..e0c796eec 100644 --- a/test/integration/net/ICE/MockTurnServer.cs +++ b/test/integration/net/ICE/MockTurnServer.cs @@ -60,7 +60,7 @@ public MockTurnServer(IPAddress listenAddress, int port) _listener.BeginReceiveFrom(); } - private void OnPacketReceived(UdpReceiver receiver, int localPort, IPEndPoint remoteEndPoint, byte[] packet) + private void OnPacketReceived(UdpReceiver receiver, int localPort, IPEndPoint remoteEndPoint, ReadOnlySpan packet) { STUNMessage stunMessage = STUNMessage.ParseSTUNMessage(packet, packet.Length); @@ -140,10 +140,10 @@ private void OnPacketReceived(UdpReceiver receiver, int localPort, IPEndPoint re /// The port number the packet was received on. /// The end point of the peer sending traffic to the TURN server. /// The byes received from the peer. - private void OnRelayPacketReceived(UdpReceiver receiver, int localPort, IPEndPoint remoteEndPoint, byte[] packet) + private void OnRelayPacketReceived(UdpReceiver receiver, int localPort, IPEndPoint remoteEndPoint, ReadOnlySpan packet) { STUNMessage dataInd = new STUNMessage(STUNMessageTypesEnum.DataIndication); - dataInd.Attributes.Add(new STUNAttribute(STUNAttributeTypesEnum.Data, packet)); + dataInd.Attributes.Add(new STUNAttribute(STUNAttributeTypesEnum.Data, packet.ToArray())); dataInd.AddXORPeerAddressAttribute(remoteEndPoint.Address, remoteEndPoint.Port); _clientSocket.SendTo(dataInd.ToByteBuffer(null, false), _clientEndPoint); diff --git a/test/unit/Initialise.cs b/test/unit/Initialise.cs index f303527a4..a904c9796 100644 --- a/test/unit/Initialise.cs +++ b/test/unit/Initialise.cs @@ -40,6 +40,7 @@ public static Microsoft.Extensions.Logging.ILogger InitTestLogger(Xunit.Abstract .MinimumLevel.Is(Serilog.Events.LogEventLevel.Verbose) .Enrich.WithProperty("ThreadId", System.Threading.Thread.CurrentThread.ManagedThreadId) .WriteTo.TestOutput(output, outputTemplate: template) + .WriteTo.Debug(outputTemplate: template) .WriteTo.Console(outputTemplate: template) .CreateLogger(); SIPSorcery.LogFactory.Set(new SerilogLoggerFactory(serilog)); diff --git a/test/unit/SIPSorcery.UnitTests.csproj b/test/unit/SIPSorcery.UnitTests.csproj index 4b40cfc2e..fa810bc73 100755 --- a/test/unit/SIPSorcery.UnitTests.csproj +++ b/test/unit/SIPSorcery.UnitTests.csproj @@ -19,9 +19,10 @@ - - - + + + + all runtime; build; native; contentfiles; analyzers; buildtransitive diff --git a/test/unit/net/SCTP/SctpAssociationUnitTest.cs b/test/unit/net/SCTP/SctpAssociationUnitTest.cs index 689c670ed..6bee215fd 100644 --- a/test/unit/net/SCTP/SctpAssociationUnitTest.cs +++ b/test/unit/net/SCTP/SctpAssociationUnitTest.cs @@ -98,7 +98,7 @@ public void SendDataChunk() string message = "hello world"; var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); - bAssoc.OnData += (frame) => tcs.TrySetResult(Encoding.UTF8.GetString(frame.UserData)); + bAssoc.OnData += (frame) => tcs.TrySetResult(Encoding.UTF8.GetString(frame.UserData.ToArray())); aAssoc.SendData(0, 0, Encoding.UTF8.GetBytes(message)); tcs.Task.Wait(3000); @@ -123,7 +123,7 @@ public void SendFragmentedDataChunk() string message = "hello world"; var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); - bAssoc.OnData += (frame) => tcs.TrySetResult(Encoding.UTF8.GetString(frame.UserData)); + bAssoc.OnData += (frame) => tcs.TrySetResult(Encoding.UTF8.GetString(frame.UserData.ToArray())); aAssoc.SendData(0, 0, Encoding.UTF8.GetBytes(message)); tcs.Task.Wait(3000); @@ -228,7 +228,9 @@ internal class MockB2BSctpTransport : SctpTransport private bool _exit; - public event Action OnSctpPacket; + public delegate void SctpPacketViewAction(SctpPacketView pkt); + + public event SctpPacketViewAction OnSctpPacket; public event Action OnCookieEcho; public MockB2BSctpTransport(BlockingCollection output, BlockingCollection input) @@ -243,19 +245,19 @@ public void Listen() { if (_input.TryTake(out var buffer, 1000)) { - SctpPacket pkt = SctpPacket.Parse(buffer, 0, buffer.Length); + SctpPacket pkt = SctpPacket.Parse(buffer); // Process packet. if (pkt.Chunks.Any(x => x.KnownType == SctpChunkType.INIT)) { var initAckPacket = base.GetInitAck(pkt, null); var initAckBuffer = initAckPacket.GetBytes(); - Send(null, initAckBuffer, 0, initAckBuffer.Length); + Send(null, initAckBuffer); } else if (pkt.Chunks.Any(x => x.KnownType == SctpChunkType.COOKIE_ECHO)) { var cookieEcho = pkt.Chunks.Single(x => x.KnownType == SctpChunkType.COOKIE_ECHO); - var cookie = base.GetCookie(pkt); + var cookie = base.GetCookie(SctpPacketView.Parse(buffer)); if (cookie.IsEmpty()) { throw new ApplicationException($"MockB2BSctpTransport gave itself an invalid INIT cookie."); @@ -267,15 +269,15 @@ public void Listen() } else { - OnSctpPacket?.Invoke(pkt); + OnSctpPacket?.Invoke(SctpPacketView.Parse(buffer)); } } } } - public override void Send(string associationID, byte[] buffer, int offset, int length) + public override void Send(string associationID, ReadOnlySpan buffer) { - _output.Add(buffer.Skip(offset).Take(length).ToArray()); + _output.Add(buffer.ToArray()); } public void Close() diff --git a/test/unit/net/SCTP/SctpChunkUnitTest.cs b/test/unit/net/SCTP/SctpChunkUnitTest.cs index d9188fe1d..ed78c344e 100644 --- a/test/unit/net/SCTP/SctpChunkUnitTest.cs +++ b/test/unit/net/SCTP/SctpChunkUnitTest.cs @@ -63,7 +63,7 @@ public void ParseSACKChunk() { var sackBuffer = BufferUtils.ParseHexStr("13881388E48092946AB2050003000014D19244F60002000000000001A7498379"); - var sackPkt = SctpPacket.Parse(sackBuffer, 0, sackBuffer.Length); + var sackPkt = SctpPacket.Parse(sackBuffer); Assert.NotNull(sackPkt); Assert.Single(sackPkt.Chunks); diff --git a/test/unit/net/SCTP/SctpDataReceiverUnitTest.cs b/test/unit/net/SCTP/SctpDataReceiverUnitTest.cs index deb35614b..916cc9a4e 100644 --- a/test/unit/net/SCTP/SctpDataReceiverUnitTest.cs +++ b/test/unit/net/SCTP/SctpDataReceiverUnitTest.cs @@ -39,7 +39,7 @@ public void SinglePacketFrame() SctpDataChunk chunk = new SctpDataChunk(false, true, true, 0, 0, 0, 0, new byte[] { 0x00 }); - var sortedFrames = receiver.OnDataChunk(chunk); + var sortedFrames = receiver.OnDataChunk(chunk.View()); Assert.Single(sortedFrames); Assert.Equal("00", sortedFrames.Single().UserData.HexStr()); @@ -59,11 +59,11 @@ public void ThreeFragments() SctpDataChunk chunk2 = new SctpDataChunk(false, false, false, 1, 0, 0, 0, new byte[] { 0x01 }); SctpDataChunk chunk3 = new SctpDataChunk(false, false, true, 2, 0, 0, 0, new byte[] { 0x02 }); - var sortFrames1 = receiver.OnDataChunk(chunk1); + var sortFrames1 = receiver.OnDataChunk(chunk1.View()); Assert.Equal(0U, receiver.CumulativeAckTSN); - var sortFrames2 = receiver.OnDataChunk(chunk2); + var sortFrames2 = receiver.OnDataChunk(chunk2.View()); Assert.Equal(1U, receiver.CumulativeAckTSN); - var sortFrames3 = receiver.OnDataChunk(chunk3); + var sortFrames3 = receiver.OnDataChunk(chunk3.View()); Assert.Equal(2U, receiver.CumulativeAckTSN); Assert.Empty(sortFrames1); @@ -86,11 +86,11 @@ public void ThreeFragmentsOutOfOrder() SctpDataChunk chunk2 = new SctpDataChunk(false, false, false, 1, 0, 0, 0, new byte[] { 0x01 }); SctpDataChunk chunk3 = new SctpDataChunk(false, false, true, 2, 0, 0, 0, new byte[] { 0x02 }); - var sortFrames1 = receiver.OnDataChunk(chunk1); + var sortFrames1 = receiver.OnDataChunk(chunk1.View()); Assert.Equal(0U, receiver.CumulativeAckTSN); - var sortFrames2 = receiver.OnDataChunk(chunk3); + var sortFrames2 = receiver.OnDataChunk(chunk3.View()); Assert.Equal(0U, receiver.CumulativeAckTSN); - var sortFrames3 = receiver.OnDataChunk(chunk2); + var sortFrames3 = receiver.OnDataChunk(chunk2.View()); Assert.Equal(2U, receiver.CumulativeAckTSN); Assert.Empty(sortFrames1); @@ -113,11 +113,11 @@ public void ThreeFragmentsBeginLast() SctpDataChunk chunk2 = new SctpDataChunk(false, false, false, 1, 0, 0, 0, new byte[] { 0x01 }); SctpDataChunk chunk3 = new SctpDataChunk(false, false, true, 2, 0, 0, 0, new byte[] { 0x02 }); - var sortFrames1 = receiver.OnDataChunk(chunk3); + var sortFrames1 = receiver.OnDataChunk(chunk3.View()); Assert.Null(receiver.CumulativeAckTSN); - var sortFrames2 = receiver.OnDataChunk(chunk2); + var sortFrames2 = receiver.OnDataChunk(chunk2.View()); Assert.Null(receiver.CumulativeAckTSN); - var sortFrames3 = receiver.OnDataChunk(chunk1); + var sortFrames3 = receiver.OnDataChunk(chunk1.View()); Assert.Equal(2U, receiver.CumulativeAckTSN); Assert.Empty(sortFrames1); @@ -141,15 +141,15 @@ public void FragmentWithTSNWrap() SctpDataChunk chunk4 = new SctpDataChunk(false, false, false, 0, 0, 0, 0, new byte[] { 0x03 }); SctpDataChunk chunk5 = new SctpDataChunk(false, false, true, 1, 0, 0, 0, new byte[] { 0x04 }); - var sFrames1 = receiver.OnDataChunk(chunk1); + var sFrames1 = receiver.OnDataChunk(chunk1.View()); Assert.Equal(uint.MaxValue - 2, receiver.CumulativeAckTSN); - var sFrames2 = receiver.OnDataChunk(chunk2); + var sFrames2 = receiver.OnDataChunk(chunk2.View()); Assert.Equal(uint.MaxValue - 1, receiver.CumulativeAckTSN); - var sFrames3 = receiver.OnDataChunk(chunk3); + var sFrames3 = receiver.OnDataChunk(chunk3.View()); Assert.Equal(uint.MaxValue, receiver.CumulativeAckTSN); - var sFrames4 = receiver.OnDataChunk(chunk4); + var sFrames4 = receiver.OnDataChunk(chunk4.View()); Assert.Equal(0U, receiver.CumulativeAckTSN); - var sFrames5 = receiver.OnDataChunk(chunk5); + var sFrames5 = receiver.OnDataChunk(chunk5.View()); Assert.Equal(1U, receiver.CumulativeAckTSN); Assert.Empty(sFrames1); @@ -180,19 +180,19 @@ public void FragmentWithTSNWrapAndOutOfOrder() SctpDataChunk chunk6 = new SctpDataChunk(true, true, true, 6, 0, 0, 0, new byte[] { 0x06 }); SctpDataChunk chunk9 = new SctpDataChunk(true, true, true, 9, 0, 0, 0, new byte[] { 0x09 }); - var sframes9 = receiver.OnDataChunk(chunk9); + var sframes9 = receiver.OnDataChunk(chunk9.View()); Assert.Null(receiver.CumulativeAckTSN); - var sframes1 = receiver.OnDataChunk(chunk1); + var sframes1 = receiver.OnDataChunk(chunk1.View()); Assert.Equal(uint.MaxValue - 2, receiver.CumulativeAckTSN); - var sframes2 = receiver.OnDataChunk(chunk2); + var sframes2 = receiver.OnDataChunk(chunk2.View()); Assert.Equal(uint.MaxValue - 1, receiver.CumulativeAckTSN); - var sframes3 = receiver.OnDataChunk(chunk3); + var sframes3 = receiver.OnDataChunk(chunk3.View()); Assert.Equal(uint.MaxValue, receiver.CumulativeAckTSN); - var sframes6 = receiver.OnDataChunk(chunk6); + var sframes6 = receiver.OnDataChunk(chunk6.View()); Assert.Equal(uint.MaxValue, receiver.CumulativeAckTSN); - var sframes4 = receiver.OnDataChunk(chunk4); + var sframes4 = receiver.OnDataChunk(chunk4.View()); Assert.Equal(0U, receiver.CumulativeAckTSN); - var sframes5 = receiver.OnDataChunk(chunk5); + var sframes5 = receiver.OnDataChunk(chunk5.View()); Assert.Equal(1U, receiver.CumulativeAckTSN); Assert.Empty(sframes1); @@ -221,11 +221,11 @@ public void FragmentWithExpectedTSNWrap() SctpDataChunk chunk4 = new SctpDataChunk(false, false, false, 0, 0, 0, 0, new byte[] { 0x03 }); SctpDataChunk chunk5 = new SctpDataChunk(false, false, true, 1, 0, 0, 0, new byte[] { 0x04 }); - var sframes1 = receiver.OnDataChunk(chunk1); - var sframes2 = receiver.OnDataChunk(chunk2); - var sframes3 = receiver.OnDataChunk(chunk3); - var sframes4 = receiver.OnDataChunk(chunk4); - var sframes5 = receiver.OnDataChunk(chunk5); + var sframes1 = receiver.OnDataChunk(chunk1.View()); + var sframes2 = receiver.OnDataChunk(chunk2.View()); + var sframes3 = receiver.OnDataChunk(chunk3.View()); + var sframes4 = receiver.OnDataChunk(chunk4.View()); + var sframes5 = receiver.OnDataChunk(chunk5.View()); Assert.Empty(sframes1); Assert.Empty(sframes2); @@ -290,7 +290,7 @@ public void CheckExpiryWithSinglePacketChunksUnordered() { SctpDataChunk chunk = new SctpDataChunk(true, true, true, tsn++, 0, 0, 0, new byte[] { 0x55 }); - var sortedFrames = receiver.OnDataChunk(chunk); + var sortedFrames = receiver.OnDataChunk(chunk.View()); Assert.Single(sortedFrames); Assert.Equal("55", sortedFrames.Single().UserData.HexStr()); @@ -316,7 +316,7 @@ public void CheckExpiryWithSinglePacketChunksOrdered() { SctpDataChunk chunk = new SctpDataChunk(false, true, true, tsn++, 0, streamSeqnum++, 0, new byte[] { 0x55 }); - var sortedFrames = receiver.OnDataChunk(chunk); + var sortedFrames = receiver.OnDataChunk(chunk.View()); Assert.Single(sortedFrames); Assert.Equal("55", sortedFrames.Single().UserData.HexStr()); @@ -337,9 +337,9 @@ public void ThreeStreamPackets() SctpDataChunk chunk2 = new SctpDataChunk(false, true, true, 1, 0, 1, 0, new byte[] { 0x01 }); SctpDataChunk chunk3 = new SctpDataChunk(false, true, true, 2, 0, 2, 0, new byte[] { 0x02 }); - var sortFrames1 = receiver.OnDataChunk(chunk1); - var sortFrames2 = receiver.OnDataChunk(chunk2); - var sortFrames3 = receiver.OnDataChunk(chunk3); + var sortFrames1 = receiver.OnDataChunk(chunk1.View()); + var sortFrames2 = receiver.OnDataChunk(chunk2.View()); + var sortFrames3 = receiver.OnDataChunk(chunk3.View()); Assert.Single(sortFrames1); Assert.Equal(0, sortFrames1.Single().StreamSeqNum); @@ -367,10 +367,10 @@ public void StreamPacketsReceviedOutOfOrder() SctpDataChunk chunk2 = new SctpDataChunk(false, true, true, 1, 0, 1, 0, new byte[] { 0x01 }); SctpDataChunk chunk3 = new SctpDataChunk(false, true, true, 2, 0, 2, 0, new byte[] { 0x02 }); - var sortFrames0 = receiver.OnDataChunk(chunk0); - var sortFrames1 = receiver.OnDataChunk(chunk3); - var sortFrames2 = receiver.OnDataChunk(chunk2); - var sortFrames3 = receiver.OnDataChunk(chunk1); + var sortFrames0 = receiver.OnDataChunk(chunk0.View()); + var sortFrames1 = receiver.OnDataChunk(chunk3.View()); + var sortFrames2 = receiver.OnDataChunk(chunk2.View()); + var sortFrames3 = receiver.OnDataChunk(chunk1.View()); Assert.Single(sortFrames0); Assert.Empty(sortFrames1); @@ -390,8 +390,8 @@ public void StreamPacketsReceviedOutOfOrder() public void GetSingleGapReport() { SctpDataReceiver receiver = new SctpDataReceiver(0, 0, 25); - receiver.OnDataChunk(new SctpDataChunk(true, true, true, 25, 0, 0, 0, new byte[] { 0x33 })); - receiver.OnDataChunk(new SctpDataChunk(true, true, true, 30, 0, 0, 0, new byte[] { 0x33 })); + receiver.OnDataChunk(new SctpDataChunk(true, true, true, 25, 0, 0, 0, new byte[] { 0x33 }).View()); + receiver.OnDataChunk(new SctpDataChunk(true, true, true, 30, 0, 0, 0, new byte[] { 0x33 }).View()); var gapReports = receiver.GetForwardTSNGaps(); @@ -412,8 +412,8 @@ public void GetSingleGapReportWithWrap() { uint initialTSN = uint.MaxValue - 2; SctpDataReceiver receiver = new SctpDataReceiver(0, 0, initialTSN); - receiver.OnDataChunk(new SctpDataChunk(true, true, true, initialTSN, 0, 0, 0, new byte[] { 0x33 })); - receiver.OnDataChunk(new SctpDataChunk(true, true, true, 2, 0, 0, 0, new byte[] { 0x33 })); + receiver.OnDataChunk(new SctpDataChunk(true, true, true, initialTSN, 0, 0, 0, new byte[] { 0x33 }).View()); + receiver.OnDataChunk(new SctpDataChunk(true, true, true, 2, 0, 0, 0, new byte[] { 0x33 }).View()); var gapReports = receiver.GetForwardTSNGaps(); @@ -432,10 +432,10 @@ public void GetSingleGapReportWithWrap() public void GetTwoGapReports() { SctpDataReceiver receiver = new SctpDataReceiver(0, 0, 15005); - receiver.OnDataChunk(new SctpDataChunk(true, true, true, 15005, 0, 0, 0, new byte[] { 0x33 })); - receiver.OnDataChunk(new SctpDataChunk(true, true, true, 15007, 0, 0, 0, new byte[] { 0x33 })); - receiver.OnDataChunk(new SctpDataChunk(true, true, true, 15008, 0, 0, 0, new byte[] { 0x33 })); - receiver.OnDataChunk(new SctpDataChunk(true, true, true, 15010, 0, 0, 0, new byte[] { 0x33 })); + receiver.OnDataChunk(new SctpDataChunk(true, true, true, 15005, 0, 0, 0, new byte[] { 0x33 }).View()); + receiver.OnDataChunk(new SctpDataChunk(true, true, true, 15007, 0, 0, 0, new byte[] { 0x33 }).View()); + receiver.OnDataChunk(new SctpDataChunk(true, true, true, 15008, 0, 0, 0, new byte[] { 0x33 }).View()); + receiver.OnDataChunk(new SctpDataChunk(true, true, true, 15010, 0, 0, 0, new byte[] { 0x33 }).View()); var gapReports = receiver.GetForwardTSNGaps(); @@ -451,13 +451,13 @@ public void GetTwoGapReports() public void GetThreeGapReports() { SctpDataReceiver receiver = new SctpDataReceiver(0, 0, 3); - receiver.OnDataChunk(new SctpDataChunk(true, true, true, 3, 0, 0, 0, new byte[] { 0x33 })); - receiver.OnDataChunk(new SctpDataChunk(true, true, true, 7, 0, 0, 0, new byte[] { 0x33 })); - receiver.OnDataChunk(new SctpDataChunk(true, true, true, 8, 0, 0, 0, new byte[] { 0x33 })); - receiver.OnDataChunk(new SctpDataChunk(true, true, true, 9, 0, 0, 0, new byte[] { 0x33 })); - receiver.OnDataChunk(new SctpDataChunk(true, true, true, 11, 0, 0, 0, new byte[] { 0x33 })); - receiver.OnDataChunk(new SctpDataChunk(true, true, true, 12, 0, 0, 0, new byte[] { 0x33 })); - receiver.OnDataChunk(new SctpDataChunk(true, true, true, 14, 0, 0, 0, new byte[] { 0x33 })); + receiver.OnDataChunk(new SctpDataChunk(true, true, true, 3, 0, 0, 0, new byte[] { 0x33 }).View()); + receiver.OnDataChunk(new SctpDataChunk(true, true, true, 7, 0, 0, 0, new byte[] { 0x33 }).View()); + receiver.OnDataChunk(new SctpDataChunk(true, true, true, 8, 0, 0, 0, new byte[] { 0x33 }).View()); + receiver.OnDataChunk(new SctpDataChunk(true, true, true, 9, 0, 0, 0, new byte[] { 0x33 }).View()); + receiver.OnDataChunk(new SctpDataChunk(true, true, true, 11, 0, 0, 0, new byte[] { 0x33 }).View()); + receiver.OnDataChunk(new SctpDataChunk(true, true, true, 12, 0, 0, 0, new byte[] { 0x33 }).View()); + receiver.OnDataChunk(new SctpDataChunk(true, true, true, 14, 0, 0, 0, new byte[] { 0x33 }).View()); var gapReports = receiver.GetForwardTSNGaps(); @@ -478,11 +478,11 @@ public void GeGapReportWithDuplicateForwardTSN() SctpDataReceiver receiver = new SctpDataReceiver(0, 0, initialTSN); // Forward TSN. - receiver.OnDataChunk(new SctpDataChunk(true, true, true, initialTSN + 1, 0, 0, 0, new byte[] { 0x33 })); + receiver.OnDataChunk(new SctpDataChunk(true, true, true, initialTSN + 1, 0, 0, 0, new byte[] { 0x33 }).View()); // Initial expected TSN. - receiver.OnDataChunk(new SctpDataChunk(true, true, true, initialTSN, 0, 0, 0, new byte[] { 0x33 })); + receiver.OnDataChunk(new SctpDataChunk(true, true, true, initialTSN, 0, 0, 0, new byte[] { 0x33 }).View()); // Duplicate of first received TSN. - receiver.OnDataChunk(new SctpDataChunk(true, true, true, initialTSN + 1, 0, 0, 0, new byte[] { 0x33 })); + receiver.OnDataChunk(new SctpDataChunk(true, true, true, initialTSN + 1, 0, 0, 0, new byte[] { 0x33 }).View()); var gapReports = receiver.GetForwardTSNGaps(); @@ -501,11 +501,11 @@ public void GetSackForSingleMissingChunk() SctpDataReceiver receiver = new SctpDataReceiver(arwnd, mtu, initialTSN); - receiver.OnDataChunk(new SctpDataChunk(true, true, true, initialTSN, 0, 0, 0, new byte[] { 0x44 })); + receiver.OnDataChunk(new SctpDataChunk(true, true, true, initialTSN, 0, 0, 0, new byte[] { 0x44 }).View()); Assert.Equal(initialTSN, receiver.CumulativeAckTSN); // Simulate a missing chunk by incrementing the TSN by 2. - receiver.OnDataChunk(new SctpDataChunk(true, true, true, initialTSN + 2, 0, 0, 0, new byte[] { 0x44 })); + receiver.OnDataChunk(new SctpDataChunk(true, true, true, initialTSN + 2, 0, 0, 0, new byte[] { 0x44 }).View()); Assert.Equal(initialTSN, receiver.CumulativeAckTSN); var sack = receiver.GetSackChunk(); @@ -527,7 +527,7 @@ public void GetSackForInitialChunkMissing() SctpDataReceiver receiver = new SctpDataReceiver(arwnd, mtu, initialTSN); - receiver.OnDataChunk(new SctpDataChunk(true, true, true, initialTSN + 1, 0, 0, 0, new byte[] { 0x44 })); + receiver.OnDataChunk(new SctpDataChunk(true, true, true, initialTSN + 1, 0, 0, 0, new byte[] { 0x44 }).View()); Assert.Null(receiver.CumulativeAckTSN); Assert.Null(receiver.GetSackChunk()); } @@ -545,12 +545,12 @@ public void InitialChunkOutOfOrder() SctpDataReceiver receiver = new SctpDataReceiver(arwnd, mtu, initialTSN); // Skip initial DATA chunk. - receiver.OnDataChunk(new SctpDataChunk(true, true, true, initialTSN + 1, 0, 0, 0, new byte[] { 0x44 })); + receiver.OnDataChunk(new SctpDataChunk(true, true, true, initialTSN + 1, 0, 0, 0, new byte[] { 0x44 }).View()); Assert.Null(receiver.CumulativeAckTSN); Assert.Null(receiver.GetSackChunk()); // Give the receiver the initial DATA chunk. - receiver.OnDataChunk(new SctpDataChunk(true, true, true, initialTSN, 0, 0, 0, new byte[] { 0x44 })); + receiver.OnDataChunk(new SctpDataChunk(true, true, true, initialTSN, 0, 0, 0, new byte[] { 0x44 }).View()); Assert.Equal(initialTSN + 1, receiver.CumulativeAckTSN); var sack = receiver.GetSackChunk(); @@ -572,16 +572,16 @@ public void InitialChunkTwoChunkDelay() SctpDataReceiver receiver = new SctpDataReceiver(arwnd, mtu, initialTSN); // Skip initial DATA chunk. - receiver.OnDataChunk(new SctpDataChunk(true, true, true, initialTSN + 1, 0, 0, 0, new byte[] { 0x44 })); + receiver.OnDataChunk(new SctpDataChunk(true, true, true, initialTSN + 1, 0, 0, 0, new byte[] { 0x44 }).View()); Assert.Null(receiver.CumulativeAckTSN); Assert.Null(receiver.GetSackChunk()); - receiver.OnDataChunk(new SctpDataChunk(true, true, true, initialTSN + 2, 0, 0, 0, new byte[] { 0x44 })); + receiver.OnDataChunk(new SctpDataChunk(true, true, true, initialTSN + 2, 0, 0, 0, new byte[] { 0x44 }).View()); Assert.Null(receiver.CumulativeAckTSN); Assert.Null(receiver.GetSackChunk()); // Give the receiver the initial DATA chunk. - receiver.OnDataChunk(new SctpDataChunk(true, true, true, initialTSN, 0, 0, 0, new byte[] { 0x44 })); + receiver.OnDataChunk(new SctpDataChunk(true, true, true, initialTSN, 0, 0, 0, new byte[] { 0x44 }).View()); Assert.Equal(initialTSN + 2, receiver.CumulativeAckTSN); var sack = receiver.GetSackChunk(); diff --git a/test/unit/net/SCTP/SctpDataSendRecvUnitTest.cs b/test/unit/net/SCTP/SctpDataSendRecvUnitTest.cs index 5f29a2dd1..c940e004d 100644 --- a/test/unit/net/SCTP/SctpDataSendRecvUnitTest.cs +++ b/test/unit/net/SCTP/SctpDataSendRecvUnitTest.cs @@ -52,8 +52,8 @@ public async Task SACKChunkRetransmit() // sender to the receiver of a remote peer and the return of the SACK. Action doSend = (chunk) => { - receiver.OnDataChunk(chunk); - sender.GotSack(receiver.GetSackChunk()); + receiver.OnDataChunk(chunk.View()); + sender.GotSack(receiver.GetSackChunk().View()); }; Action dontSend = (chunk) => { }; @@ -116,8 +116,8 @@ public async Task InitialDataChunkDropped() } else { - receiver.OnDataChunk(chunk); - sender.GotSack(receiver.GetSackChunk()); + receiver.OnDataChunk(chunk.View()); + sender.GotSack(receiver.GetSackChunk().View()); } }; sender._sendDataChunk = doSend; @@ -162,12 +162,12 @@ public void MediumBufferSend() Action doSend = (chunk) => { logger.LogDebug($"Data chunk {chunk.TSN} provided to receiver."); - var frames = receiver.OnDataChunk(chunk); - sender.GotSack(receiver.GetSackChunk()); + var frames = receiver.OnDataChunk(chunk.View()); + sender.GotSack(receiver.GetSackChunk().View()); if (frames.Count > 0) { - logger.LogDebug($"Receiver got frame of length {frames.First().UserData?.Length}."); + logger.LogDebug($"Receiver got frame of length {frames.First().UserData.Length}."); frame = frames.First(); frameReady.Set(); } @@ -210,12 +210,12 @@ public void MaxBufferSend() Action doSend = (chunk) => { logger.LogDebug($"Data chunk {chunk.TSN} provided to receiver."); - var frames = receiver.OnDataChunk(chunk); - sender.GotSack(receiver.GetSackChunk()); + var frames = receiver.OnDataChunk(chunk.View()); + sender.GotSack(receiver.GetSackChunk().View()); if (frames.Count > 0) { - logger.LogDebug($"Receiver got frame of length {frames.First().UserData?.Length}."); + logger.LogDebug($"Receiver got frame of length {frames.First().UserData.Length}."); frame = frames.First(); frameReady.Set(); } @@ -267,12 +267,12 @@ public void MediumBufferSendWithRandomDrops() else { logger.LogDebug($"Data chunk {chunk.TSN} provided to receiver."); - var frames = receiver.OnDataChunk(chunk); - sender.GotSack(receiver.GetSackChunk()); + var frames = receiver.OnDataChunk(chunk.View()); + sender.GotSack(receiver.GetSackChunk().View()); if (frames.Count > 0) { - logger.LogDebug($"Receiver got frame of length {frames.First().UserData?.Length}."); + logger.LogDebug($"Receiver got frame of length {frames.First().UserData.Length}."); frame = frames.First(); frameReady.Set(); } @@ -325,12 +325,12 @@ public async Task MaxBufferSendWithRandomDrops() else { logger.LogDebug($"Data chunk {chunk.TSN} provided to receiver."); - var frames = receiver.OnDataChunk(chunk); - sender.GotSack(receiver.GetSackChunk()); + var frames = receiver.OnDataChunk(chunk.View()); + sender.GotSack(receiver.GetSackChunk().View()); if (frames.Count > 0) { - logger.LogDebug($"Receiver got frame of length {frames.First().UserData?.Length}."); + logger.LogDebug($"Receiver got frame of length {frames.First().UserData.Length}."); frame = frames.First(); frameReady.Set(); } diff --git a/test/unit/net/SCTP/SctpDataSenderUnitTest.cs b/test/unit/net/SCTP/SctpDataSenderUnitTest.cs index 80eb79dc6..0dd7bb768 100755 --- a/test/unit/net/SCTP/SctpDataSenderUnitTest.cs +++ b/test/unit/net/SCTP/SctpDataSenderUnitTest.cs @@ -52,7 +52,7 @@ public async Task SmallBufferSend() Assert.Single(outStm); byte[] sendBuffer = outStm.Single(); - SctpPacket pkt = SctpPacket.Parse(sendBuffer, 0, sendBuffer.Length); + SctpPacket pkt = SctpPacket.Parse(sendBuffer); Assert.NotNull(pkt); Assert.NotNull(pkt.Chunks.Single() as SctpDataChunk); @@ -108,8 +108,8 @@ public async Task IncreaseCongestionWindowSlowStart() { if (chunk.TSN % 5 == 0) { - receiver.OnDataChunk(chunk); - sender.GotSack(receiver.GetSackChunk()); + receiver.OnDataChunk(chunk.View()); + sender.GotSack(receiver.GetSackChunk().View()); } }; diff --git a/test/unit/net/SCTP/SctpHeaderUnitTest.cs b/test/unit/net/SCTP/SctpHeaderUnitTest.cs index 7efbdf9c7..ffe6b5937 100644 --- a/test/unit/net/SCTP/SctpHeaderUnitTest.cs +++ b/test/unit/net/SCTP/SctpHeaderUnitTest.cs @@ -51,7 +51,7 @@ public void RoundtripSctpHeader() header.WriteToBuffer(buffer, 0); - var rndTripHeader = SctpHeader.Parse(buffer, 0); + var rndTripHeader = SctpHeader.Parse(buffer); Assert.Equal(srcPort, rndTripHeader.SourcePort); Assert.Equal(dstPort, rndTripHeader.DestinationPort); @@ -68,7 +68,7 @@ public void ParseUsrSctpInitHeader() byte[] buffer = { 0xdf, 0x90, 0x00, 0x07, 0x00, 0x00, 0x00, 0x00, 0x6a, 0xb8, 0x0e, 0x99 }; - var sctpHdr = SctpHeader.Parse(buffer, 0); + var sctpHdr = SctpHeader.Parse(buffer); Assert.Equal(57232, sctpHdr.SourcePort); Assert.Equal(7, sctpHdr.DestinationPort); diff --git a/test/unit/net/SCTP/SctpPacketUnitTest.cs b/test/unit/net/SCTP/SctpPacketUnitTest.cs index 2391a679f..7c791a049 100644 --- a/test/unit/net/SCTP/SctpPacketUnitTest.cs +++ b/test/unit/net/SCTP/SctpPacketUnitTest.cs @@ -49,9 +49,9 @@ public void ParseUsrSctpInitPacket() 0x69, 0x81, 0x78, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x03, 0x00, 0x05, 0x00, 0x08, 0xc0, 0xa8, 0x0b, 0x32, 0x00, 0x05, 0x00, 0x08, 0xc0, 0xa8, 0x00, 0x32 }; - Assert.True(SctpPacket.IsValid(buffer, 0, buffer.Length, 0U)); + Assert.True(SctpPacket.IsValid(buffer, requiredTag: 0U)); - var sctpPkt = SctpPacket.Parse(buffer, 0, buffer.Length); + var sctpPkt = SctpPacket.Parse(buffer); Assert.NotNull(sctpPkt); Assert.Equal(57232, sctpPkt.Header.SourcePort); @@ -122,9 +122,9 @@ public void ParseUsrSctpInitAckPacket() 0x00, 0x05, 0x00, 0x08, 0xc0, 0xa8, 0x00, 0x32, 0xc5, 0x35, 0x15, 0xc8, 0x35, 0x57, 0x0a, 0xd5, 0x96, 0x29, 0xc8, 0xbf, 0x38, 0x7b, 0xc2, 0x16, 0xe9, 0x4c, 0x81, 0xbe }; - Assert.True(SctpPacket.IsValid(buffer, 0, buffer.Length, 0xe31c5536U)); + Assert.True(SctpPacket.IsValid(buffer, requiredTag: 0xe31c5536U)); - var sctpPkt = SctpPacket.Parse(buffer, 0, buffer.Length); + var sctpPkt = SctpPacket.Parse(buffer); Assert.NotNull(sctpPkt); Assert.Equal(7, sctpPkt.Header.SourcePort); @@ -155,10 +155,10 @@ public void ParseUsrSctpCookieEchoPacket() 0x4b, 0x41, 0x4d, 0x45, 0x2d, 0x42, 0x53, 0x44, 0x20, 0x31, 0x2e, 0x31, 0x00, 0x00, 0x00, 0x00 }; // Checksum does not match because the original cookie was too long and was truncated for testing purposes. - Assert.False(SctpPacket.VerifyChecksum(buffer, 0, buffer.Length)); - Assert.Equal(0xcd6e6150U, SctpPacket.GetVerificationTag(buffer, 0, buffer.Length)); + Assert.False(SctpPacket.VerifyChecksum(buffer)); + Assert.Equal(0xcd6e6150U, SctpPacket.GetVerificationTag(buffer)); - var sctpPkt = SctpPacket.Parse(buffer, 0, buffer.Length); + var sctpPkt = SctpPacket.Parse(buffer); Assert.NotNull(sctpPkt); Assert.Equal(57232, sctpPkt.Header.SourcePort); @@ -184,9 +184,9 @@ public void ParseUsrSctpCookieAckPacket() byte[] buffer = { 0x00, 0x07, 0xdf, 0x90, 0xe3, 0x1c, 0x55, 0x36, 0xb2, 0x04, 0xdf, 0x21, 0x0b, 0x00, 0x00, 0x04 }; - Assert.True(SctpPacket.IsValid(buffer, 0, buffer.Length, 0xe31c5536U)); + Assert.True(SctpPacket.IsValid(buffer, requiredTag: 0xe31c5536U)); - var sctpPkt = SctpPacket.Parse(buffer, 0, buffer.Length); + var sctpPkt = SctpPacket.Parse(buffer); Assert.NotNull(sctpPkt); Assert.Equal(7, sctpPkt.Header.SourcePort); @@ -220,18 +220,18 @@ public void RoundTripInitPacket() 0x69, 0x81, 0x78, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x03, 0x00, 0x05, 0x00, 0x08, 0xc0, 0xa8, 0x0b, 0x32, 0x00, 0x05, 0x00, 0x08, 0xc0, 0xa8, 0x00, 0x32 }; - Assert.True(SctpPacket.IsValid(buffer, 0, buffer.Length, 0x0U)); + Assert.True(SctpPacket.IsValid(buffer, requiredTag: 0x0U)); - var initPkt = SctpPacket.Parse(buffer, 0, buffer.Length); + var initPkt = SctpPacket.Parse(buffer); var rndTripBuffer = initPkt.GetBytes(); logger.LogDebug($"Before: {buffer.HexStr()}"); logger.LogDebug($"After : {rndTripBuffer.HexStr()}"); - Assert.True(SctpPacket.IsValid(rndTripBuffer, 0, rndTripBuffer.Length, 0x0U)); + Assert.True(SctpPacket.IsValid(rndTripBuffer, requiredTag: 0x0U)); - var sctpPkt = SctpPacket.Parse(rndTripBuffer, 0, rndTripBuffer.Length); + var sctpPkt = SctpPacket.Parse(rndTripBuffer); Assert.NotNull(sctpPkt); Assert.Equal(57232, sctpPkt.Header.SourcePort); @@ -264,9 +264,9 @@ public void ParseUsrSctpHeartbeatPacket() 0x00, 0x00, 0x00, 0x00, 0x02, 0x10, 0x00, 0x00, 0xc0, 0xa8, 0x00, 0x32, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}; - Assert.True(SctpPacket.IsValid(buffer, 0, buffer.Length, 0x054a3af0U)); + Assert.True(SctpPacket.IsValid(buffer, requiredTag: 0x054a3af0U)); - var sctpPkt = SctpPacket.Parse(buffer, 0, buffer.Length); + var sctpPkt = SctpPacket.Parse(buffer); Assert.NotNull(sctpPkt); Assert.Equal(7, sctpPkt.Header.SourcePort); @@ -293,18 +293,18 @@ public void RoundTripHeartbeatPacket() 0x00, 0x00, 0x00, 0x00, 0x02, 0x10, 0x00, 0x00, 0xc0, 0xa8, 0x00, 0x32, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}; - Assert.True(SctpPacket.IsValid(buffer, 0, buffer.Length, 0x054a3af0U)); + Assert.True(SctpPacket.IsValid(buffer, requiredTag: 0x054a3af0U)); - var heartbeatPkt = SctpPacket.Parse(buffer, 0, buffer.Length); + var heartbeatPkt = SctpPacket.Parse(buffer); var rndTripBuffer = heartbeatPkt.GetBytes(); logger.LogDebug($"Before: {buffer.HexStr()}"); logger.LogDebug($"After : {rndTripBuffer.HexStr()}"); - Assert.True(SctpPacket.IsValid(rndTripBuffer, 0, rndTripBuffer.Length, 0x054a3af0U)); + Assert.True(SctpPacket.IsValid(rndTripBuffer, requiredTag: 0x054a3af0U)); - var sctpPkt = SctpPacket.Parse(rndTripBuffer, 0, buffer.Length); + var sctpPkt = SctpPacket.Parse(rndTripBuffer); Assert.NotNull(sctpPkt); Assert.Equal(7, sctpPkt.Header.SourcePort); @@ -330,9 +330,9 @@ public void ParseUsrSctpDataPacket() 0x00, 0x07, 0x11, 0x5c, 0x2e, 0x0b, 0x82, 0xc7, 0x5d, 0x2c, 0xeb, 0xa7, 0x00, 0x07, 0x00, 0x13, 0xdf, 0x08, 0xc1, 0xb7, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x68, 0x69, 0x0a, 0x00}; - Assert.True(SctpPacket.IsValid(buffer, 0, buffer.Length, 0x2e0b82c7U)); + Assert.True(SctpPacket.IsValid(buffer, requiredTag: 0x2e0b82c7U)); - var sctpPkt = SctpPacket.Parse(buffer, 0, buffer.Length); + var sctpPkt = SctpPacket.Parse(buffer); Assert.NotNull(sctpPkt); Assert.Equal(7, sctpPkt.Header.SourcePort); @@ -360,18 +360,18 @@ public void RoundTripDataPacket() 0x00, 0x07, 0x11, 0x5c, 0x2e, 0x0b, 0x82, 0xc7, 0x5d, 0x2c, 0xeb, 0xa7, 0x00, 0x07, 0x00, 0x13, 0xdf, 0x08, 0xc1, 0xb7, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x68, 0x69, 0x0a, 0x00}; - Assert.True(SctpPacket.IsValid(buffer, 0, buffer.Length, 0x2e0b82c7U)); + Assert.True(SctpPacket.IsValid(buffer, requiredTag: 0x2e0b82c7U)); - var dataPkt = SctpPacket.Parse(buffer, 0, buffer.Length); + var dataPkt = SctpPacket.Parse(buffer); var rndTripBuffer = dataPkt.GetBytes(); logger.LogDebug($"Before: {buffer.HexStr()}"); logger.LogDebug($"After : {rndTripBuffer.HexStr()}"); - Assert.True(SctpPacket.IsValid(rndTripBuffer, 0, rndTripBuffer.Length, 0x2e0b82c7U)); + Assert.True(SctpPacket.IsValid(rndTripBuffer, requiredTag: 0x2e0b82c7U)); - var sctpPkt = SctpPacket.Parse(rndTripBuffer, 0, buffer.Length); + var sctpPkt = SctpPacket.Parse(rndTripBuffer); Assert.NotNull(sctpPkt); Assert.Equal(7, sctpPkt.Header.SourcePort); @@ -405,9 +405,9 @@ public void ParseUsrSctpAbortPacket() 0x65, 0x71, 0x75, 0x61, 0x6c, 0x20, 0x74, 0x68, 0x61, 0x6e, 0x20, 0x54, 0x53, 0x4e, 0x20, 0x63, 0x36, 0x65, 0x33, 0x61, 0x64, 0x33, 0x63, 0x00}; - Assert.True(SctpPacket.IsValid(buffer, 0, buffer.Length, 0x93c9d98aU)); + Assert.True(SctpPacket.IsValid(buffer, requiredTag: 0x93c9d98aU)); - var sctpPkt = SctpPacket.Parse(buffer, 0, buffer.Length); + var sctpPkt = SctpPacket.Parse(buffer); Assert.NotNull(sctpPkt); Assert.Equal(7, sctpPkt.Header.SourcePort); diff --git a/test/unit/net/SCTP/SctpTransportUnitTest.cs b/test/unit/net/SCTP/SctpTransportUnitTest.cs index ca1d3b3a5..3c2dfa2f5 100644 --- a/test/unit/net/SCTP/SctpTransportUnitTest.cs +++ b/test/unit/net/SCTP/SctpTransportUnitTest.cs @@ -13,6 +13,7 @@ // BSD 3-Clause "New" or "Revised" License, see included LICENSE.md file. //----------------------------------------------------------------------------- +using System; using System.Linq; using System.Text; using Microsoft.Extensions.Logging; @@ -89,7 +90,7 @@ public SctpPacket GetInitAck(SctpPacket initPacket) return base.GetCookieHMAC(buffer); } - public override void Send(string associationID, byte[] buffer, int offset, int length) + public override void Send(string associationID, ReadOnlySpan buffer) { } } } diff --git a/test/unit/sys/InterlockedExUnitTest.cs b/test/unit/sys/InterlockedExUnitTest.cs new file mode 100644 index 000000000..884223861 --- /dev/null +++ b/test/unit/sys/InterlockedExUnitTest.cs @@ -0,0 +1,47 @@ +//----------------------------------------------------------------------------- +// Filename: TypeExtensionsUnitTest.cs +// +// Description: Unit tests for methods in the TypeExtensions class. +// +// Author(s): +// Aaron Clauson (aaron@sipsorcery.com) +// +// History: +// ?? Aaron Clauson Created. +// +// License: +// BSD 3-Clause "New" or "Revised" License, see included LICENSE.md file. +//----------------------------------------------------------------------------- + +using System; +using System.Net; +using System.Text; +using Microsoft.Extensions.Logging; +using Xunit; + +namespace SIPSorcery.Sys.UnitTests +{ + [Trait("Category", "unit")] + public class InterlockedExUnitTest + { + private Microsoft.Extensions.Logging.ILogger logger = null; + + public InterlockedExUnitTest(Xunit.Abstractions.ITestOutputHelper output) + { + logger = SIPSorcery.UnitTests.TestLogHelper.InitTestLogger(output); + } + + [Fact] + public void CompareExchangeU32() + { + logger.LogDebug("--> " + System.Reflection.MethodBase.GetCurrentMethod().Name); + logger.BeginScope(System.Reflection.MethodBase.GetCurrentMethod().Name); + + uint value = 10; + uint was = InterlockedEx.CompareExchange(ref value, value: 20, comparand: 10); + + Assert.Equal(10u, was); + Assert.Equal(20u, value); + } + } +}