Skip to content

Commit

Permalink
Add support for task cancellation
Browse files Browse the repository at this point in the history
  • Loading branch information
groue committed Oct 1, 2022
1 parent b6ec764 commit 6565552
Show file tree
Hide file tree
Showing 5 changed files with 300 additions and 37 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
<?xml version="1.0" encoding="UTF-8"?>
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
<plist version="1.0">
<dict>
<key>IDEDidComputeMac32BitWarning</key>
<true/>
</dict>
</plist>
92 changes: 92 additions & 0 deletions .swiftpm/xcode/xcshareddata/xcschemes/Semaphore.xcscheme
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
<?xml version="1.0" encoding="UTF-8"?>
<Scheme
LastUpgradeVersion = "1400"
version = "1.3">
<BuildAction
parallelizeBuildables = "YES"
buildImplicitDependencies = "YES">
<BuildActionEntries>
<BuildActionEntry
buildForTesting = "YES"
buildForRunning = "YES"
buildForProfiling = "YES"
buildForArchiving = "YES"
buildForAnalyzing = "YES">
<BuildableReference
BuildableIdentifier = "primary"
BlueprintIdentifier = "Semaphore"
BuildableName = "Semaphore"
BlueprintName = "Semaphore"
ReferencedContainer = "container:">
</BuildableReference>
</BuildActionEntry>
<BuildActionEntry
buildForTesting = "YES"
buildForRunning = "YES"
buildForProfiling = "NO"
buildForArchiving = "NO"
buildForAnalyzing = "YES">
<BuildableReference
BuildableIdentifier = "primary"
BlueprintIdentifier = "SemaphoreTests"
BuildableName = "SemaphoreTests"
BlueprintName = "SemaphoreTests"
ReferencedContainer = "container:">
</BuildableReference>
</BuildActionEntry>
</BuildActionEntries>
</BuildAction>
<TestAction
buildConfiguration = "Debug"
selectedDebuggerIdentifier = "Xcode.DebuggerFoundation.Debugger.LLDB"
selectedLauncherIdentifier = "Xcode.DebuggerFoundation.Launcher.LLDB"
shouldUseLaunchSchemeArgsEnv = "YES"
codeCoverageEnabled = "YES">
<Testables>
<TestableReference
skipped = "NO">
<BuildableReference
BuildableIdentifier = "primary"
BlueprintIdentifier = "SemaphoreTests"
BuildableName = "SemaphoreTests"
BlueprintName = "SemaphoreTests"
ReferencedContainer = "container:">
</BuildableReference>
</TestableReference>
</Testables>
</TestAction>
<LaunchAction
buildConfiguration = "Debug"
selectedDebuggerIdentifier = "Xcode.DebuggerFoundation.Debugger.LLDB"
selectedLauncherIdentifier = "Xcode.DebuggerFoundation.Launcher.LLDB"
launchStyle = "0"
useCustomWorkingDirectory = "NO"
ignoresPersistentStateOnLaunch = "NO"
debugDocumentVersioning = "YES"
debugServiceExtension = "internal"
allowLocationSimulation = "YES">
</LaunchAction>
<ProfileAction
buildConfiguration = "Release"
shouldUseLaunchSchemeArgsEnv = "YES"
savedToolIdentifier = ""
useCustomWorkingDirectory = "NO"
debugDocumentVersioning = "YES">
<MacroExpansion>
<BuildableReference
BuildableIdentifier = "primary"
BlueprintIdentifier = "Semaphore"
BuildableName = "Semaphore"
BlueprintName = "Semaphore"
ReferencedContainer = "container:">
</BuildableReference>
</MacroExpansion>
</ProfileAction>
<AnalyzeAction
buildConfiguration = "Debug">
</AnalyzeAction>
<ArchiveAction
buildConfiguration = "Release"
revealArchiveInOrganizer = "YES">
</ArchiveAction>
</Scheme>
6 changes: 4 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,22 +11,24 @@ let semaphore = Semaphore(value: 0)

Task {
// Suspends the task until a signal occurs.
await semaphore.wait()
try await semaphore.wait()
await doSomething()
}

// Resumes the suspended task.
semaphore.signal()
```

The `wait()` method throws `CancellationError` if the task is cancelled while waiting for a signal.

Semaphores also provide a way to restrict the access to a limited resource. The sample code below makes sure that `downloadAndSave()` waits until the previous call has completed:

```swift
let semaphore = Semaphore(value: 1)

