Skip to content

Commit

Permalink
feat: RsaStream 对方身份验证
Browse files Browse the repository at this point in the history
  • Loading branch information
SALTWOOD committed Aug 13, 2024
1 parent a884f80 commit 0ea9701
Showing 1 changed file with 38 additions and 14 deletions.
52 changes: 38 additions & 14 deletions Network/RsaStream.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,18 @@ namespace TeraIO.Network;
public class RsaStream : Stream, IDisposable
{
protected readonly Stream _stream;
protected RSA _rsaPrivate;
protected RSA _rsaPublic;
private RSA _privateKey;
public RSA PrivateKey
{
get => this._privateKey;
set
{
this._privateKey = value;
this.PublicKey = RSA.Create();
this.PublicKey.ImportParameters(this.PrivateKey.ExportParameters(false));
}
}
public RSA PublicKey { get; protected set; }
protected RSAParameters _publicKey;
protected RSA? _remotePublicKey;
protected ushort _protocolVersion = 1;
Expand All @@ -26,6 +36,8 @@ public class RsaStream : Stream, IDisposable
private RsaStreamStatus _status;
private byte[] pendingBytes = Array.Empty<byte>();

public int RsaKeyLength { get; protected set; }

public RsaStreamStatus Status
{
get
Expand All @@ -44,45 +56,53 @@ public RsaStreamStatus Status
}
}

public RsaStream(Stream stream)
public RsaStream(Stream stream) : this(stream, 4096) { }

#pragma warning disable CS8618
public RsaStream(Stream stream, int rsaKeyLength)
{
_stream = stream ?? throw new ArgumentNullException(nameof(stream));
if (!_stream.CanRead || !_stream.CanWrite)
{
throw new InvalidOperationException("Stream must be readable and writable.");
}
_rsaPrivate = RSA.Create();
_rsaPublic = RSA.Create();
_publicKey = _rsaPrivate.ExportParameters(false);
_rsaPublic.ImportParameters(_publicKey);
this.RsaKeyLength = rsaKeyLength;
this.PrivateKey = RSA.Create(rsaKeyLength);
}
#pragma warning restore CS8618

public void Handshake()
public void Handshake(string? signature = null)
{
this.Status = RsaStreamStatus.Handshaking;
// Send public key
byte[] publicKeyBytes = _rsaPublic.ExportRSAPublicKey();
byte[] publicKeyBytes = PublicKey.ExportRSAPublicKey();
_stream.Write(publicKeyBytes, 0, publicKeyBytes.Length);
_stream.Flush();

// Receive remote public key
byte[] remotePublicKeyBytes = new byte[4096]; // Adjust size as needed
byte[] remotePublicKeyBytes = new byte[this.RsaKeyLength]; // Adjust size as needed
int bytesRead = _stream.Read(remotePublicKeyBytes, 0, remotePublicKeyBytes.Length);
if (bytesRead == 0) throw new IOException("Failed to read remote public key.");

_remotePublicKey = RSA.Create();
_remotePublicKey.ImportRSAPublicKey(remotePublicKeyBytes.AsSpan(0, bytesRead), out _);

if (signature != null && GetRsaFingerprint(_remotePublicKey) != signature)
{
this.Status = RsaStreamStatus.AuthenticationFailed;
throw new Exception($"Remote signature doen't match: expected \"{signature}\", but \"{GetRsaFingerprint(_remotePublicKey)}\" found.");
}

// Test encryption/decryption
byte[] helloBytes = Encoding.UTF8.GetBytes("RSA HELLO");
byte[] encryptedHello = _remotePublicKey.Encrypt(helloBytes, RSAEncryptionPadding.OaepSHA256);
_stream.Write(encryptedHello, 0, encryptedHello.Length);
_stream.Flush();

byte[] responseBytes = new byte[4096]; // Adjust size as needed
byte[] responseBytes = new byte[this.RsaKeyLength]; // Adjust size as needed
int responseLength = _stream.Read(responseBytes, 0, responseBytes.Length);
if (responseLength == 0) throw new IOException("Failed to read response.");
byte[] decryptedResponse = _rsaPrivate.Decrypt(responseBytes.AsSpan(0, responseLength), RSAEncryptionPadding.OaepSHA256);
byte[] decryptedResponse = PrivateKey.Decrypt(responseBytes.AsSpan(0, responseLength), RSAEncryptionPadding.OaepSHA256);
string decryptedResponseStr = Encoding.UTF8.GetString(decryptedResponse);
if (decryptedResponseStr != "RSA HELLO") throw new InvalidOperationException("Handshake failed.");

Expand Down Expand Up @@ -158,7 +178,7 @@ public override int Read(byte[] buffer, int offset, int count)
using (var cryptoStream = new MemoryStream())
{
int totalBytesRead = 0;
int encryptedBlockSize = _rsaPrivate.KeySize / 8; // Each encrypted block's size should be equal to the RSA key size in bytes
int encryptedBlockSize = PrivateKey.KeySize / 8; // Each encrypted block's size should be equal to the RSA key size in bytes

byte[] encryptedBuffer = new byte[encryptedBlockSize];
int bytesRead;
Expand All @@ -177,7 +197,7 @@ public override int Read(byte[] buffer, int offset, int count)
if (bytesRead != encryptedBlockSize)
throw new CryptographicException("The length of the data to decrypt is not valid for the size of this key.");

decryptedBlock = _rsaPrivate.Decrypt(encryptedBuffer, RSAEncryptionPadding.OaepSHA256);
decryptedBlock = PrivateKey.Decrypt(encryptedBuffer, RSAEncryptionPadding.OaepSHA256);

if (decryptedBlock.Length > 0)
{
Expand Down Expand Up @@ -240,4 +260,8 @@ protected override void Dispose(bool disposing)
}
}
#endregion

#region Static Methods
public static string GetRsaFingerprint(RSA rsa) => $"SHA256:{Convert.ToHexString(SHA256.HashData(rsa.ExportRSAPublicKey())).ToLower()}";
#endregion
}

0 comments on commit 0ea9701

Please sign in to comment.