Skip to content

Commit

Permalink
Server-side adaptation of AllowPacketFragmentation options.
Browse files Browse the repository at this point in the history
  • Loading branch information
xljiulang committed Dec 7, 2024
1 parent 6171c81 commit 9a7a8bd
Show file tree
Hide file tree
Showing 16 changed files with 172 additions and 81 deletions.
7 changes: 5 additions & 2 deletions Source/MQTTnet.AspnetCore/ConnectionBuilderExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@

using Microsoft.AspNetCore.Connections;
using Microsoft.Extensions.DependencyInjection;
using MQTTnet.Adapter;
using MQTTnet.Server;
using System;

namespace MQTTnet.AspNetCore
{
Expand All @@ -15,15 +17,16 @@ public static class ConnectionBuilderExtensions
/// </summary>
/// <param name="builder"></param>
/// <param name="protocols"></param>
/// <param name="allowPacketFragmentationSelector"></param>
/// <returns></returns>
public static IConnectionBuilder UseMqtt(this IConnectionBuilder builder, MqttProtocols protocols = MqttProtocols.MqttAndWebSocket)
public static IConnectionBuilder UseMqtt(this IConnectionBuilder builder, MqttProtocols protocols = MqttProtocols.MqttAndWebSocket, Func<IMqttChannelAdapter, bool>? allowPacketFragmentationSelector = null)
{
// check services.AddMqttServer()
builder.ApplicationServices.GetRequiredService<MqttServer>();
builder.ApplicationServices.GetRequiredService<MqttConnectionHandler>().UseFlag = true;

var middleware = builder.ApplicationServices.GetRequiredService<MqttConnectionMiddleware>();
return builder.Use(next => context => middleware.InvokeAsync(next, context, protocols));
return builder.Use(next => context => middleware.InvokeAsync(next, context, protocols, allowPacketFragmentationSelector));
}
}
}
14 changes: 14 additions & 0 deletions Source/MQTTnet.AspnetCore/Features/PacketFragmentationFeature.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using MQTTnet.Adapter;
using System;

namespace MQTTnet.AspNetCore
{
sealed class PacketFragmentationFeature(Func<IMqttChannelAdapter, bool> allowPacketFragmentationSelector)
{
public Func<IMqttChannelAdapter, bool> AllowPacketFragmentationSelector { get; } = allowPacketFragmentationSelector;
}
}
28 changes: 28 additions & 0 deletions Source/MQTTnet.AspnetCore/Features/TlsConnectionFeature.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using Microsoft.AspNetCore.Http.Features;
using System.Security.Cryptography.X509Certificates;
using System.Threading;
using System.Threading.Tasks;

namespace MQTTnet.AspNetCore
{
sealed class TlsConnectionFeature : ITlsConnectionFeature
{
public static readonly TlsConnectionFeature WithoutClientCertificate = new(null);

public X509Certificate2? ClientCertificate { get; set; }

public Task<X509Certificate2?> GetClientCertificateAsync(CancellationToken cancellationToken)
{
return Task.FromResult(ClientCertificate);
}

public TlsConnectionFeature(X509Certificate? clientCertificate)
{
ClientCertificate = clientCertificate as X509Certificate2;
}
}
}
14 changes: 14 additions & 0 deletions Source/MQTTnet.AspnetCore/Features/WebSocketConnectionFeature.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

