Skip to content

Commit

Permalink
Use Pipe to reduce the memory allocation of MqttPacketInspector.
Browse files Browse the repository at this point in the history
  • Loading branch information
xljiulang committed Dec 7, 2024
1 parent f2a81ea commit 759a816
Show file tree
Hide file tree
Showing 5 changed files with 106 additions and 44 deletions.
5 changes: 3 additions & 2 deletions Samples/Diagnostics/PackageInspection_Samples.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
// ReSharper disable InconsistentNaming

using MQTTnet.Diagnostics.PacketInspection;
using System.Buffers;

namespace MQTTnet.Samples.Diagnostics;

Expand Down Expand Up @@ -43,11 +44,11 @@ static Task OnInspectPacket(InspectMqttPacketEventArgs eventArgs)
{
if (eventArgs.Direction == MqttPacketFlowDirection.Inbound)
{
Console.WriteLine($"IN: {Convert.ToBase64String(eventArgs.Buffer)}");
Console.WriteLine($"IN: {Convert.ToBase64String(eventArgs.Buffer.ToArray())}");
}
else
{
Console.WriteLine($"OUT: {Convert.ToBase64String(eventArgs.Buffer)}");
Console.WriteLine($"OUT: {Convert.ToBase64String(eventArgs.Buffer.ToArray())}");
}

return Task.CompletedTask;
Expand Down
3 changes: 2 additions & 1 deletion Source/MQTTnet.Tests/Diagnostics/PacketInspection_Tests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
// See the LICENSE file in the project root for more information.

using System;
using System.Buffers;
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
Expand Down Expand Up @@ -34,7 +35,7 @@ public async Task Inspect_Client_Packets()

mqttClient.InspectPacketAsync += eventArgs =>
{
packets.Add(eventArgs.Direction + ":" + Convert.ToBase64String(eventArgs.Buffer));
packets.Add(eventArgs.Direction + ":" + Convert.ToBase64String(eventArgs.Buffer.ToArray()));
return CompletedTask.Instance;
};

Expand Down
2 changes: 1 addition & 1 deletion Source/MQTTnet.Tests/Server/General.cs
Original file line number Diff line number Diff line change
Expand Up @@ -981,7 +981,7 @@ public async Task Disconnect_Client_with_Reason()
{
if (e.Buffer.Length > 0)
{
if (e.Buffer[0] == (byte)MqttControlPacketType.Disconnect << 4)
if (e.Buffer.FirstSpan[0] == (byte)MqttControlPacketType.Disconnect << 4)
{
disconnectPacketReceived = true;
}
Expand Down
133 changes: 96 additions & 37 deletions Source/MQTTnet/Adapter/MqttPacketInspector.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;
using System.Buffers;
using System.IO;
using System.Threading.Tasks;
using MQTTnet.Diagnostics.Logger;
using MQTTnet.Diagnostics.PacketInspection;
using MQTTnet.Formatter;
using MQTTnet.Internal;
using System;
using System.Buffers;
using System.IO.Pipelines;
using System.Threading;
using System.Threading.Tasks;

namespace MQTTnet.Adapter;

Expand All @@ -18,7 +19,9 @@ public sealed class MqttPacketInspector
readonly AsyncEvent<InspectMqttPacketEventArgs> _asyncEvent;
readonly MqttNetSourceLogger _logger;

MemoryStream _receivedPacketBuffer;
readonly Pipe _pipeIn = new();
readonly Pipe _pipeOut = new();
ReceiveState _receiveState = ReceiveState.Disable;

public MqttPacketInspector(AsyncEvent<InspectMqttPacketEventArgs> asyncEvent, IMqttNetLogger logger)
{
Expand All @@ -29,79 +32,115 @@ public MqttPacketInspector(AsyncEvent<InspectMqttPacketEventArgs> asyncEvent, IM
_logger = logger.WithSource(nameof(MqttPacketInspector));
}

public void BeginReceivePacket()
public async Task BeginSendPacket(MqttPacketBuffer buffer)
{
if (!_asyncEvent.HasHandlers)
{
return;
}

if (_receivedPacketBuffer == null)
// Create a copy of the actual packet so that the inspector gets no access
// to the internal buffers. This is waste of memory but this feature is only
// intended for debugging etc. so that this is OK.
var writer = _pipeOut.Writer;
await writer.WriteAsync(buffer.Packet).ConfigureAwait(false);
foreach (var memory in buffer.Payload)
{
_receivedPacketBuffer = new MemoryStream();
await writer.WriteAsync(memory).ConfigureAwait(false);
}

_receivedPacketBuffer?.SetLength(0);
await writer.CompleteAsync().ConfigureAwait(false);
await InspectPacketAsync(_pipeOut.Reader, MqttPacketFlowDirection.Outbound).ConfigureAwait(false);

// reset pipe
await _pipeOut.Reader.CompleteAsync().ConfigureAwait(false);
_pipeOut.Reset();
}

public Task BeginSendPacket(MqttPacketBuffer buffer)
public void BeginReceivePacket()
{
if (!_asyncEvent.HasHandlers)
if (_asyncEvent.HasHandlers)
{
return CompletedTask.Instance;
}
// This shouldn't happen, but we need to be able to accommodate the unexpected.
if (_receiveState == ReceiveState.Fill)
{
_pipeIn.Writer.Complete();
_pipeIn.Reader.Complete();
_pipeIn.Reset();

// Create a copy of the actual packet so that the inspector gets no access
// to the internal buffers. This is waste of memory but this feature is only
// intended for debugging etc. so that this is OK.
var bufferCopy = buffer.ToArray();
_logger.Warning("An EndReceivePacket() operation was unexpectedly lost.");
}

return InspectPacket(bufferCopy, MqttPacketFlowDirection.Outbound);
_receiveState = ReceiveState.Begin;
}
else
{
_receiveState = ReceiveState.Disable;
}
}

public Task EndReceivePacket()
public void FillReceiveBuffer(ReadOnlySpan<byte> buffer)
{
if (!_asyncEvent.HasHandlers)
if (_receiveState == ReceiveState.Disable)
{
return CompletedTask.Instance;
return;
}

var buffer = _receivedPacketBuffer.ToArray();
_receivedPacketBuffer.SetLength(0);
if (_receiveState == ReceiveState.End)
{
throw new InvalidOperationException("FillReceiveBuffer is not allowed in End state.");
}

return InspectPacket(buffer, MqttPacketFlowDirection.Inbound);
_pipeIn.Writer.Write(buffer);
_receiveState = ReceiveState.Fill;
}

public void FillReceiveBuffer(ReadOnlySpan<byte> buffer)
public void FillReceiveBuffer(ReadOnlySequence<byte> buffer)
{
if (!_asyncEvent.HasHandlers)
if (_receiveState == ReceiveState.Disable)
{
return;
}

_receivedPacketBuffer?.Write(buffer);
if (_receiveState == ReceiveState.End)
{
throw new InvalidOperationException("FillReceiveBuffer is not allowed in End state.");
}

var writer = _pipeIn.Writer;
foreach (var memory in buffer)
{
writer.Write(memory.Span);
}

_receiveState = ReceiveState.Fill;
}

public void FillReceiveBuffer(ReadOnlySequence<byte> buffer)

public async Task EndReceivePacket()
{
if (!_asyncEvent.HasHandlers)
if (_receiveState == ReceiveState.Disable || _receiveState == ReceiveState.End)
{
return;
}

if (_receivedPacketBuffer != null)
{
foreach (var memory in buffer)
{
_receivedPacketBuffer.Write(memory.Span);
}
}
await _pipeIn.Writer.FlushAsync().ConfigureAwait(false);
await _pipeIn.Writer.CompleteAsync().ConfigureAwait(false);
await InspectPacketAsync(_pipeIn.Reader, MqttPacketFlowDirection.Inbound).ConfigureAwait(false);

// reset pipe
await _pipeIn.Reader.CompleteAsync().ConfigureAwait(false);
_pipeIn.Reset();

_receiveState = ReceiveState.End;
}

async Task InspectPacket(byte[] buffer, MqttPacketFlowDirection direction)

async Task InspectPacketAsync(PipeReader pipeReader, MqttPacketFlowDirection direction)
{
try
{
var buffer = await ReadBufferAsync(pipeReader, default).ConfigureAwait(false);
var eventArgs = new InspectMqttPacketEventArgs(direction, buffer);
await _asyncEvent.InvokeAsync(eventArgs).ConfigureAwait(false);
}
Expand All @@ -110,4 +149,24 @@ async Task InspectPacket(byte[] buffer, MqttPacketFlowDirection direction)
_logger.Error(exception, "Error while inspecting packet.");
}
}

static async ValueTask<ReadOnlySequence<byte>> ReadBufferAsync(PipeReader pipeReader, CancellationToken cancellationToken)
{
var readResult = await pipeReader.ReadAsync(cancellationToken).ConfigureAwait(false);
while (!readResult.IsCompleted)
{
pipeReader.AdvanceTo(readResult.Buffer.Start, readResult.Buffer.Start);
readResult = await pipeReader.ReadAsync(cancellationToken).ConfigureAwait(false);
}

return readResult.Buffer;
}

private enum ReceiveState
{
Disable,
Begin,
Fill,
End,
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,19 @@
// See the LICENSE file in the project root for more information.

using System;
using System.Buffers;

namespace MQTTnet.Diagnostics.PacketInspection
{
public sealed class InspectMqttPacketEventArgs : EventArgs
{
public InspectMqttPacketEventArgs(MqttPacketFlowDirection direction, byte[] buffer)
public InspectMqttPacketEventArgs(MqttPacketFlowDirection direction, ReadOnlySequence<byte> buffer)
{
Direction = direction;
Buffer = buffer ?? throw new ArgumentNullException(nameof(buffer));
Buffer = buffer;
}

public byte[] Buffer { get; }
public ReadOnlySequence<byte> Buffer { get; }

public MqttPacketFlowDirection Direction { get; }
}
Expand Down

0 comments on commit 759a816

Please sign in to comment.