Skip to content

Commit

Permalink
Merge pull request #298 from timcassell/fix-parallel-asynclocal
Browse files Browse the repository at this point in the history
Fix `AsyncLocal` in `Promise.Parallel*` body
  • Loading branch information
timcassell authored Nov 6, 2023
2 parents 5b5a292 + e923270 commit ca1e19f
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 15 deletions.
42 changes: 38 additions & 4 deletions Package/Core/Promises/Internal/ParallelInternal.cs
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@ internal sealed partial class PromiseParallelForEach<TEnumerator, TParallelBody,
// Use the CancelationRef directly instead of CancelationSource struct to save memory.
private CancelationRef _cancelationRef;
private SynchronizationContext _synchronizationContext;
private ExecutionContext _executionContext;
private int _remainingAvailableWorkers;
private int _waitCounter;
private List<Exception> _exceptions;
Expand Down Expand Up @@ -257,6 +258,10 @@ internal static PromiseParallelForEach<TEnumerator, TParallelBody, TSource> GetO
promise._completionState = Promise.State.Resolved;
promise._cancelationRef = CancelationRef.GetOrCreate();
cancelationToken.TryRegister(promise, out promise._externalCancelationRegistration);
if (Promise.Config.AsyncFlowExecutionContextEnabled)
{
promise._executionContext = ExecutionContext.Capture();
}
return promise;
}

Expand All @@ -265,6 +270,7 @@ internal override void MaybeDispose()
Dispose();
_body = default(TParallelBody);
_synchronizationContext = null;
_executionContext = null;
ObjectPool.MaybeRepool(this);
}

Expand All @@ -283,12 +289,40 @@ internal void MaybeLaunchWorker(bool launchWorker)
InterlockedAddWithUnsignedOverflowCheck(ref _waitCounter, 1);

ScheduleContextCallback(_synchronizationContext, this,
obj => obj.UnsafeAs<PromiseParallelForEach<TEnumerator, TParallelBody, TSource>>().ExecuteWorker(true),
obj => obj.UnsafeAs<PromiseParallelForEach<TEnumerator, TParallelBody, TSource>>().ExecuteWorker(true)
obj => obj.UnsafeAs<PromiseParallelForEach<TEnumerator, TParallelBody, TSource>>().ExecuteWorkerAndLaunchNext(),
obj => obj.UnsafeAs<PromiseParallelForEach<TEnumerator, TParallelBody, TSource>>().ExecuteWorkerAndLaunchNext()
);
}
}

private void ExecuteWorkerAndLaunchNext()
{
if (_executionContext == null)
{
ExecuteWorker(true);
}
else
{
// .Net Framework doesn't allow us to re-use a captured context, so we have to copy it for each invocation.
// .Net Core's implementation of CreateCopy returns itself, so this is always as efficient as it can be.
ExecutionContext.Run(_executionContext.CreateCopy(), obj => obj.UnsafeAs<PromiseParallelForEach<TEnumerator, TParallelBody, TSource>>().ExecuteWorker(true), this);
}
}

private void ExecuteWorkerWithoutLaunchNext()
{
if (_executionContext == null)
{
ExecuteWorker(false);
}
else
{
// .Net Framework doesn't allow us to re-use a captured context, so we have to copy it for each invocation.
// .Net Core's implementation of CreateCopy returns itself, so this is always as efficient as it can be.
ExecutionContext.Run(_executionContext.CreateCopy(), obj => obj.UnsafeAs<PromiseParallelForEach<TEnumerator, TParallelBody, TSource>>().ExecuteWorker(false), this);
}
}

