From 0ea970101c6379418165945d9026a35eeb5f17f3 Mon Sep 17 00:00:00 2001 From: SALTWOOD <105980161+SALTWOOD@users.noreply.github.com> Date: Tue, 13 Aug 2024 16:53:02 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20RsaStream=20=E5=AF=B9=E6=96=B9=E8=BA=AB?= =?UTF-8?q?=E4=BB=BD=E9=AA=8C=E8=AF=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- Network/RsaStream.cs | 52 ++++++++++++++++++++++++++++++++------------ 1 file changed, 38 insertions(+), 14 deletions(-) diff --git a/Network/RsaStream.cs b/Network/RsaStream.cs index f0a4984..56ae043 100644 --- a/Network/RsaStream.cs +++ b/Network/RsaStream.cs @@ -16,8 +16,18 @@ namespace TeraIO.Network; public class RsaStream : Stream, IDisposable { protected readonly Stream _stream; - protected RSA _rsaPrivate; - protected RSA _rsaPublic; + private RSA _privateKey; + public RSA PrivateKey + { + get => this._privateKey; + set + { + this._privateKey = value; + this.PublicKey = RSA.Create(); + this.PublicKey.ImportParameters(this.PrivateKey.ExportParameters(false)); + } + } + public RSA PublicKey { get; protected set; } protected RSAParameters _publicKey; protected RSA? _remotePublicKey; protected ushort _protocolVersion = 1; @@ -26,6 +36,8 @@ public class RsaStream : Stream, IDisposable private RsaStreamStatus _status; private byte[] pendingBytes = Array.Empty(); + public int RsaKeyLength { get; protected set; } + public RsaStreamStatus Status { get @@ -44,45 +56,53 @@ public RsaStreamStatus Status } } - public RsaStream(Stream stream) + public RsaStream(Stream stream) : this(stream, 4096) { } + +#pragma warning disable CS8618 + public RsaStream(Stream stream, int rsaKeyLength) { _stream = stream ?? throw new ArgumentNullException(nameof(stream)); if (!_stream.CanRead || !_stream.CanWrite) { throw new InvalidOperationException("Stream must be readable and writable."); } - _rsaPrivate = RSA.Create(); - _rsaPublic = RSA.Create(); - _publicKey = _rsaPrivate.ExportParameters(false); - _rsaPublic.ImportParameters(_publicKey); + this.RsaKeyLength = rsaKeyLength; + this.PrivateKey = RSA.Create(rsaKeyLength); } +#pragma warning restore CS8618 - public void Handshake() + public void Handshake(string? signature = null) { this.Status = RsaStreamStatus.Handshaking; // Send public key - byte[] publicKeyBytes = _rsaPublic.ExportRSAPublicKey(); + byte[] publicKeyBytes = PublicKey.ExportRSAPublicKey(); _stream.Write(publicKeyBytes, 0, publicKeyBytes.Length); _stream.Flush(); // Receive remote public key - byte[] remotePublicKeyBytes = new byte[4096]; // Adjust size as needed + byte[] remotePublicKeyBytes = new byte[this.RsaKeyLength]; // Adjust size as needed int bytesRead = _stream.Read(remotePublicKeyBytes, 0, remotePublicKeyBytes.Length); if (bytesRead == 0) throw new IOException("Failed to read remote public key."); _remotePublicKey = RSA.Create(); _remotePublicKey.ImportRSAPublicKey(remotePublicKeyBytes.AsSpan(0, bytesRead), out _); + if (signature != null && GetRsaFingerprint(_remotePublicKey) != signature) + { + this.Status = RsaStreamStatus.AuthenticationFailed; + throw new Exception($"Remote signature doen't match: expected \"{signature}\", but \"{GetRsaFingerprint(_remotePublicKey)}\" found."); + } + // Test encryption/decryption byte[] helloBytes = Encoding.UTF8.GetBytes("RSA HELLO"); byte[] encryptedHello = _remotePublicKey.Encrypt(helloBytes, RSAEncryptionPadding.OaepSHA256); _stream.Write(encryptedHello, 0, encryptedHello.Length); _stream.Flush(); - byte[] responseBytes = new byte[4096]; // Adjust size as needed + byte[] responseBytes = new byte[this.RsaKeyLength]; // Adjust size as needed int responseLength = _stream.Read(responseBytes, 0, responseBytes.Length); if (responseLength == 0) throw new IOException("Failed to read response."); - byte[] decryptedResponse = _rsaPrivate.Decrypt(responseBytes.AsSpan(0, responseLength), RSAEncryptionPadding.OaepSHA256); + byte[] decryptedResponse = PrivateKey.Decrypt(responseBytes.AsSpan(0, responseLength), RSAEncryptionPadding.OaepSHA256); string decryptedResponseStr = Encoding.UTF8.GetString(decryptedResponse); if (decryptedResponseStr != "RSA HELLO") throw new InvalidOperationException("Handshake failed."); @@ -158,7 +178,7 @@ public override int Read(byte[] buffer, int offset, int count) using (var cryptoStream = new MemoryStream()) { int totalBytesRead = 0; - int encryptedBlockSize = _rsaPrivate.KeySize / 8; // Each encrypted block's size should be equal to the RSA key size in bytes + int encryptedBlockSize = PrivateKey.KeySize / 8; // Each encrypted block's size should be equal to the RSA key size in bytes byte[] encryptedBuffer = new byte[encryptedBlockSize]; int bytesRead; @@ -177,7 +197,7 @@ public override int Read(byte[] buffer, int offset, int count) if (bytesRead != encryptedBlockSize) throw new CryptographicException("The length of the data to decrypt is not valid for the size of this key."); - decryptedBlock = _rsaPrivate.Decrypt(encryptedBuffer, RSAEncryptionPadding.OaepSHA256); + decryptedBlock = PrivateKey.Decrypt(encryptedBuffer, RSAEncryptionPadding.OaepSHA256); if (decryptedBlock.Length > 0) { @@ -240,4 +260,8 @@ protected override void Dispose(bool disposing) } } #endregion + + #region Static Methods + public static string GetRsaFingerprint(RSA rsa) => $"SHA256:{Convert.ToHexString(SHA256.HashData(rsa.ExportRSAPublicKey())).ToLower()}"; + #endregion } \ No newline at end of file