namespace MQTTnet.AspNetCore
{
sealed class WebSocketConnectionFeature(string path)
{
/// <summary>
/// The path of WebSocket request.
/// </summary>
public string Path { get; } = path;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ public IMqttChannelAdapter CreateClientAdapter(MqttClientOptions options, MqttPa
ArgumentNullException.ThrowIfNull(nameof(options));
var bufferWriter = new MqttBufferWriter(options.WriterBufferSize, options.WriterBufferSizeMax);
var formatter = new MqttPacketFormatterAdapter(options.ProtocolVersion, bufferWriter);
return new MqttClientChannelAdapter(formatter, options.ChannelOptions, packetInspector, options.AllowPacketFragmentation);
return new MqttClientChannelAdapter(formatter, options.ChannelOptions, options.AllowPacketFragmentation, packetInspector);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using Microsoft.AspNetCore.Connections.Features;
using Microsoft.AspNetCore.Http.Features;
using System;
using System.Net;
Expand Down Expand Up @@ -72,7 +71,11 @@ public static async Task<ClientConnectionContext> CreateAsync(MqttClientTcpOptio
var networkStream = new NetworkStream(socket, ownsSocket: true);
if (options.TlsOptions?.UseTls != true)
{
return new ClientConnectionContext(networkStream);
return new ClientConnectionContext(networkStream)
{
LocalEndPoint = socket.LocalEndPoint,
RemoteEndPoint = socket.RemoteEndPoint,
};
}

var targetHost = options.TlsOptions.TargetHost;
Expand Down Expand Up @@ -143,7 +146,6 @@ public static async Task<ClientConnectionContext> CreateAsync(MqttClientTcpOptio
RemoteEndPoint = socket.RemoteEndPoint,
};

connection.Features.Set<IConnectionSocketFeature>(new ConnectionSocketFeature(socket));
connection.Features.Set<ITlsConnectionFeature>(new TlsConnectionFeature(sslStream.LocalCertificate));
return connection;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,10 @@ public static async Task<ClientConnectionContext> CreateAsync(MqttClientWebSocke
RemoteEndPoint = new DnsEndPoint(uri.Host, uri.Port),
};

connection.Features.Set(new WebSocketConnectionFeature(uri.AbsolutePath));
if (uri.Scheme == Uri.UriSchemeWss)
{
connection.Features.Set<ITlsConnectionFeature>(TlsConnectionFeature.Default);
connection.Features.Set<ITlsConnectionFeature>(TlsConnectionFeature.WithoutClientCertificate);
}
return connection;
}
Expand Down
25 changes: 0 additions & 25 deletions Source/MQTTnet.AspnetCore/Internal/ClientConnectionContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,11 @@
// See the LICENSE file in the project root for more information.

using Microsoft.AspNetCore.Connections;
using Microsoft.AspNetCore.Connections.Features;
using Microsoft.AspNetCore.Http.Features;
using System;
using System.Collections.Generic;
using System.IO;
using System.IO.Pipelines;
using System.Net.Sockets;
using System.Security.Cryptography.X509Certificates;
using System.Threading;
using System.Threading.Tasks;

Expand Down Expand Up @@ -67,27 +64,5 @@ private class StreamTransport(Stream stream) : IDuplexPipe

public PipeWriter Output { get; } = PipeWriter.Create(stream, new StreamPipeWriterOptions(leaveOpen: true));
}

private class TlsConnectionFeature : ITlsConnectionFeature
{
public static readonly TlsConnectionFeature Default = new(null);

public X509Certificate2? ClientCertificate { get; set; }

public Task<X509Certificate2?> GetClientCertificateAsync(CancellationToken cancellationToken)
{
return Task.FromResult(ClientCertificate);
}

public TlsConnectionFeature(X509Certificate? clientCertificate)
{
ClientCertificate = clientCertificate as X509Certificate2;
}
}

private class ConnectionSocketFeature(Socket socket) : IConnectionSocketFeature
{
public Socket Socket { get; } = socket;
}
}
}
63 changes: 29 additions & 34 deletions Source/MQTTnet.AspnetCore/Internal/MqttChannel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
// See the LICENSE file in the project root for more information.

using Microsoft.AspNetCore.Connections;
using Microsoft.AspNetCore.Http.Connections.Features;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Http.Features;
using MQTTnet.Adapter;
using MQTTnet.Exceptions;
Expand All @@ -29,7 +29,7 @@ class MqttChannel : IDisposable
readonly PipeReader _input;
readonly PipeWriter _output;
readonly MqttPacketInspector? _packetInspector;
readonly bool _allowPacketFragmentation;
bool _allowPacketFragmentation = false;

public MqttPacketFormatterAdapter PacketFormatterAdapter { get; }

Expand All @@ -43,69 +43,64 @@ class MqttChannel : IDisposable

public bool IsSecureConnection { get; }

public bool IsWebSocketConnection { get; }