private void ExecuteWorker(bool launchNext)
{
var currentContext = ts_currentContext;
Expand Down Expand Up @@ -362,8 +396,8 @@ internal override void Handle(PromiseRefBase handler, object rejectContainer, Pr
{
// Schedule the worker body to run again on the context, but without launching another worker.
ScheduleContextCallback(_synchronizationContext, this,
obj => obj.UnsafeAs<PromiseParallelForEach<TEnumerator, TParallelBody, TSource>>().ExecuteWorker(false),
obj => obj.UnsafeAs<PromiseParallelForEach<TEnumerator, TParallelBody, TSource>>().ExecuteWorker(false)
obj => obj.UnsafeAs<PromiseParallelForEach<TEnumerator, TParallelBody, TSource>>().ExecuteWorkerWithoutLaunchNext(),
obj => obj.UnsafeAs<PromiseParallelForEach<TEnumerator, TParallelBody, TSource>>().ExecuteWorkerWithoutLaunchNext()
);
}
else if (state == Promise.State.Canceled)
Expand Down
72 changes: 61 additions & 11 deletions Package/Tests/CoreTests/APIs/ParallelForTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -211,32 +211,35 @@ public void AllItemsEnumeratedOnce_WithCaptureValue_Sync(
}
}

private IEnumerable<int> IterateAndAssertContext(SynchronizationContext context)
private IEnumerable<int> IterateAndAssertContext(SynchronizationType expectedContext, Thread mainThread)
{
Assert.AreEqual(context, SynchronizationContext.Current);
TestHelper.AssertCallbackContext(expectedContext, expectedContext, mainThread);
for (int i = 1; i <= 100; i++)
{
yield return i;
Assert.AreEqual(context, SynchronizationContext.Current);
TestHelper.AssertCallbackContext(expectedContext, expectedContext, mainThread);
}
}

[Test]
public void SynchronizationContext_AllCodeExecutedOnCorrectContext_Sync(
[Values] bool foregroundContext)
[Values(SynchronizationType.Foreground, SynchronizationType.Background)] SynchronizationType syncContext)
{
SynchronizationContext context = foregroundContext ?
TestHelper._foregroundContext :
(SynchronizationContext) TestHelper._backgroundContext;
var mainThread = Thread.CurrentThread;
SynchronizationContext context = syncContext == SynchronizationType.Foreground
? TestHelper._foregroundContext
: (SynchronizationContext) TestHelper._backgroundContext;

var otherContext = new PromiseSynchronizationContext();
var otherContext = syncContext == SynchronizationType.Foreground
? (SynchronizationContext) TestHelper._backgroundContext
: TestHelper._foregroundContext;

var cq = new Queue<int>();
bool isComplete = false;

Promise.ParallelForEach(IterateAndAssertContext(context), (item, cancelationToken) =>
Promise.ParallelForEach(IterateAndAssertContext(syncContext, mainThread), (item, cancelationToken) =>
{
Assert.AreEqual(context, SynchronizationContext.Current);
TestHelper.AssertCallbackContext(syncContext, syncContext, mainThread);
return Promise.SwitchToContext(context)
.Then(() =>
{
Expand All @@ -257,7 +260,6 @@ public void SynchronizationContext_AllCodeExecutedOnCorrectContext_Sync(
if (!SpinWait.SpinUntil(() =>
{
TestHelper.ExecuteForegroundCallbacks();
otherContext.Execute();
return isComplete;
}, TimeSpan.FromSeconds(1)))
{
Expand Down Expand Up @@ -529,6 +531,54 @@ public void ParallelFor_AllIndicesEnumeratedOnce_WithCaptureValue_Sync(
Assert.True(set.Contains(i));
}
}

#if CSHARP_7_3_OR_NEWER
[Test]
public void ParallelFor_ExecutionContextFlowsToWorkerBodies(
[Values] bool foregroundContext)
{
Promise.Config.AsyncFlowExecutionContextEnabled = true;
var context = foregroundContext
? (SynchronizationContext) TestHelper._foregroundContext
: TestHelper._backgroundContext;

var al = new AsyncLocal<int>();
al.Value = 42;
Promise.ParallelFor(0, 100, async (item, cancellationToken) =>
{
await Promise.SwitchToForegroundAwait(forceAsync: true);
Assert.AreEqual(42, al.Value);
})
.WaitWithTimeoutWhileExecutingForegroundContext(TimeSpan.FromSeconds(Environment.ProcessorCount));
}

private static IEnumerable<int> Iterate100()
{
for (int i = 0; i < 100; i++)
{
yield return i;
}
}

[Test]
public void ParallelForEach_ExecutionContextFlowsToWorkerBodies(
[Values] bool foregroundContext)
{
Promise.Config.AsyncFlowExecutionContextEnabled = true;
var context = foregroundContext
? (SynchronizationContext) TestHelper._foregroundContext
: TestHelper._backgroundContext;

var al = new AsyncLocal<int>();
al.Value = 42;
Promise.ParallelForEach(Iterate100(), async (item, cancellationToken) =>
{
await Promise.SwitchToForegroundAwait(forceAsync: true);
Assert.AreEqual(42, al.Value);
})
.WaitWithTimeoutWhileExecutingForegroundContext(TimeSpan.FromSeconds(Environment.ProcessorCount));
}
#endif // CSHARP_7_3_OR_NEWER
}
#endif // !UNITY_WEBGL
}

0 comments on commit ca1e19f

Please sign in to comment.