From 5373d8ddc3e8a72346fd7ddc9ca24dc85b54723f Mon Sep 17 00:00:00 2001 From: metalgearsloth Date: Mon, 20 Nov 2023 02:46:34 +1100 Subject: [PATCH] Recycleable memory manager --- .../GameObjects/RobustMemoryManager.cs | 41 +++++++++++++++++++ .../Network/Messages/MsgConCmdAck.cs | 4 +- Robust.Shared/Network/Messages/MsgEntity.cs | 5 ++- .../Network/Messages/MsgScriptResponse.cs | 4 +- Robust.Shared/Network/Messages/MsgState.cs | 40 ++++-------------- .../Messages/MsgViewVariablesModifyRemote.cs | 7 +++- .../Messages/MsgViewVariablesRemoteData.cs | 4 +- .../Messages/MsgViewVariablesReqData.cs | 4 +- .../Messages/MsgViewVariablesReqSession.cs | 4 +- Robust.Shared/Network/NetMessageExt.cs | 8 ++-- Robust.Shared/Robust.Shared.csproj | 1 + Robust.Shared/SharedIoC.cs | 1 + .../ViewVariables/MsgViewVariablesPath.cs | 4 +- Robust.UnitTesting/Shared/Timing/TimerTest.cs | 1 - 14 files changed, 83 insertions(+), 45 deletions(-) create mode 100644 Robust.Shared/GameObjects/RobustMemoryManager.cs diff --git a/Robust.Shared/GameObjects/RobustMemoryManager.cs b/Robust.Shared/GameObjects/RobustMemoryManager.cs new file mode 100644 index 00000000000..4eb3106258e --- /dev/null +++ b/Robust.Shared/GameObjects/RobustMemoryManager.cs @@ -0,0 +1,41 @@ +using System.IO; +using Microsoft.IO; +using Robust.Shared.IoC; +using Robust.Shared.Log; +using Robust.Shared.Utility; +using SixLabors.ImageSharp.Memory; + +namespace Robust.Shared.GameObjects; + +/// +/// Generic memory manager for engine use. +/// +internal sealed class RobustMemoryManager +{ + // Let's be real this is a bandaid for pooling bullshit at an engine level and I don't know what + // good memory management looks like for PVS or the RobustSerializer. + + private static readonly RecyclableMemoryStreamManager MemStreamManager = new() + { + ThrowExceptionOnToArray = true, + }; + + public RobustMemoryManager() + { + MemStreamManager.StreamDoubleDisposed += (sender, args) => + throw new InvalidMemoryOperationException("Found double disposed stream."); + + MemStreamManager.StreamFinalized += (sender, args) => + throw new InvalidMemoryOperationException("Stream finalized but not disposed indicating a leak"); + + MemStreamManager.StreamOverCapacity += (sender, args) => + throw new InvalidMemoryOperationException("Stream over memory capacity"); + } + + public static MemoryStream GetMemoryStream() + { + var stream = MemStreamManager.GetStream(); + DebugTools.Assert(stream.Position == 0); + return stream; + } +} diff --git a/Robust.Shared/Network/Messages/MsgConCmdAck.cs b/Robust.Shared/Network/Messages/MsgConCmdAck.cs index 8933142a0db..73e6f7363f6 100644 --- a/Robust.Shared/Network/Messages/MsgConCmdAck.cs +++ b/Robust.Shared/Network/Messages/MsgConCmdAck.cs @@ -1,5 +1,6 @@ using System.IO; using Lidgren.Network; +using Robust.Shared.GameObjects; using Robust.Shared.Serialization; using Robust.Shared.Utility; @@ -17,7 +18,8 @@ public sealed class MsgConCmdAck : NetMessage public override void ReadFromBuffer(NetIncomingMessage buffer, IRobustSerializer serializer) { int length = buffer.ReadVariableInt32(); - using var stream = buffer.ReadAlignedMemory(length); + using var stream = RobustMemoryManager.GetMemoryStream(); + buffer.ReadAlignedMemory(stream, length); Text = serializer.Deserialize(stream); } diff --git a/Robust.Shared/Network/Messages/MsgEntity.cs b/Robust.Shared/Network/Messages/MsgEntity.cs index 3f59fe4f85d..62cb1cbc0fb 100644 --- a/Robust.Shared/Network/Messages/MsgEntity.cs +++ b/Robust.Shared/Network/Messages/MsgEntity.cs @@ -33,8 +33,9 @@ public override void ReadFromBuffer(NetIncomingMessage buffer, IRobustSerializer { case EntityMessageType.SystemMessage: { - int length = buffer.ReadVariableInt32(); - using var stream = buffer.ReadAlignedMemory(length); + var length = buffer.ReadVariableInt32(); + using var stream = RobustMemoryManager.GetMemoryStream(); + buffer.ReadAlignedMemory(stream, length); SystemMessage = serializer.Deserialize(stream); } break; diff --git a/Robust.Shared/Network/Messages/MsgScriptResponse.cs b/Robust.Shared/Network/Messages/MsgScriptResponse.cs index 745e52552fb..40f647544e9 100644 --- a/Robust.Shared/Network/Messages/MsgScriptResponse.cs +++ b/Robust.Shared/Network/Messages/MsgScriptResponse.cs @@ -1,5 +1,6 @@ using System.IO; using Lidgren.Network; +using Robust.Shared.GameObjects; using Robust.Shared.IoC; using Robust.Shared.Serialization; using Robust.Shared.Utility; @@ -28,7 +29,8 @@ public override void ReadFromBuffer(NetIncomingMessage buffer, IRobustSerializer { buffer.ReadPadBits(); var length = buffer.ReadVariableInt32(); - using var stream = buffer.ReadAlignedMemory(length); + using var stream = RobustMemoryManager.GetMemoryStream(); + buffer.ReadAlignedMemory(stream, length); serializer.DeserializeDirect(stream, out Echo); serializer.DeserializeDirect(stream, out Response); } diff --git a/Robust.Shared/Network/Messages/MsgState.cs b/Robust.Shared/Network/Messages/MsgState.cs index e1a0b394725..314e99cd0e2 100644 --- a/Robust.Shared/Network/Messages/MsgState.cs +++ b/Robust.Shared/Network/Messages/MsgState.cs @@ -2,7 +2,7 @@ using System.Buffers; using System.IO; using Lidgren.Network; -using Microsoft.Extensions.ObjectPool; +using Robust.Shared.GameObjects; using Robust.Shared.GameStates; using Robust.Shared.Serialization; using Robust.Shared.Utility; @@ -27,57 +27,36 @@ public sealed class MsgState : NetMessage internal bool _hasWritten; - private static readonly ObjectPool StreamPool = - new DefaultObjectPool(new MemoryStreamPolicy()); - - private sealed class MemoryStreamPolicy : IPooledObjectPolicy - { - public MemoryStream Create() - { - return new MemoryStream(); - } - - public bool Return(MemoryStream obj) - { - obj.Position = 0; - return true; - } - } - public override void ReadFromBuffer(NetIncomingMessage buffer, IRobustSerializer serializer) { MsgSize = buffer.LengthBytes; var uncompressedLength = buffer.ReadVariableInt32(); var compressedLength = buffer.ReadVariableInt32(); - MemoryStream finalStream; + using var finalStream = RobustMemoryManager.GetMemoryStream(); // State is compressed. if (compressedLength > 0) { - var stream = buffer.ReadAlignedMemory(compressedLength); + var stream = RobustMemoryManager.GetMemoryStream(); + buffer.ReadAlignedMemory(stream, compressedLength); using var decompressStream = new ZStdDecompressStream(stream); - var decompressed = StreamPool.Get(); - decompressed.SetLength(uncompressedLength); - decompressStream.CopyTo(decompressed, uncompressedLength); - decompressed.Position = 0; - finalStream = decompressed; + finalStream.SetLength(uncompressedLength); + decompressStream.CopyTo(finalStream, uncompressedLength); + finalStream.Position = 0; } // State is uncompressed. else { - var stream = buffer.ReadAlignedMemory(uncompressedLength); - finalStream = stream; + buffer.ReadAlignedMemory(finalStream, uncompressedLength); } serializer.DeserializeDirect(finalStream, out State); - StreamPool.Return(finalStream); - State.PayloadSize = uncompressedLength; } public override void WriteToBuffer(NetOutgoingMessage buffer, IRobustSerializer serializer) { - var stateStream = StreamPool.Get(); + using var stateStream = RobustMemoryManager.GetMemoryStream(); serializer.SerializeDirect(stateStream, State); buffer.WriteVariableInt32((int)stateStream.Length); @@ -108,7 +87,6 @@ public override void WriteToBuffer(NetOutgoingMessage buffer, IRobustSerializer buffer.Write(stateStream.AsSpan()); } - StreamPool.Return(stateStream); _hasWritten = true; MsgSize = buffer.LengthBytes; } diff --git a/Robust.Shared/Network/Messages/MsgViewVariablesModifyRemote.cs b/Robust.Shared/Network/Messages/MsgViewVariablesModifyRemote.cs index 645f46f2a90..d38536952fb 100644 --- a/Robust.Shared/Network/Messages/MsgViewVariablesModifyRemote.cs +++ b/Robust.Shared/Network/Messages/MsgViewVariablesModifyRemote.cs @@ -1,5 +1,6 @@ using System.IO; using Lidgren.Network; +using Robust.Shared.GameObjects; using Robust.Shared.IoC; using Robust.Shared.Serialization; using Robust.Shared.Utility; @@ -42,12 +43,14 @@ public override void ReadFromBuffer(NetIncomingMessage buffer, IRobustSerializer SessionId = buffer.ReadUInt32(); { var length = buffer.ReadInt32(); - using var stream = buffer.ReadAlignedMemory(length); + using var stream = RobustMemoryManager.GetMemoryStream(); + buffer.ReadAlignedMemory(stream, length); PropertyIndex = serializer.Deserialize(stream); } { var length = buffer.ReadInt32(); - using var stream = buffer.ReadAlignedMemory(length); + using var stream = RobustMemoryManager.GetMemoryStream(); + buffer.ReadAlignedMemory(stream, length); Value = serializer.Deserialize(stream); } ReinterpretValue = buffer.ReadBoolean(); diff --git a/Robust.Shared/Network/Messages/MsgViewVariablesRemoteData.cs b/Robust.Shared/Network/Messages/MsgViewVariablesRemoteData.cs index 1b017558bd9..18097bcadb4 100644 --- a/Robust.Shared/Network/Messages/MsgViewVariablesRemoteData.cs +++ b/Robust.Shared/Network/Messages/MsgViewVariablesRemoteData.cs @@ -1,5 +1,6 @@ using System.IO; using Lidgren.Network; +using Robust.Shared.GameObjects; using Robust.Shared.IoC; using Robust.Shared.Serialization; using Robust.Shared.Utility; @@ -31,7 +32,8 @@ public override void ReadFromBuffer(NetIncomingMessage buffer, IRobustSerializer { RequestId = buffer.ReadUInt32(); var length = buffer.ReadInt32(); - using var stream = buffer.ReadAlignedMemory(length); + using var stream = RobustMemoryManager.GetMemoryStream(); + buffer.ReadAlignedMemory(stream, length); Blob = serializer.Deserialize(stream); } diff --git a/Robust.Shared/Network/Messages/MsgViewVariablesReqData.cs b/Robust.Shared/Network/Messages/MsgViewVariablesReqData.cs index 0072d8ec48f..75f066de646 100644 --- a/Robust.Shared/Network/Messages/MsgViewVariablesReqData.cs +++ b/Robust.Shared/Network/Messages/MsgViewVariablesReqData.cs @@ -1,5 +1,6 @@ using System.IO; using Lidgren.Network; +using Robust.Shared.GameObjects; using Robust.Shared.IoC; using Robust.Shared.Serialization; using Robust.Shared.Utility; @@ -37,7 +38,8 @@ public override void ReadFromBuffer(NetIncomingMessage buffer, IRobustSerializer RequestId = buffer.ReadUInt32(); SessionId = buffer.ReadUInt32(); var length = buffer.ReadInt32(); - using var stream = buffer.ReadAlignedMemory(length); + using var stream = RobustMemoryManager.GetMemoryStream(); + buffer.ReadAlignedMemory(stream, length); RequestMeta = serializer.Deserialize(stream); } diff --git a/Robust.Shared/Network/Messages/MsgViewVariablesReqSession.cs b/Robust.Shared/Network/Messages/MsgViewVariablesReqSession.cs index 9ec811a7773..c3c25a6c898 100644 --- a/Robust.Shared/Network/Messages/MsgViewVariablesReqSession.cs +++ b/Robust.Shared/Network/Messages/MsgViewVariablesReqSession.cs @@ -1,5 +1,6 @@ using System.IO; using Lidgren.Network; +using Robust.Shared.GameObjects; using Robust.Shared.IoC; using Robust.Shared.Serialization; using Robust.Shared.Utility; @@ -32,7 +33,8 @@ public override void ReadFromBuffer(NetIncomingMessage buffer, IRobustSerializer { RequestId = buffer.ReadUInt32(); var length = buffer.ReadInt32(); - using var stream = buffer.ReadAlignedMemory(length); + using var stream = RobustMemoryManager.GetMemoryStream(); + buffer.ReadAlignedMemory(stream, length); Selector = serializer.Deserialize(stream); } diff --git a/Robust.Shared/Network/NetMessageExt.cs b/Robust.Shared/Network/NetMessageExt.cs index 4fda623ad5f..d2025939630 100644 --- a/Robust.Shared/Network/NetMessageExt.cs +++ b/Robust.Shared/Network/NetMessageExt.cs @@ -6,6 +6,7 @@ using Robust.Shared.Map; using Robust.Shared.Maths; using Robust.Shared.Timing; +using Robust.Shared.Utility; namespace Robust.Shared.Network { @@ -96,16 +97,17 @@ public static void Write(this NetOutgoingMessage message, Color color) /// /// Thrown if the current read position of the message is not byte-aligned. /// - public static MemoryStream ReadAlignedMemory(this NetIncomingMessage message, int length) + public static void ReadAlignedMemory(this NetIncomingMessage message, MemoryStream memStream, int length) { if ((message.Position & 7) != 0) { throw new ArgumentException("Read position in message must be byte-aligned", nameof(message)); } - var stream = new MemoryStream(message.Data, message.PositionInBytes, length, false); + DebugTools.Assert(memStream.Position == 0); + memStream.Write(message.Data, message.PositionInBytes, length); + memStream.Position = 0; message.Position += length * 8; - return stream; } public static TimeSpan ReadTimeSpan(this NetIncomingMessage message) diff --git a/Robust.Shared/Robust.Shared.csproj b/Robust.Shared/Robust.Shared.csproj index 6900ddda297..d77073bfe4e 100644 --- a/Robust.Shared/Robust.Shared.csproj +++ b/Robust.Shared/Robust.Shared.csproj @@ -9,6 +9,7 @@ + diff --git a/Robust.Shared/SharedIoC.cs b/Robust.Shared/SharedIoC.cs index 93449d5d187..352a499da96 100644 --- a/Robust.Shared/SharedIoC.cs +++ b/Robust.Shared/SharedIoC.cs @@ -49,6 +49,7 @@ public static void RegisterIoC(IDependencyCollection deps) deps.Register(); deps.Register(); deps.Register(); + deps.Register(); } } } diff --git a/Robust.Shared/ViewVariables/MsgViewVariablesPath.cs b/Robust.Shared/ViewVariables/MsgViewVariablesPath.cs index aa497da0779..89f3dbd6752 100644 --- a/Robust.Shared/ViewVariables/MsgViewVariablesPath.cs +++ b/Robust.Shared/ViewVariables/MsgViewVariablesPath.cs @@ -1,6 +1,7 @@ using System; using System.IO; using Lidgren.Network; +using Robust.Shared.GameObjects; using Robust.Shared.IoC; using Robust.Shared.Network; using Robust.Shared.Serialization; @@ -156,7 +157,8 @@ public override void ReadFromBuffer(NetIncomingMessage buffer, IRobustSerializer { base.ReadFromBuffer(buffer, serializer); var length = buffer.ReadInt32(); - using var stream = buffer.ReadAlignedMemory(length); + using var stream = RobustMemoryManager.GetMemoryStream(); + buffer.ReadAlignedMemory(stream, length); Options = serializer.Deserialize(stream); } diff --git a/Robust.UnitTesting/Shared/Timing/TimerTest.cs b/Robust.UnitTesting/Shared/Timing/TimerTest.cs index 12a28f4a3db..17e21838073 100644 --- a/Robust.UnitTesting/Shared/Timing/TimerTest.cs +++ b/Robust.UnitTesting/Shared/Timing/TimerTest.cs @@ -71,7 +71,6 @@ async void Run() public void TestCancellation() { var timerManager = IoCManager.Resolve(); - var taskManager = IoCManager.Resolve(); var cts = new CancellationTokenSource(); var ran = false;