public MqttChannel(
MqttPacketFormatterAdapter packetFormatterAdapter,
ConnectionContext connection,
MqttPacketInspector? packetInspector = null,
bool? allowPacketFragmentation = null)
HttpContext? httpContext,
MqttPacketInspector? packetInspector)
{
PacketFormatterAdapter = packetFormatterAdapter;
_packetInspector = packetInspector;

var httpContextFeature = connection.Features.Get<IHttpContextFeature>();
var tlsConnectionFeature = connection.Features.Get<ITlsConnectionFeature>();
RemoteEndPoint = GetRemoteEndPoint(httpContextFeature, connection.RemoteEndPoint);
IsSecureConnection = IsTlsConnection(httpContextFeature, tlsConnectionFeature);
ClientCertificate = GetClientCertificate(httpContextFeature, tlsConnectionFeature);
RemoteEndPoint = GetRemoteEndPoint(connection.RemoteEndPoint, httpContext);
ClientCertificate = GetClientCertificate(tlsConnectionFeature, httpContext);
IsSecureConnection = IsTlsConnection(tlsConnectionFeature, httpContext);
IsWebSocketConnection = connection.Features.Get<WebSocketConnectionFeature>() != null;

_packetInspector = packetInspector;
_input = connection.Transport.Input;
_output = connection.Transport.Output;

_allowPacketFragmentation = allowPacketFragmentation == null
? AllowPacketFragmentation(httpContextFeature)
: allowPacketFragmentation.Value;
}

private static bool AllowPacketFragmentation(IHttpContextFeature? _httpContextFeature)
private static EndPoint? GetRemoteEndPoint(EndPoint? remoteEndPoint, HttpContext? httpContext)
{
var serverModeWebSocket = _httpContextFeature != null &&
_httpContextFeature.HttpContext != null &&
_httpContextFeature.HttpContext.WebSockets.IsWebSocketRequest;

return !serverModeWebSocket;
}

if (remoteEndPoint != null)
{
return remoteEndPoint;
}

private static EndPoint? GetRemoteEndPoint(IHttpContextFeature? _httpContextFeature, EndPoint? remoteEndPoint)
{
if (_httpContextFeature != null && _httpContextFeature.HttpContext != null)
if (httpContext != null)
{
var httpConnection = _httpContextFeature.HttpContext.Connection;
var httpConnection = httpContext.Connection;
var remoteAddress = httpConnection.RemoteIpAddress;
if (remoteAddress != null)
{
return new IPEndPoint(remoteAddress, httpConnection.RemotePort);
}
}

return remoteEndPoint;
return null;
}

private static bool IsTlsConnection(IHttpContextFeature? _httpContextFeature, ITlsConnectionFeature? tlsConnectionFeature)
private static bool IsTlsConnection(ITlsConnectionFeature? tlsConnectionFeature, HttpContext? httpContext)
{
return _httpContextFeature != null && _httpContextFeature.HttpContext != null
? _httpContextFeature.HttpContext.Request.IsHttps
: tlsConnectionFeature != null;
return tlsConnectionFeature != null || (httpContext != null && httpContext.Request.IsHttps);
}

private static X509Certificate2? GetClientCertificate(IHttpContextFeature? _httpContextFeature, ITlsConnectionFeature? tlsConnectionFeature)
private static X509Certificate2? GetClientCertificate(ITlsConnectionFeature? tlsConnectionFeature, HttpContext? httpContext)
{
return _httpContextFeature != null && _httpContextFeature.HttpContext != null
? _httpContextFeature.HttpContext.Connection.ClientCertificate
: tlsConnectionFeature?.ClientCertificate;
return tlsConnectionFeature != null
? tlsConnectionFeature.ClientCertificate
: httpContext?.Connection.ClientCertificate;
}

public void SetAllowPacketFragmentation(bool value)
{
_allowPacketFragmentation = value;
}

