diff --git a/Source/MQTTnet.AspnetCore/Features/PacketFragmentationFeature.cs b/Source/MQTTnet.AspnetCore/Features/PacketFragmentationFeature.cs index d8e9f7458..eb05414a4 100644 --- a/Source/MQTTnet.AspnetCore/Features/PacketFragmentationFeature.cs +++ b/Source/MQTTnet.AspnetCore/Features/PacketFragmentationFeature.cs @@ -20,12 +20,9 @@ public static bool CanAllowPacketFragmentation(IMqttChannelAdapter channelAdapte //} // In the AspNetCore environment, we need to exclude WebSocket before AllowPacketFragmentation. - if (channelAdapter is MqttServerChannelAdapter serverChannelAdapter) + if (channelAdapter.IsWebSocketConnection() == true) { - if (serverChannelAdapter.IsWebSocketConnection) - { - return false; - } + return false; } return endpointOptions == null || endpointOptions.AllowPacketFragmentation; diff --git a/Source/MQTTnet.AspnetCore/Internal/IAspNetCoreMqttChannelAdapter.cs b/Source/MQTTnet.AspnetCore/Internal/IAspNetCoreMqttChannelAdapter.cs new file mode 100644 index 000000000..323f29999 --- /dev/null +++ b/Source/MQTTnet.AspnetCore/Internal/IAspNetCoreMqttChannelAdapter.cs @@ -0,0 +1,16 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Http.Features; +using MQTTnet.Adapter; + +namespace MQTTnet.AspNetCore +{ + interface IAspNetCoreMqttChannelAdapter : IMqttChannelAdapter + { + HttpContext? HttpContext { get; } + IFeatureCollection? Features { get; } + } +} diff --git a/Source/MQTTnet.AspnetCore/Internal/MqttClientChannelAdapter.cs b/Source/MQTTnet.AspnetCore/Internal/MqttClientChannelAdapter.cs index 832df3f0b..1fec39feb 100644 --- a/Source/MQTTnet.AspnetCore/Internal/MqttClientChannelAdapter.cs +++ b/Source/MQTTnet.AspnetCore/Internal/MqttClientChannelAdapter.cs @@ -3,6 +3,8 @@ // See the LICENSE file in the project root for more information. using Microsoft.AspNetCore.Connections; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Http.Features; using MQTTnet.Adapter; using MQTTnet.Formatter; using MQTTnet.Packets; @@ -15,7 +17,7 @@ namespace MQTTnet.AspNetCore; -sealed class MqttClientChannelAdapter : IMqttChannelAdapter, IAsyncDisposable +sealed class MqttClientChannelAdapter : IAspNetCoreMqttChannelAdapter, IAsyncDisposable { private bool _disposed = false; private ConnectionContext? _connection; @@ -25,6 +27,9 @@ sealed class MqttClientChannelAdapter : IMqttChannelAdapter, IAsyncDisposable private readonly bool _allowPacketFragmentation; private readonly MqttPacketInspector? _packetInspector; + public HttpContext? HttpContext => null; + public IFeatureCollection? Features => _connection?.Features; + public MqttClientChannelAdapter( MqttPacketFormatterAdapter packetFormatterAdapter, IMqttClientChannelOptions channelOptions, diff --git a/Source/MQTTnet.AspnetCore/Internal/MqttServerChannelAdapter.cs b/Source/MQTTnet.AspnetCore/Internal/MqttServerChannelAdapter.cs index 657a8e237..3703931b1 100644 --- a/Source/MQTTnet.AspnetCore/Internal/MqttServerChannelAdapter.cs +++ b/Source/MQTTnet.AspnetCore/Internal/MqttServerChannelAdapter.cs @@ -4,18 +4,24 @@ using Microsoft.AspNetCore.Connections; using Microsoft.AspNetCore.Http; -using MQTTnet.Adapter; +using Microsoft.AspNetCore.Http.Features; using MQTTnet.Formatter; using System.Threading; using System.Threading.Tasks; namespace MQTTnet.AspNetCore; -sealed class MqttServerChannelAdapter : MqttChannel, IMqttChannelAdapter +sealed class MqttServerChannelAdapter : MqttChannel, IAspNetCoreMqttChannelAdapter { + public HttpContext? HttpContext { get; } + public IFeatureCollection? Features { get; } + public MqttServerChannelAdapter(MqttPacketFormatterAdapter packetFormatterAdapter, ConnectionContext connection, HttpContext? httpContext) : base(packetFormatterAdapter, connection, httpContext, packetInspector: null) { + HttpContext = httpContext; + Features = connection.Features; + SetAllowPacketFragmentation(connection, httpContext); } diff --git a/Source/MQTTnet.AspnetCore/MqttChannelAdapterExtensions.cs b/Source/MQTTnet.AspnetCore/MqttChannelAdapterExtensions.cs new file mode 100644 index 000000000..17e834c77 --- /dev/null +++ b/Source/MQTTnet.AspnetCore/MqttChannelAdapterExtensions.cs @@ -0,0 +1,48 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.AspNetCore.Http; +using MQTTnet.Adapter; +using System; + +namespace MQTTnet.AspNetCore +{ + public static class MqttChannelAdapterExtensions + { + public static bool? IsWebSocketConnection(this IMqttChannelAdapter channelAdapter) + { + ArgumentNullException.ThrowIfNull(channelAdapter); + return channelAdapter is IAspNetCoreMqttChannelAdapter adapter + ? adapter.Features != null && adapter.Features.Get() != null + : null; + } + + /// + /// Retrieves the requested feature from the feature collection of channelAdapter. + /// + /// + /// + /// + public static TFeature? GetFeature(this IMqttChannelAdapter channelAdapter) + { + ArgumentNullException.ThrowIfNull(channelAdapter); + return channelAdapter is IAspNetCoreMqttChannelAdapter adapter && adapter.Features != null + ? adapter.Features.Get() + : default; + } + + /// + /// When the channelAdapter is a WebSocket connection, it can get an associated . + /// + /// + /// + public static HttpContext? GetHttpContext(this IMqttChannelAdapter channelAdapter) + { + ArgumentNullException.ThrowIfNull(channelAdapter); + return channelAdapter is IAspNetCoreMqttChannelAdapter adapter + ? adapter.HttpContext + : null; + } + } +}