// There is at most one task that is downloading and saving at any given time
func downloadAndSave() async throws {
await semaphore.wait()
try await semaphore.wait()
let value = try await downloadValue()
try await save(value)
semaphore.signal()
Expand Down
109 changes: 84 additions & 25 deletions Sources/Semaphore/Semaphore.swift
Original file line number Diff line number Diff line change
Expand Up @@ -41,18 +41,27 @@ import Foundation
///
/// - ``wait()``
/// - ``run(_:)``
public class Semaphore {
public final class Semaphore {
/// The semaphore value.
private var value: Int

/// An array of continuations that release waiting tasks.
private var continuations: [UnsafeContinuation<Void, Never>] = []
private class Suspension {
enum State {
case pending
case suspended(UnsafeContinuation<Void, Error>)
case cancelled
}
var state = State.pending
}

private var suspensions: [Suspension] = []

/// This lock would be required even if ``Semaphore`` were made an actor,
/// because `withUnsafeContinuation` suspends before it runs its closure
/// argument. Also, by making ``Semaphore`` a plain class, we can expose a
/// non-async ``signal()`` method.
private let lock = NSLock()
/// non-async ``signal()`` method. The lock is recursive in order to handle
/// cancellation (see the implementation of ``wait()``).
private let lock = NSRecursiveLock()

/// Creates a semaphore.
///
Expand All @@ -64,28 +73,72 @@ public class Semaphore {
}

deinit {
precondition(continuations.isEmpty, "Semaphore is deallocated while some task(s) are suspended waiting for a signal.")
precondition(suspensions.isEmpty, "Semaphore is deallocated while some task(s) are suspended waiting for a signal.")
}

/// Waits for, or decrements, a semaphore.
///
/// Decrement the counting semaphore. If the resulting value is less than
/// zero, this function suspends the current task until a signal occurs.
/// Otherwise, no suspension happens.
public func wait() async {
/// zero, this function suspends the current task until a signal occurs,
/// without blocking the underlying thread. Otherwise, no suspension happens.
///
/// - Throws: If the task is canceled before a signal occurs, this function
/// throws `CancellationError`.
public func wait() async throws {
lock.lock()

value -= 1
if value < 0 {
await withUnsafeContinuation { continuation in
// The first task to wait will be the first task woken by `signal`.
// This is not intended to be a strong fifo guarantee, but just
// an attempt at some fairness.
continuations.insert(continuation, at: 0)
lock.unlock()
}
} else {
if value >= 0 {
lock.unlock()
return
}

// Get ready for being suspended waiting for a continuation, or for
// early cancellation.
let suspension = Suspension()

try await withTaskCancellationHandler {
try await withUnsafeThrowingContinuation { (continuation: UnsafeContinuation<Void, Error>) in
if case .cancelled = suspension.state {
// Current task was already cancelled when withTaskCancellationHandler
// was invoked.
lock.unlock()
continuation.resume(throwing: CancellationError())
} else {
// Current task was not cancelled: register the continuation
// that `signal` will resume.
//
// The first suspended task will be the first task resumed by `signal`.
// This is not intended to be a strong fifo guarantee, but just
// an attempt at some fairness.
suspension.state = .suspended(continuation)
suspensions.insert(suspension, at: 0)
lock.unlock()
}
}
} onCancel: {
// withTaskCancellationHandler may immediately call this block (if
// the current task is cancelled), or call it later (if the task is
// cancelled later). In the first case, we're still holding the lock,
// waiting for the continuation. In the second case, we do not hold
// the lock. This is the reason why we use a recursive lock.
lock.lock()
defer { lock.unlock() }

// We're no longer waiting for a signal
value += 1
if let index = suspensions.firstIndex(where: { $0 === suspension }) {
suspensions.remove(at: index)
}

if case let .suspended(continuation) = suspension.state {
// Task is cancelled while suspended: resume with a CancellationError.
continuation.resume(throwing: CancellationError())
} else {
// Current task is cancelled
// Next step: withUnsafeThrowingContinuation right above
suspension.state = .cancelled
}
}
}

Expand All @@ -94,15 +147,15 @@ public class Semaphore {
/// Increment the counting semaphore. If the previous value was less than
/// zero, this function resumes a task currently suspended in ``wait()``.
///
/// - returns This function returns true if a task is resumed. Otherwise,
/// false is returned.
/// - returns This function returns true if a suspended task is resumed.
/// Otherwise, false is returned.
@discardableResult
public func signal() -> Bool {
lock.lock()
defer { lock.unlock() }

value += 1
if let continuation = continuations.popLast() {
if case let .suspended(continuation) = suspensions.popLast()?.state {
continuation.resume()
return true
}
Expand All @@ -114,16 +167,22 @@ public class Semaphore {
/// The two sample codes below are equivalent:
///
/// ```swift
/// let value = await semaphore.run {
/// let value = try await semaphore.run {
/// await getValue()
/// }
///
/// await semaphore.wait()
/// try await semaphore.wait()
/// let value = await getValue()
/// semaphore.signal()
/// ```
public func run<T>(_ execute: @escaping () async throws -> T) async rethrows -> T {
await wait()
///
/// - Parameter execute: The closure to execute between `wait()`
/// and `signal()`.
/// - Throws: If the task is canceled before a signal occurs, this function
/// throws `CancellationError`. Otherwise, it throws the error thrown by
/// the `execute` closure.
public func run<T>(_ execute: @escaping () async throws -> T) async throws -> T {
try await wait()
defer { signal() }
return try await execute()
}
Expand Down
Loading

0 comments on commit 6565552

Please sign in to comment.