Skip to content

Commit

Permalink
Protocol: allow (early) dispose of sent handles. (#290)
Browse files Browse the repository at this point in the history
  • Loading branch information
tmds authored Jun 19, 2024
1 parent b447942 commit 2966198
Show file tree
Hide file tree
Showing 8 changed files with 232 additions and 56 deletions.
34 changes: 28 additions & 6 deletions src/Tmds.DBus.Protocol/Connection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ public async Task CallMethodAsync(MessageBuffer message)
DBusConnection connection;
try
{
RefHandles(message);
connection = await ConnectCoreAsync().ConfigureAwait(false);
}
catch
Expand All @@ -222,6 +223,7 @@ public async Task<T> CallMethodAsync<T>(MessageBuffer message, MessageValueReade
DBusConnection connection;
try
{
RefHandles(message);
connection = await ConnectCoreAsync().ConfigureAwait(false);
}
catch
Expand All @@ -232,6 +234,15 @@ public async Task<T> CallMethodAsync<T>(MessageBuffer message, MessageValueReade
return await connection.CallMethodAsync(message, reader, readerState).ConfigureAwait(false);
}

private void RefHandles(MessageBuffer message)
{
// Take a reference on any handles we might be sending.
// This ensures the handles are valid or that we throw an exception at this point.
// It also enables a user to to dispose the handles as soon as Connection method returns
// (without having to await it).
message.RefHandles();
}

[Obsolete("Use an overload that accepts ObserverFlags.")]
public ValueTask<IDisposable> AddMatchAsync<T>(MatchRule rule, MessageValueReader<T> reader, Action<Exception?, T, object?, object?> handler, object? readerState = null, object? handlerState = null, bool emitOnCapturedContext = true, bool subscribe = true)
=> AddMatchAsync(rule, reader, handler, readerState, handlerState, emitOnCapturedContext, ObserverFlags.EmitOnDispose | (!subscribe ? ObserverFlags.NoSubscribe : default));
Expand Down Expand Up @@ -285,14 +296,25 @@ private static Connection CreateConnection(ref Connection? field, string? addres

public bool TrySendMessage(MessageBuffer message)
{
DBusConnection? connection = GetConnection(ifConnected: true);
if (connection is null)
bool messageSent = false;
try
{
message.ReturnToPool();
return false;
DBusConnection? connection = GetConnection(ifConnected: true);
if (connection is not null)
{
RefHandles(message);
connection.SendMessage(message);
messageSent = true;
}
return messageSent;
}
finally
{
if (!messageSent)
{
message.ReturnToPool();
}
}
connection.SendMessage(message);
return true;
}

public Task<Exception?> DisconnectedAsync()
Expand Down
3 changes: 2 additions & 1 deletion src/Tmds.DBus.Protocol/Message.cs
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,8 @@ internal void ReturnToPool()
{
_data.Reset();
ClearHeaders();
_handles?.DisposeHandles();
_handles?.Dispose();
_handles = null;
_pool.Return(this);
}

Expand Down
4 changes: 3 additions & 1 deletion src/Tmds.DBus.Protocol/MessageBuffer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,16 @@ internal void Init(uint serial, MessageFlags flags, UnixFdCollection? handles)
Handles = handles;
}

internal void RefHandles() => Handles?.RefHandles();

// Users should create a message using a MessageWriter
// and then hand it to the Connection class which is responsible for calling this method.
// A library user is never considered the owner of this message and therefore
// we don't provide a public method for a user to Dispose/ReturnToPool.
internal void ReturnToPool()
{
_data.Reset();
Handles?.DisposeHandles();
Handles?.Dispose();
Handles = null;
_messagePool.Return(this);
}
Expand Down
2 changes: 1 addition & 1 deletion src/Tmds.DBus.Protocol/MessageStream.cs
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ private async void ReadMessagesIntoSocket()
var message = await _messageReader.ReadAsync().ConfigureAwait(false);
try
{
IReadOnlyList<SafeHandle>? handles = _supportsFdPassing ? message.Handles : null;
UnixFdCollection? handles = _supportsFdPassing ? message.Handles : null;
var buffer = message.AsReadOnlySequence();
if (buffer.IsSingleSegment)
{
Expand Down
44 changes: 20 additions & 24 deletions src/Tmds.DBus.Protocol/SocketExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ private async static ValueTask<int> ReceiveWithHandlesAsync(this Socket socket,
}
}

public static ValueTask SendAsync(this Socket socket, ReadOnlyMemory<byte> buffer, IReadOnlyList<SafeHandle>? handles)
public static ValueTask SendAsync(this Socket socket, ReadOnlyMemory<byte> buffer, UnixFdCollection? handles)
{
if (handles is null || handles.Count == 0)
{
Expand All @@ -65,7 +65,7 @@ private static async ValueTask SendAsync(this Socket socket, ReadOnlyMemory<byte
}
}

private static ValueTask SendAsyncWithHandlesAsync(this Socket socket, ReadOnlyMemory<byte> buffer, IReadOnlyList<SafeHandle> handles)
private static ValueTask SendAsyncWithHandlesAsync(this Socket socket, ReadOnlyMemory<byte> buffer, UnixFdCollection handles)
{
socket.Blocking = false;
do
Expand All @@ -92,7 +92,7 @@ private static ValueTask SendAsyncWithHandlesAsync(this Socket socket, ReadOnlyM
} while (true);
}

private static unsafe int sendmsg(Socket socket, ReadOnlyMemory<byte> buffer, IReadOnlyList<SafeHandle> handles)
private static unsafe int sendmsg(Socket socket, ReadOnlyMemory<byte> buffer, UnixFdCollection handles)
{
fixed (byte* ptr = buffer.Span)
{
Expand All @@ -113,32 +113,26 @@ private static unsafe int sendmsg(Socket socket, ReadOnlyMemory<byte> buffer, IR
fdm.hdr.cmsg_type = SCM_RIGHTS;

SafeHandle handle = socket.GetSafeHandle();
int handleRefsAdded = 0;
bool refAdded = false;
try
lock (handles.SyncObject)
{
handle.DangerousAddRef(ref refAdded);
for (int i = 0, j = 0; i < handles.Count; i++)
bool refAdded = false;
try
{
bool added = false;
SafeHandle h = handles[i];
h.DangerousAddRef(ref added);
handleRefsAdded++;
fdm.fds[j++] = h.DangerousGetHandle().ToInt32();
}
handle.DangerousAddRef(ref refAdded);
for (int i = 0, j = 0; i < handles.Count; i++)
{
fdm.fds[j++] = handles.DangerousGetHandle(i);
}

return (int)sendmsg(handle.DangerousGetHandle().ToInt32(), new IntPtr(&msg), 0);
}
finally
{
for (int i = 0; i < handleRefsAdded; i++)
return (int)sendmsg(handle.DangerousGetHandle().ToInt32(), new IntPtr(&msg), 0);
}
finally
{
SafeHandle h = handles[i];
h.DangerousRelease();
if (refAdded)
{
handle.DangerousRelease();
}
}

if (refAdded)
handle.DangerousRelease();
}
}
}
Expand Down Expand Up @@ -183,7 +177,9 @@ private static unsafe int recvmsg(Socket socket, Memory<byte> buffer, UnixFdColl
finally
{
if (refAdded)
{
handle.DangerousRelease();
}
}
}
}
Expand Down
98 changes: 75 additions & 23 deletions src/Tmds.DBus.Protocol/UnixFdCollection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ sealed class UnixFdCollection : IReadOnlyList<SafeHandle>, IDisposable
// We don't need to lock it while adding handles, or reading them to send them.
private readonly object _gate;
private bool _disposed;
private bool _handlesReffed;

internal object SyncObject => _gate;

internal bool IsRawHandleCollection => _rawHandles is not null;

Expand All @@ -36,6 +39,7 @@ internal int AddHandle(IntPtr handle)

internal void AddHandle(SafeHandle handle)
{
Debug.Assert(!_handlesReffed);
if (handle is null)
{
throw new ArgumentNullException(nameof(handle));
Expand Down Expand Up @@ -149,69 +153,117 @@ IEnumerator IEnumerable.GetEnumerator()
throw new NotSupportedException();
}

public void DisposeHandles(int count = -1)
public void Dispose()
{
if (count != 0)
lock (_gate)
{
DisposeHandles(true, count);
if (_disposed)
{
return;
}
_disposed = true;

DisposeHandles(disposing: true);
}

GC.SuppressFinalize(this);
}

public void Dispose()
internal void RefHandles()
{
lock (_gate)
{
if (_disposed)
{
return;
ThrowDisposed();
}
_disposed = true;
DisposeHandles(true);

int handleRefsAdded = 0;
try
{
for (int i = 0; i < Count; i++)
{
bool added = false;
SafeHandle h = this[i];
h.DangerousAddRef(ref added);
handleRefsAdded++;
}

_handlesReffed = true;
}
catch
{
for (int i = 0; i < handleRefsAdded; i++)
{
SafeHandle h = this[i];
h.DangerousRelease();
}

throw;
}
}
}

internal int DangerousGetHandle(int i)
{
Debug.Assert(Monitor.IsEntered(_gate));
if (!_handlesReffed)
{
throw new InvalidOperationException("Trying to send an unreffed handle.");
}
int fd = this[i].DangerousGetHandle().ToInt32();
return fd;
}

~UnixFdCollection()
{
DisposeHandles(false);
}

private void DisposeHandles(bool disposing, int count = -1)
private void DisposeHandles(bool disposing)
{
if (count == -1)
bool handlesReffed = _handlesReffed;
if (handlesReffed)
{
count = Count;
_handlesReffed = false;

for (int i = 0; i < Count; i++)
{
SafeHandle h = this[i];
h.DangerousRelease();
}
}

if (disposing)
if (disposing || handlesReffed)
{
// dispose managed state
if (_handles is not null)
{
for (int i = 0; i < count; i++)
for (int i = 0; i < Count; i++)
{
var handle = _handles[i];
if (handle.Handle is not null)
{
handle.Handle.Dispose();
}
}
_handles.RemoveRange(0, count);
_handles.Clear();
}
}
else

// free unmanaged resources
if (_rawHandles is not null)
{
if (_rawHandles is not null)
for (int i = 0; i < Count; i++)
{
for (int i = 0; i < count; i++)
{
var handle = _rawHandles[i];
var handle = _rawHandles[i];

if (handle.RawHandle != InvalidRawHandle)
{
close(handle.RawHandle.ToInt32());
}
if (handle.RawHandle != InvalidRawHandle)
{
close(handle.RawHandle.ToInt32());
}
_rawHandles.RemoveRange(0, count);
}
_rawHandles.Clear();
}
}

Expand Down
3 changes: 3 additions & 0 deletions test/Tmds.DBus.Protocol.Tests/ConnectionTests.cs
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
using System;
using System.Collections.Generic;
using System.IO;
using System.Reflection.Metadata;
using System.Runtime.InteropServices;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using System.Xml;
using Microsoft.Win32.SafeHandles;
using Xunit;

namespace Tmds.DBus.Protocol.Tests
Expand Down
Loading

0 comments on commit 2966198

Please sign in to comment.