diff --git a/.swiftpm/xcode/package.xcworkspace/xcshareddata/IDEWorkspaceChecks.plist b/.swiftpm/xcode/package.xcworkspace/xcshareddata/IDEWorkspaceChecks.plist
new file mode 100644
index 0000000..18d9810
--- /dev/null
+++ b/.swiftpm/xcode/package.xcworkspace/xcshareddata/IDEWorkspaceChecks.plist
@@ -0,0 +1,8 @@
+
+
+
+
+ IDEDidComputeMac32BitWarning
+
+
+
diff --git a/.swiftpm/xcode/xcshareddata/xcschemes/Semaphore.xcscheme b/.swiftpm/xcode/xcshareddata/xcschemes/Semaphore.xcscheme
new file mode 100644
index 0000000..99f6062
--- /dev/null
+++ b/.swiftpm/xcode/xcshareddata/xcschemes/Semaphore.xcscheme
@@ -0,0 +1,92 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/README.md b/README.md
index 42cb43e..7d5124e 100644
--- a/README.md
+++ b/README.md
@@ -11,7 +11,7 @@ let semaphore = Semaphore(value: 0)
Task {
// Suspends the task until a signal occurs.
- await semaphore.wait()
+ try await semaphore.wait()
await doSomething()
}
@@ -19,6 +19,8 @@ Task {
semaphore.signal()
```
+The `wait()` method throws `CancellationError` if the task is cancelled while waiting for a signal.
+
Semaphores also provide a way to restrict the access to a limited resource. The sample code below makes sure that `downloadAndSave()` waits until the previous call has completed:
```swift
@@ -26,7 +28,7 @@ let semaphore = Semaphore(value: 1)
// There is at most one task that is downloading and saving at any given time
func downloadAndSave() async throws {
- await semaphore.wait()
+ try await semaphore.wait()
let value = try await downloadValue()
try await save(value)
semaphore.signal()
diff --git a/Sources/Semaphore/Semaphore.swift b/Sources/Semaphore/Semaphore.swift
index 543d40a..0ec30ac 100644
--- a/Sources/Semaphore/Semaphore.swift
+++ b/Sources/Semaphore/Semaphore.swift
@@ -41,18 +41,27 @@ import Foundation
///
/// - ``wait()``
/// - ``run(_:)``
-public class Semaphore {
+public final class Semaphore {
/// The semaphore value.
private var value: Int
- /// An array of continuations that release waiting tasks.
- private var continuations: [UnsafeContinuation] = []
+ private class Suspension {
+ enum State {
+ case pending
+ case suspended(UnsafeContinuation)
+ case cancelled
+ }
+ var state = State.pending
+ }
+
+ private var suspensions: [Suspension] = []
/// This lock would be required even if ``Semaphore`` were made an actor,
/// because `withUnsafeContinuation` suspends before it runs its closure
/// argument. Also, by making ``Semaphore`` a plain class, we can expose a
- /// non-async ``signal()`` method.
- private let lock = NSLock()
+ /// non-async ``signal()`` method. The lock is recursive in order to handle
+ /// cancellation (see the implementation of ``wait()``).
+ private let lock = NSRecursiveLock()
/// Creates a semaphore.
///
@@ -64,28 +73,72 @@ public class Semaphore {
}
deinit {
- precondition(continuations.isEmpty, "Semaphore is deallocated while some task(s) are suspended waiting for a signal.")
+ precondition(suspensions.isEmpty, "Semaphore is deallocated while some task(s) are suspended waiting for a signal.")
}
/// Waits for, or decrements, a semaphore.
///
/// Decrement the counting semaphore. If the resulting value is less than
- /// zero, this function suspends the current task until a signal occurs.
- /// Otherwise, no suspension happens.
- public func wait() async {
+ /// zero, this function suspends the current task until a signal occurs,
+ /// without blocking the underlying thread. Otherwise, no suspension happens.
+ ///
+ /// - Throws: If the task is canceled before a signal occurs, this function
+ /// throws `CancellationError`.
+ public func wait() async throws {
lock.lock()
value -= 1
- if value < 0 {
- await withUnsafeContinuation { continuation in
- // The first task to wait will be the first task woken by `signal`.
- // This is not intended to be a strong fifo guarantee, but just
- // an attempt at some fairness.
- continuations.insert(continuation, at: 0)
- lock.unlock()
- }
- } else {
+ if value >= 0 {
lock.unlock()
+ return
+ }
+
+ // Get ready for being suspended waiting for a continuation, or for
+ // early cancellation.
+ let suspension = Suspension()
+
+ try await withTaskCancellationHandler {
+ try await withUnsafeThrowingContinuation { (continuation: UnsafeContinuation) in
+ if case .cancelled = suspension.state {
+ // Current task was already cancelled when withTaskCancellationHandler
+ // was invoked.
+ lock.unlock()
+ continuation.resume(throwing: CancellationError())
+ } else {
+ // Current task was not cancelled: register the continuation
+ // that `signal` will resume.
+ //
+ // The first suspended task will be the first task resumed by `signal`.
+ // This is not intended to be a strong fifo guarantee, but just
+ // an attempt at some fairness.
+ suspension.state = .suspended(continuation)
+ suspensions.insert(suspension, at: 0)
+ lock.unlock()
+ }
+ }
+ } onCancel: {
+ // withTaskCancellationHandler may immediately call this block (if
+ // the current task is cancelled), or call it later (if the task is
+ // cancelled later). In the first case, we're still holding the lock,
+ // waiting for the continuation. In the second case, we do not hold
+ // the lock. This is the reason why we use a recursive lock.
+ lock.lock()
+ defer { lock.unlock() }
+
+ // We're no longer waiting for a signal
+ value += 1
+ if let index = suspensions.firstIndex(where: { $0 === suspension }) {
+ suspensions.remove(at: index)
+ }
+
+ if case let .suspended(continuation) = suspension.state {
+ // Task is cancelled while suspended: resume with a CancellationError.
+ continuation.resume(throwing: CancellationError())
+ } else {
+ // Current task is cancelled
+ // Next step: withUnsafeThrowingContinuation right above
+ suspension.state = .cancelled
+ }
}
}
@@ -94,15 +147,15 @@ public class Semaphore {
/// Increment the counting semaphore. If the previous value was less than
/// zero, this function resumes a task currently suspended in ``wait()``.
///
- /// - returns This function returns true if a task is resumed. Otherwise,
- /// false is returned.
+ /// - returns This function returns true if a suspended task is resumed.
+ /// Otherwise, false is returned.
@discardableResult
public func signal() -> Bool {
lock.lock()
defer { lock.unlock() }
value += 1
- if let continuation = continuations.popLast() {
+ if case let .suspended(continuation) = suspensions.popLast()?.state {
continuation.resume()
return true
}
@@ -114,16 +167,22 @@ public class Semaphore {
/// The two sample codes below are equivalent:
///
/// ```swift
- /// let value = await semaphore.run {
+ /// let value = try await semaphore.run {
/// await getValue()
/// }
///
- /// await semaphore.wait()
+ /// try await semaphore.wait()
/// let value = await getValue()
/// semaphore.signal()
/// ```
- public func run(_ execute: @escaping () async throws -> T) async rethrows -> T {
- await wait()
+ ///
+ /// - Parameter execute: The closure to execute between `wait()`
+ /// and `signal()`.
+ /// - Throws: If the task is canceled before a signal occurs, this function
+ /// throws `CancellationError`. Otherwise, it throws the error thrown by
+ /// the `execute` closure.
+ public func run(_ execute: @escaping () async throws -> T) async throws -> T {
+ try await wait()
defer { signal() }
return try await execute()
}
diff --git a/Tests/SemaphoreTests/SemaphoreTests.swift b/Tests/SemaphoreTests/SemaphoreTests.swift
index 756d582..19856f4 100644
--- a/Tests/SemaphoreTests/SemaphoreTests.swift
+++ b/Tests/SemaphoreTests/SemaphoreTests.swift
@@ -61,7 +61,7 @@ final class SemaphoreTests: XCTestCase {
do {
// Given a task suspended on the semaphore
let sem = Semaphore(value: 0)
- Task { await sem.wait() }
+ Task { try await sem.wait() }
try await Task.sleep(nanoseconds: delay)
// First signal resumes the suspended task
@@ -76,17 +76,18 @@ final class SemaphoreTests: XCTestCase {
do {
// Given a zero semaphore
let sem = DispatchSemaphore(value: 0)
+
+ // When a thread waits for this semaphore,
let ex1 = expectation(description: "wait")
ex1.isInverted = true
let ex2 = expectation(description: "woken")
-
- // When a thread waits for this semaphore,
- // Then the thread is initially blocked.
Thread {
sem.wait()
ex1.fulfill()
ex2.fulfill()
}.start()
+
+ // Then the thread is initially blocked.
wait(for: [ex1], timeout: 0.5)
// When a signal occurs, then the waiting thread is woken.
@@ -98,17 +99,18 @@ final class SemaphoreTests: XCTestCase {
do {
// Given a zero semaphore
let sem = Semaphore(value: 0)
+
+ // When a task waits for this semaphore,
let ex1 = expectation(description: "wait")
ex1.isInverted = true
let ex2 = expectation(description: "woken")
-
- // When a task waits for this semaphore,
- // Then the task is initially suspended.
Task {
- await sem.wait()
+ try await sem.wait()
ex1.fulfill()
ex2.fulfill()
}
+
+ // Then the task is initially suspended.
wait(for: [ex1], timeout: 0.5)
// When a signal occurs, then the suspended task is resumed.
@@ -117,6 +119,106 @@ final class SemaphoreTests: XCTestCase {
}
}
+ func test_cancellation_while_suspended_throws_CancellationError() async throws {
+ let sem = Semaphore(value: 0)
+ let ex = expectation(description: "cancellation")
+ let task = Task {
+ do {
+ try await sem.wait()
+ XCTFail("Expected CancellationError")
+ } catch is CancellationError {
+ } catch {
+ XCTFail("Unexpected error")
+ }
+ ex.fulfill()
+ }
+ try await Task.sleep(nanoseconds: 100_000_000)
+ task.cancel()
+ wait(for: [ex], timeout: 1)
+ }
+
+ func test_cancellation_before_suspension_throws_CancellationError() async throws {
+ let sem = Semaphore(value: 0)
+ let ex = expectation(description: "cancellation")
+ let task = Task {
+ // Uncancellable delay
+ await withUnsafeContinuation { continuation in
+ DispatchQueue.main.asyncAfter(deadline: .now() + 0.1) {
+ continuation.resume()
+ }
+ }
+ do {
+ try await sem.wait()
+ XCTFail("Expected CancellationError")
+ } catch is CancellationError {
+ } catch {
+ XCTFail("Unexpected error")
+ }
+ ex.fulfill()
+ }
+ task.cancel()
+ wait(for: [ex], timeout: 5)
+ }
+
+ func test_that_cancellation_while_suspended_increments_the_semaphore() async throws {
+ // Given a task cancelled while suspended on a semaphore,
+ let sem = Semaphore(value: 0)
+ let task = Task {
+ try await sem.wait()
+ }
+ try await Task.sleep(nanoseconds: 100_000_000)
+ task.cancel()
+
+ // When a task waits for this semaphore,
+ let ex1 = expectation(description: "wait")
+ ex1.isInverted = true
+ let ex2 = expectation(description: "woken")
+ Task {
+ try await sem.wait()
+ ex1.fulfill()
+ ex2.fulfill()
+ }
+
+ // Then the task is initially suspended.
+ wait(for: [ex1], timeout: 0.5)
+
+ // When a signal occurs, then the suspended task is resumed.
+ sem.signal()
+ wait(for: [ex2], timeout: 0.5)
+ }
+
+ func test_that_cancellation_before_suspension_increments_the_semaphore() async throws {
+ // Given a task cancelled before it waits on a semaphore,
+ let sem = Semaphore(value: 0)
+ let task = Task {
+ // Uncancellable delay
+ await withUnsafeContinuation { continuation in
+ DispatchQueue.main.asyncAfter(deadline: .now() + 0.1) {
+ continuation.resume()
+ }
+ }
+ try await sem.wait()
+ }
+ task.cancel()
+
+ // When a task waits for this semaphore,
+ let ex1 = expectation(description: "wait")
+ ex1.isInverted = true
+ let ex2 = expectation(description: "woken")
+ Task {
+ try await sem.wait()
+ ex1.fulfill()
+ ex2.fulfill()
+ }
+
+ // Then the task is initially suspended.
+ wait(for: [ex1], timeout: 0.5)
+
+ // When a signal occurs, then the suspended task is resumed.
+ sem.signal()
+ wait(for: [ex2], timeout: 0.5)
+ }
+
func test_semaphore_as_a_resource_limiter() async {
/// An actor that counts the maximum number of concurrent executions of
/// the `run()` method.
@@ -137,10 +239,10 @@ final class SemaphoreTests: XCTestCase {
let sem = Semaphore(value: count)
// Spawn many concurrent tasks
- await withTaskGroup(of: Void.self) { group in
+ await withThrowingTaskGroup(of: Void.self) { group in
for _ in 0..<(maxCount * 2) {
group.addTask {
- await sem.wait()
+ try await sem.wait()
await runner.run()
sem.signal()
}