Skip to content

Commit

Permalink
FIX: DefaultMachineId doesn't throw; MQTT: buffer release, better str…
Browse files Browse the repository at this point in the history
…ing decoding (#175)

Modifications
- algorithm of determining best machine's MAC modified to not throw
- added convenient IReferenceCounted.SafeRelease extension method
- added release for buffers in MQTT codec in case processing fails to ensure release (or proper data flow) of the buffer in all cases
- fixed the way string is decoded in MQTT codec
  • Loading branch information
nayato authored Nov 25, 2016
1 parent 4488596 commit 17a69a0
Show file tree
Hide file tree
Showing 11 changed files with 415 additions and 263 deletions.
2 changes: 1 addition & 1 deletion src/DotNetty.Codecs.Mqtt/MqttDecoder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -466,7 +466,7 @@ static string DecodeString(IByteBuffer buffer, ref int remainingLength, int minB

DecreaseRemainingLength(ref remainingLength, size);

string value = Encoding.UTF8.GetString(buffer.Array, buffer.ArrayOffset + buffer.ReaderIndex, size);
string value = buffer.ToString(buffer.ReaderIndex, size, Encoding.UTF8);
// todo: enforce string definition by MQTT spec
buffer.SetReaderIndex(buffer.ReaderIndex + size);
return value;
Expand Down
201 changes: 133 additions & 68 deletions src/DotNetty.Codecs.Mqtt/MqttEncoder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ namespace DotNetty.Codecs.Mqtt
using DotNetty.Buffers;
using DotNetty.Codecs.Mqtt.Packets;
using DotNetty.Common;
using DotNetty.Common.Utilities;
using DotNetty.Transport.Channels;

public sealed class MqttEncoder : MessageToMessageEncoder<Packet>
Expand Down Expand Up @@ -124,44 +125,55 @@ static void EncodeConnectMessage(IByteBufferAllocator bufferAllocator, ConnectPa
int variableHeaderBufferSize = StringSizeLength + protocolNameBytes.Length + 4;
int variablePartSize = variableHeaderBufferSize + payloadBufferSize;
int fixedHeaderBufferSize = 1 + MaxVariableLength;
IByteBuffer buf = bufferAllocator.Buffer(fixedHeaderBufferSize + variablePartSize);
buf.WriteByte(CalculateFirstByteOfFixedHeader(packet));
WriteVariableLengthInt(buf, variablePartSize);
IByteBuffer buf = null;
try
{
buf = bufferAllocator.Buffer(fixedHeaderBufferSize + variablePartSize);
buf.WriteByte(CalculateFirstByteOfFixedHeader(packet));
WriteVariableLengthInt(buf, variablePartSize);

buf.WriteShort(protocolNameBytes.Length);
buf.WriteBytes(protocolNameBytes);
buf.WriteShort(protocolNameBytes.Length);
buf.WriteBytes(protocolNameBytes);

buf.WriteByte(Util.ProtocolLevel);
buf.WriteByte(CalculateConnectFlagsByte(packet));
buf.WriteShort(packet.KeepAliveInSeconds);
buf.WriteByte(Util.ProtocolLevel);
buf.WriteByte(CalculateConnectFlagsByte(packet));
buf.WriteShort(packet.KeepAliveInSeconds);

// Payload
buf.WriteShort(clientIdBytes.Length);
buf.WriteBytes(clientIdBytes, 0, clientIdBytes.Length);
if (packet.HasWill)
{
buf.WriteShort(willTopicBytes.Length);
buf.WriteBytes(willTopicBytes, 0, willTopicBytes.Length);
buf.WriteShort(willMessage.ReadableBytes);
if (willMessage.IsReadable())
// Payload
buf.WriteShort(clientIdBytes.Length);
buf.WriteBytes(clientIdBytes, 0, clientIdBytes.Length);
if (packet.HasWill)
{
buf.WriteBytes(willMessage);
buf.WriteShort(willTopicBytes.Length);
buf.WriteBytes(willTopicBytes, 0, willTopicBytes.Length);
buf.WriteShort(willMessage.ReadableBytes);
if (willMessage.IsReadable())
{
buf.WriteBytes(willMessage);
}
willMessage.Release();
willMessage = null;
}
willMessage.Release();
}
if (packet.HasUsername)
{
buf.WriteShort(userNameBytes.Length);
buf.WriteBytes(userNameBytes, 0, userNameBytes.Length);

if (packet.HasPassword)
if (packet.HasUsername)
{
buf.WriteShort(passwordBytes.Length);
buf.WriteBytes(passwordBytes, 0, passwordBytes.Length);
buf.WriteShort(userNameBytes.Length);
buf.WriteBytes(userNameBytes, 0, userNameBytes.Length);

if (packet.HasPassword)
{
buf.WriteShort(passwordBytes.Length);
buf.WriteBytes(passwordBytes, 0, passwordBytes.Length);
}
}
}

output.Add(buf);
output.Add(buf);
buf = null;
}
finally
{
buf?.SafeRelease();
willMessage?.SafeRelease();
}
}

static int CalculateConnectFlagsByte(ConnectPacket packet)
Expand Down Expand Up @@ -193,20 +205,30 @@ static int CalculateConnectFlagsByte(ConnectPacket packet)

static void EncodeConnAckMessage(IByteBufferAllocator bufferAllocator, ConnAckPacket message, List<object> output)
{
IByteBuffer buffer = bufferAllocator.Buffer(4);
buffer.WriteByte(CalculateFirstByteOfFixedHeader(message));
buffer.WriteByte(2); // remaining length
if (message.SessionPresent)
IByteBuffer buffer = null;
try
{
buffer.WriteByte(1); // 7 reserved 0-bits and SP = 1
buffer = bufferAllocator.Buffer(4);
buffer.WriteByte(CalculateFirstByteOfFixedHeader(message));
buffer.WriteByte(2); // remaining length
if (message.SessionPresent)
{
buffer.WriteByte(1); // 7 reserved 0-bits and SP = 1
}
else
{
buffer.WriteByte(0); // 7 reserved 0-bits and SP = 0
}
buffer.WriteByte((byte)message.ReturnCode);


output.Add(buffer);
buffer = null;
}
else
finally
{
buffer.WriteByte(0); // 7 reserved 0-bits and SP = 0
buffer?.SafeRelease();
}
buffer.WriteByte((byte)message.ReturnCode);

output.Add(buffer);
}

static void EncodePublishMessage(IByteBufferAllocator bufferAllocator, PublishPacket packet, List<object> output)
Expand All @@ -223,17 +245,26 @@ static void EncodePublishMessage(IByteBufferAllocator bufferAllocator, PublishPa
int variablePartSize = variableHeaderBufferSize + payloadBufferSize;
int fixedHeaderBufferSize = 1 + MaxVariableLength;

IByteBuffer buf = bufferAllocator.Buffer(fixedHeaderBufferSize + variablePartSize);
buf.WriteByte(CalculateFirstByteOfFixedHeader(packet));
WriteVariableLengthInt(buf, variablePartSize);
buf.WriteShort(topicNameBytes.Length);
buf.WriteBytes(topicNameBytes);
if (packet.QualityOfService > QualityOfService.AtMostOnce)
IByteBuffer buf = null;
try
{
buf.WriteShort(packet.PacketId);
}
buf = bufferAllocator.Buffer(fixedHeaderBufferSize + variablePartSize);
buf.WriteByte(CalculateFirstByteOfFixedHeader(packet));
WriteVariableLengthInt(buf, variablePartSize);
buf.WriteShort(topicNameBytes.Length);
buf.WriteBytes(topicNameBytes);
if (packet.QualityOfService > QualityOfService.AtMostOnce)
{
buf.WriteShort(packet.PacketId);
}

output.Add(buf);
output.Add(buf);
buf = null;
}
finally
{
buf?.SafeRelease();
}

if (payload.IsReadable())
{
Expand All @@ -247,12 +278,21 @@ static void EncodePacketWithIdOnly(IByteBufferAllocator bufferAllocator, PacketW

const int VariableHeaderBufferSize = PacketIdLength; // variable part only has a packet id
int fixedHeaderBufferSize = 1 + MaxVariableLength;
IByteBuffer buffer = bufferAllocator.Buffer(fixedHeaderBufferSize + VariableHeaderBufferSize);
buffer.WriteByte(CalculateFirstByteOfFixedHeader(packet));
WriteVariableLengthInt(buffer, VariableHeaderBufferSize);
buffer.WriteShort(msgId);
IByteBuffer buffer = null;
try
{
buffer = bufferAllocator.Buffer(fixedHeaderBufferSize + VariableHeaderBufferSize);
buffer.WriteByte(CalculateFirstByteOfFixedHeader(packet));
WriteVariableLengthInt(buffer, VariableHeaderBufferSize);
buffer.WriteShort(msgId);

output.Add(buffer);
output.Add(buffer);
buffer = null;
}
finally
{
buffer?.SafeRelease();
}
}

static void EncodeSubscribeMessage(IByteBufferAllocator bufferAllocator, SubscribePacket packet, List<object> output)
Expand All @@ -262,6 +302,7 @@ static void EncodeSubscribeMessage(IByteBufferAllocator bufferAllocator, Subscri

ThreadLocalObjectList encodedTopicFilters = ThreadLocalObjectList.NewInstance();

IByteBuffer buf = null;
try
{
foreach (SubscriptionRequest topic in packet.Requests)
Expand All @@ -274,7 +315,7 @@ static void EncodeSubscribeMessage(IByteBufferAllocator bufferAllocator, Subscri
int variablePartSize = VariableHeaderSize + payloadBufferSize;
int fixedHeaderBufferSize = 1 + MaxVariableLength;

IByteBuffer buf = bufferAllocator.Buffer(fixedHeaderBufferSize + variablePartSize);
buf = bufferAllocator.Buffer(fixedHeaderBufferSize + variablePartSize);
buf.WriteByte(CalculateFirstByteOfFixedHeader(packet));
WriteVariableLengthInt(buf, variablePartSize);

Expand All @@ -291,9 +332,11 @@ static void EncodeSubscribeMessage(IByteBufferAllocator bufferAllocator, Subscri
}

output.Add(buf);
buf = null;
}
finally
{
buf?.SafeRelease();
encodedTopicFilters.Return();
}
}
Expand All @@ -303,16 +346,26 @@ static void EncodeSubAckMessage(IByteBufferAllocator bufferAllocator, SubAckPack
int payloadBufferSize = message.ReturnCodes.Count;
int variablePartSize = PacketIdLength + payloadBufferSize;
int fixedHeaderBufferSize = 1 + MaxVariableLength;
IByteBuffer buf = bufferAllocator.Buffer(fixedHeaderBufferSize + variablePartSize);
buf.WriteByte(CalculateFirstByteOfFixedHeader(message));
WriteVariableLengthInt(buf, variablePartSize);
buf.WriteShort(message.PacketId);
foreach (QualityOfService qos in message.ReturnCodes)
IByteBuffer buf = null;
try
{
buf.WriteByte((byte)qos);
}
buf = bufferAllocator.Buffer(fixedHeaderBufferSize + variablePartSize);
buf.WriteByte(CalculateFirstByteOfFixedHeader(message));
WriteVariableLengthInt(buf, variablePartSize);
buf.WriteShort(message.PacketId);
foreach (QualityOfService qos in message.ReturnCodes)
{
buf.WriteByte((byte)qos);
}

output.Add(buf);
buf = null;

output.Add(buf);
}
finally
{
buf?.SafeRelease();
}
}

static void EncodeUnsubscribeMessage(IByteBufferAllocator bufferAllocator, UnsubscribePacket packet, List<object> output)
Expand All @@ -322,6 +375,7 @@ static void EncodeUnsubscribeMessage(IByteBufferAllocator bufferAllocator, Unsub

ThreadLocalObjectList encodedTopicFilters = ThreadLocalObjectList.NewInstance();

IByteBuffer buf = null;
try
{
foreach (string topic in packet.TopicFilters)
Expand All @@ -334,7 +388,7 @@ static void EncodeUnsubscribeMessage(IByteBufferAllocator bufferAllocator, Unsub
int variablePartSize = VariableHeaderSize + payloadBufferSize;
int fixedHeaderBufferSize = 1 + MaxVariableLength;

IByteBuffer buf = bufferAllocator.Buffer(fixedHeaderBufferSize + variablePartSize);
buf = bufferAllocator.Buffer(fixedHeaderBufferSize + variablePartSize);
buf.WriteByte(CalculateFirstByteOfFixedHeader(packet));
WriteVariableLengthInt(buf, variablePartSize);

Expand All @@ -350,20 +404,31 @@ static void EncodeUnsubscribeMessage(IByteBufferAllocator bufferAllocator, Unsub
}

output.Add(buf);
buf = null;
}
finally
{
buf?.SafeRelease();
encodedTopicFilters.Return();
}
}

static void EncodePacketWithFixedHeaderOnly(IByteBufferAllocator bufferAllocator, Packet packet, List<object> output)
{
IByteBuffer buffer = bufferAllocator.Buffer(2);
buffer.WriteByte(CalculateFirstByteOfFixedHeader(packet));
buffer.WriteByte(0);
IByteBuffer buffer = null;
try
{
buffer = bufferAllocator.Buffer(2);
buffer.WriteByte(CalculateFirstByteOfFixedHeader(packet));
buffer.WriteByte(0);

output.Add(buffer);
output.Add(buffer);
buffer = null;
}
finally
{
buffer?.SafeRelease();
}
}

static int CalculateFirstByteOfFixedHeader(Packet packet)
Expand Down
18 changes: 9 additions & 9 deletions src/DotNetty.Codecs.Redis/RedisBulkStringAggregator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ public sealed class RedisBulkStringAggregator : MessageToMessageDecoder<IRedisMe

public RedisBulkStringAggregator()
: this(RedisConstants.MaximumMessageLength)
{ }
{
}

RedisBulkStringAggregator(int maximumContentLength)
{
Expand All @@ -36,10 +37,7 @@ public RedisBulkStringAggregator()

public int MaximumCumulationBufferComponents
{
get
{
return this.maximumCumulationBufferComponents;
}
get { return this.maximumCumulationBufferComponents; }
set
{
Contract.Requires(value >= 2);
Expand All @@ -66,8 +64,8 @@ public override bool AcceptInboundMessage(object message)
var redisMessage = (IRedisMessage)message;

return (IsContentMessage(redisMessage)
|| IsStartMessage(redisMessage))
&& !IsAggregated(redisMessage);
|| IsStartMessage(redisMessage))
&& !IsAggregated(redisMessage);
}

protected override void Decode(IChannelHandlerContext context, IRedisMessage message, List<object> output)
Expand Down Expand Up @@ -138,6 +136,7 @@ protected override void Decode(IChannelHandlerContext context, IRedisMessage mes
throw new MessageAggregationException($"Unexpected message {message}");
}
}

static void AppendPartialContent(CompositeByteBuffer content, IByteBuffer partialContent)
{
Contract.Requires(content != null);
Expand All @@ -154,6 +153,7 @@ static void AppendPartialContent(CompositeByteBuffer content, IByteBuffer partia
// Note that WriterIndex must be manually increased
content.SetWriterIndex(content.WriterIndex + buffer.ReadableBytes);
}

void InvokeHandleOversizedMessage(IChannelHandlerContext context, BulkStringHeaderRedisMessage startMessage)
{
Contract.Requires(context != null);
Expand All @@ -177,7 +177,7 @@ static bool IsStartMessage(IRedisMessage message)
{
Contract.Requires(message != null);

return message is BulkStringHeaderRedisMessage
return message is BulkStringHeaderRedisMessage
&& !IsAggregated(message);
}

Expand Down Expand Up @@ -216,4 +216,4 @@ static bool IsAggregated(IRedisMessage message)
return message is IFullBulkStringRedisMessage;
}
}
}
}
Loading

0 comments on commit 17a69a0

Please sign in to comment.