Skip to content

Commit

Permalink
Recycleable memory manager
Browse files Browse the repository at this point in the history
  • Loading branch information
metalgearsloth committed Nov 19, 2023
1 parent 96217c6 commit 5373d8d
Show file tree
Hide file tree
Showing 14 changed files with 83 additions and 45 deletions.
41 changes: 41 additions & 0 deletions Robust.Shared/GameObjects/RobustMemoryManager.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
using System.IO;
using Microsoft.IO;
using Robust.Shared.IoC;
using Robust.Shared.Log;
using Robust.Shared.Utility;
using SixLabors.ImageSharp.Memory;

namespace Robust.Shared.GameObjects;

/// <summary>
/// Generic memory manager for engine use.
/// </summary>
internal sealed class RobustMemoryManager
{
// Let's be real this is a bandaid for pooling bullshit at an engine level and I don't know what
// good memory management looks like for PVS or the RobustSerializer.

private static readonly RecyclableMemoryStreamManager MemStreamManager = new()
{
ThrowExceptionOnToArray = true,
};

public RobustMemoryManager()
{
MemStreamManager.StreamDoubleDisposed += (sender, args) =>
throw new InvalidMemoryOperationException("Found double disposed stream.");

MemStreamManager.StreamFinalized += (sender, args) =>
throw new InvalidMemoryOperationException("Stream finalized but not disposed indicating a leak");

MemStreamManager.StreamOverCapacity += (sender, args) =>
throw new InvalidMemoryOperationException("Stream over memory capacity");
}

public static MemoryStream GetMemoryStream()
{
var stream = MemStreamManager.GetStream();
DebugTools.Assert(stream.Position == 0);
return stream;
}
}
4 changes: 3 additions & 1 deletion Robust.Shared/Network/Messages/MsgConCmdAck.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using System.IO;
using Lidgren.Network;
using Robust.Shared.GameObjects;
using Robust.Shared.Serialization;
using Robust.Shared.Utility;

Expand All @@ -17,7 +18,8 @@ public sealed class MsgConCmdAck : NetMessage
public override void ReadFromBuffer(NetIncomingMessage buffer, IRobustSerializer serializer)
{
int length = buffer.ReadVariableInt32();
using var stream = buffer.ReadAlignedMemory(length);
using var stream = RobustMemoryManager.GetMemoryStream();
buffer.ReadAlignedMemory(stream, length);
Text = serializer.Deserialize<FormattedMessage>(stream);
}