public async Task DisconnectAsync()
{
Expand Down
13 changes: 8 additions & 5 deletions Source/MQTTnet.AspnetCore/Internal/MqttClientChannelAdapter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,19 @@ sealed class MqttClientChannelAdapter : IMqttChannelAdapter, IAsyncDisposable
private MqttChannel? _channel;
private readonly MqttPacketFormatterAdapter _packetFormatterAdapter;
private readonly IMqttClientChannelOptions _channelOptions;
private readonly bool _allowPacketFragmentation;
private readonly MqttPacketInspector? _packetInspector;
private readonly bool? _allowPacketFragmentation;

public MqttClientChannelAdapter(
MqttPacketFormatterAdapter packetFormatterAdapter,
IMqttClientChannelOptions channelOptions,
MqttPacketInspector? packetInspector,
bool? allowPacketFragmentation)
bool allowPacketFragmentation,
MqttPacketInspector? packetInspector)
{
_packetFormatterAdapter = packetFormatterAdapter;
_channelOptions = channelOptions;
_packetInspector = packetInspector;
_allowPacketFragmentation = allowPacketFragmentation;
_packetInspector = packetInspector;
}

public MqttPacketFormatterAdapter PacketFormatterAdapter => GetChannel().PacketFormatterAdapter;
Expand All @@ -49,6 +49,8 @@ public MqttClientChannelAdapter(

public bool IsSecureConnection => GetChannel().IsSecureConnection;

public bool IsWebSocketConnection => GetChannel().IsSecureConnection;


public async Task ConnectAsync(CancellationToken cancellationToken)
{
Expand All @@ -60,7 +62,8 @@ public async Task ConnectAsync(CancellationToken cancellationToken)
MqttClientWebSocketOptions webSocketOptions => await ClientConnectionContext.CreateAsync(webSocketOptions, cancellationToken).ConfigureAwait(false),
_ => throw new NotSupportedException(),
};
_channel = new MqttChannel(_packetFormatterAdapter, _connection, _packetInspector, _allowPacketFragmentation);
_channel = new MqttChannel(_packetFormatterAdapter, _connection, httpContext: null, _packetInspector);
_channel.SetAllowPacketFragmentation(_allowPacketFragmentation);
}
catch (Exception ex)
{
Expand Down
12 changes: 10 additions & 2 deletions Source/MQTTnet.AspnetCore/Internal/MqttConnectionHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

using Microsoft.AspNetCore.Connections;
using Microsoft.AspNetCore.Connections.Features;
using Microsoft.AspNetCore.Http.Connections;
using MQTTnet.Adapter;
using MQTTnet.Diagnostics.Logger;
using MQTTnet.Formatter;
Expand Down Expand Up @@ -51,12 +52,19 @@ public override async Task OnConnectedAsync(ConnectionContext connection)
transferFormatFeature.ActiveFormat = TransferFormat.Binary;
}

var bufferWriter = _bufferWriterPool.Rent();
// WebSocketConnectionFeature will be accessed in MqttChannel
var httpContext = connection.GetHttpContext();
if (httpContext != null && httpContext.WebSockets.IsWebSocketRequest)
{
var path = httpContext.Request.Path;
connection.Features.Set(new WebSocketConnectionFeature(path));
}

var bufferWriter = _bufferWriterPool.Rent();
try
{
var formatter = new MqttPacketFormatterAdapter(bufferWriter);
using var adapter = new MqttServerChannelAdapter(formatter, connection);
using var adapter = new MqttServerChannelAdapter(formatter, connection, httpContext);
await clientHandler(adapter).ConfigureAwait(false);
}
finally
Expand Down
12 changes: 11 additions & 1 deletion Source/MQTTnet.AspnetCore/Internal/MqttConnectionMiddleware.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
// See the LICENSE file in the project root for more information.

using Microsoft.AspNetCore.Connections;
using MQTTnet.Adapter;
using System;
using System.Buffers;
using System.Threading.Tasks;
Expand All @@ -23,8 +24,17 @@ public MqttConnectionMiddleware(MqttConnectionHandler connectionHandler)
_connectionHandler = connectionHandler;
}

public async Task InvokeAsync(ConnectionDelegate next, ConnectionContext connection, MqttProtocols protocols)
public async Task InvokeAsync(
ConnectionDelegate next,
ConnectionContext connection,
MqttProtocols protocols,
Func<IMqttChannelAdapter, bool>? allowPacketFragmentationSelector)
{
if (allowPacketFragmentationSelector != null)
{
connection.Features.Set(new PacketFragmentationFeature(allowPacketFragmentationSelector));
}

if (protocols == MqttProtocols.MqttAndWebSocket)
{
var input = connection.Transport.Input;
Expand Down
Loading

0 comments on commit 9a7a8bd

Please sign in to comment.