Expand Down
5 changes: 3 additions & 2 deletions Robust.Shared/Network/Messages/MsgEntity.cs
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,9 @@ public override void ReadFromBuffer(NetIncomingMessage buffer, IRobustSerializer
{
case EntityMessageType.SystemMessage:
{
int length = buffer.ReadVariableInt32();
using var stream = buffer.ReadAlignedMemory(length);
var length = buffer.ReadVariableInt32();
using var stream = RobustMemoryManager.GetMemoryStream();
buffer.ReadAlignedMemory(stream, length);
SystemMessage = serializer.Deserialize<EntityEventArgs>(stream);
}
break;
Expand Down
4 changes: 3 additions & 1 deletion Robust.Shared/Network/Messages/MsgScriptResponse.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using System.IO;
using Lidgren.Network;
using Robust.Shared.GameObjects;
using Robust.Shared.IoC;
using Robust.Shared.Serialization;
using Robust.Shared.Utility;
Expand Down Expand Up @@ -28,7 +29,8 @@ public override void ReadFromBuffer(NetIncomingMessage buffer, IRobustSerializer
{
buffer.ReadPadBits();
var length = buffer.ReadVariableInt32();
using var stream = buffer.ReadAlignedMemory(length);
using var stream = RobustMemoryManager.GetMemoryStream();
buffer.ReadAlignedMemory(stream, length);
serializer.DeserializeDirect(stream, out Echo);
serializer.DeserializeDirect(stream, out Response);
}
Expand Down
40 changes: 9 additions & 31 deletions Robust.Shared/Network/Messages/MsgState.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
using System.Buffers;
using System.IO;
using Lidgren.Network;
using Microsoft.Extensions.ObjectPool;
using Robust.Shared.GameObjects;
using Robust.Shared.GameStates;
using Robust.Shared.Serialization;
using Robust.Shared.Utility;
Expand All @@ -27,57 +27,36 @@ public sealed class MsgState : NetMessage

internal bool _hasWritten;

private static readonly ObjectPool<MemoryStream> StreamPool =
new DefaultObjectPool<MemoryStream>(new MemoryStreamPolicy());

private sealed class MemoryStreamPolicy : IPooledObjectPolicy<MemoryStream>
{
public MemoryStream Create()
{
return new MemoryStream();
}

public bool Return(MemoryStream obj)
{
obj.Position = 0;
return true;
}
}

public override void ReadFromBuffer(NetIncomingMessage buffer, IRobustSerializer serializer)
{
MsgSize = buffer.LengthBytes;
var uncompressedLength = buffer.ReadVariableInt32();
var compressedLength = buffer.ReadVariableInt32();
MemoryStream finalStream;
using var finalStream = RobustMemoryManager.GetMemoryStream();

// State is compressed.
if (compressedLength > 0)
{
var stream = buffer.ReadAlignedMemory(compressedLength);
var stream = RobustMemoryManager.GetMemoryStream();
buffer.ReadAlignedMemory(stream, compressedLength);
using var decompressStream = new ZStdDecompressStream(stream);
var decompressed = StreamPool.Get();
decompressed.SetLength(uncompressedLength);
decompressStream.CopyTo(decompressed, uncompressedLength);
decompressed.Position = 0;
finalStream = decompressed;
finalStream.SetLength(uncompressedLength);
decompressStream.CopyTo(finalStream, uncompressedLength);
finalStream.Position = 0;
}
// State is uncompressed.
else
{
var stream = buffer.ReadAlignedMemory(uncompressedLength);
finalStream = stream;
buffer.ReadAlignedMemory(finalStream, uncompressedLength);
}

serializer.DeserializeDirect(finalStream, out State);
StreamPool.Return(finalStream);

State.PayloadSize = uncompressedLength;
}

public override void WriteToBuffer(NetOutgoingMessage buffer, IRobustSerializer serializer)
{
var stateStream = StreamPool.Get();
using var stateStream = RobustMemoryManager.GetMemoryStream();
serializer.SerializeDirect(stateStream, State);
buffer.WriteVariableInt32((int)stateStream.Length);

Expand Down Expand Up @@ -108,7 +87,6 @@ public override void WriteToBuffer(NetOutgoingMessage buffer, IRobustSerializer
buffer.Write(stateStream.AsSpan());
}

StreamPool.Return(stateStream);
_hasWritten = true;
MsgSize = buffer.LengthBytes;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using System.IO;
using Lidgren.Network;
using Robust.Shared.GameObjects;
using Robust.Shared.IoC;
using Robust.Shared.Serialization;
using Robust.Shared.Utility;
Expand Down Expand Up @@ -42,12 +43,14 @@ public override void ReadFromBuffer(NetIncomingMessage buffer, IRobustSerializer
SessionId = buffer.ReadUInt32();
{
var length = buffer.ReadInt32();
using var stream = buffer.ReadAlignedMemory(length);
using var stream = RobustMemoryManager.GetMemoryStream();
buffer.ReadAlignedMemory(stream, length);
PropertyIndex = serializer.Deserialize<object[]>(stream);
}
{
var length = buffer.ReadInt32();
using var stream = buffer.ReadAlignedMemory(length);
using var stream = RobustMemoryManager.GetMemoryStream();
buffer.ReadAlignedMemory(stream, length);
Value = serializer.Deserialize(stream);
}
ReinterpretValue = buffer.ReadBoolean();
Expand Down
4 changes: 3 additions & 1 deletion Robust.Shared/Network/Messages/MsgViewVariablesRemoteData.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using System.IO;
using Lidgren.Network;
using Robust.Shared.GameObjects;
using Robust.Shared.IoC;
using Robust.Shared.Serialization;
using Robust.Shared.Utility;
Expand Down Expand Up @@ -31,7 +32,8 @@ public override void ReadFromBuffer(NetIncomingMessage buffer, IRobustSerializer
{
RequestId = buffer.ReadUInt32();
var length = buffer.ReadInt32();
using var stream = buffer.ReadAlignedMemory(length);
using var stream = RobustMemoryManager.GetMemoryStream();
buffer.ReadAlignedMemory(stream, length);
Blob = serializer.Deserialize<ViewVariablesBlob>(stream);
}

Expand Down
4 changes: 3 additions & 1 deletion Robust.Shared/Network/Messages/MsgViewVariablesReqData.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using System.IO;
using Lidgren.Network;
using Robust.Shared.GameObjects;
using Robust.Shared.IoC;
using Robust.Shared.Serialization;
using Robust.Shared.Utility;
Expand Down Expand Up @@ -37,7 +38,8 @@ public override void ReadFromBuffer(NetIncomingMessage buffer, IRobustSerializer
RequestId = buffer.ReadUInt32();
SessionId = buffer.ReadUInt32();
var length = buffer.ReadInt32();
using var stream = buffer.ReadAlignedMemory(length);
using var stream = RobustMemoryManager.GetMemoryStream();
buffer.ReadAlignedMemory(stream, length);
RequestMeta = serializer.Deserialize<ViewVariablesRequest>(stream);
}

Expand Down
4 changes: 3 additions & 1 deletion Robust.Shared/Network/Messages/MsgViewVariablesReqSession.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using System.IO;
using Lidgren.Network;
using Robust.Shared.GameObjects;
using Robust.Shared.IoC;
using Robust.Shared.Serialization;
using Robust.Shared.Utility;
Expand Down Expand Up @@ -32,7 +33,8 @@ public override void ReadFromBuffer(NetIncomingMessage buffer, IRobustSerializer
{
RequestId = buffer.ReadUInt32();
var length = buffer.ReadInt32();
using var stream = buffer.ReadAlignedMemory(length);
using var stream = RobustMemoryManager.GetMemoryStream();
buffer.ReadAlignedMemory(stream, length);
Selector = serializer.Deserialize<ViewVariablesObjectSelector>(stream);
}

Expand Down
8 changes: 5 additions & 3 deletions Robust.Shared/Network/NetMessageExt.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
using Robust.Shared.Map;
using Robust.Shared.Maths;
using Robust.Shared.Timing;
using Robust.Shared.Utility;

namespace Robust.Shared.Network
{
Expand Down Expand Up @@ -96,16 +97,17 @@ public static void Write(this NetOutgoingMessage message, Color color)
/// <exception cref="ArgumentException">
/// Thrown if the current read position of the message is not byte-aligned.
/// </exception>
public static MemoryStream ReadAlignedMemory(this NetIncomingMessage message, int length)
public static void ReadAlignedMemory(this NetIncomingMessage message, MemoryStream memStream, int length)
{
if ((message.Position & 7) != 0)
{
throw new ArgumentException("Read position in message must be byte-aligned", nameof(message));
}

var stream = new MemoryStream(message.Data, message.PositionInBytes, length, false);
DebugTools.Assert(memStream.Position == 0);
memStream.Write(message.Data, message.PositionInBytes, length);
memStream.Position = 0;
message.Position += length * 8;
return stream;
}

public static TimeSpan ReadTimeSpan(this NetIncomingMessage message)
Expand Down
1 change: 1 addition & 0 deletions Robust.Shared/Robust.Shared.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
<PackageReference Include="JetBrains.Annotations" Version="2021.3.0" PrivateAssets="All" />
<PackageReference Include="Microsoft.Extensions.ObjectPool" Version="6.0.2" />
<PackageReference Include="Microsoft.ILVerification" Version="6.0.0" PrivateAssets="compile" />
<PackageReference Include="Microsoft.IO.RecyclableMemoryStream" Version="2.3.2" />
<PackageReference Include="Nett" Version="0.15.0" PrivateAssets="compile" />
<PackageReference Include="Pidgin" Version="2.5.0" />
<PackageReference Include="prometheus-net" Version="4.1.1" />
Expand Down
1 change: 1 addition & 0 deletions Robust.Shared/SharedIoC.cs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ public static void RegisterIoC(IDependencyCollection deps)
deps.Register<IParallelManagerInternal, ParallelManager>();
deps.Register<ToolshedManager>();
deps.Register<HttpClientHolder>();
deps.Register<RobustMemoryManager>();
}
}
}
4 changes: 3 additions & 1 deletion Robust.Shared/ViewVariables/MsgViewVariablesPath.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using System;
using System.IO;
using Lidgren.Network;
using Robust.Shared.GameObjects;
using Robust.Shared.IoC;
using Robust.Shared.Network;
using Robust.Shared.Serialization;
Expand Down Expand Up @@ -156,7 +157,8 @@ public override void ReadFromBuffer(NetIncomingMessage buffer, IRobustSerializer
{
base.ReadFromBuffer(buffer, serializer);
var length = buffer.ReadInt32();
using var stream = buffer.ReadAlignedMemory(length);
using var stream = RobustMemoryManager.GetMemoryStream();
buffer.ReadAlignedMemory(stream, length);
Options = serializer.Deserialize<VVListPathOptions>(stream);
}

Expand Down
1 change: 0 additions & 1 deletion Robust.UnitTesting/Shared/Timing/TimerTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@ async void Run()
public void TestCancellation()
{
var timerManager = IoCManager.Resolve<ITimerManager>();
var taskManager = IoCManager.Resolve<ITaskManager>();

var cts = new CancellationTokenSource();
var ran = false;
Expand Down

0 comments on commit 5373d8d

Please sign in to comment.