From 6565552b586a09a83523efd323383dbca53403ea Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gwendal=20Roue=CC=81?= Date: Sat, 1 Oct 2022 17:15:28 +0200 Subject: [PATCH] Add support for task cancellation --- .../xcshareddata/IDEWorkspaceChecks.plist | 8 ++ .../xcshareddata/xcschemes/Semaphore.xcscheme | 92 +++++++++++++ README.md | 6 +- Sources/Semaphore/Semaphore.swift | 109 ++++++++++++---- Tests/SemaphoreTests/SemaphoreTests.swift | 122 ++++++++++++++++-- 5 files changed, 300 insertions(+), 37 deletions(-) create mode 100644 .swiftpm/xcode/package.xcworkspace/xcshareddata/IDEWorkspaceChecks.plist create mode 100644 .swiftpm/xcode/xcshareddata/xcschemes/Semaphore.xcscheme diff --git a/.swiftpm/xcode/package.xcworkspace/xcshareddata/IDEWorkspaceChecks.plist b/.swiftpm/xcode/package.xcworkspace/xcshareddata/IDEWorkspaceChecks.plist new file mode 100644 index 0000000..18d9810 --- /dev/null +++ b/.swiftpm/xcode/package.xcworkspace/xcshareddata/IDEWorkspaceChecks.plist @@ -0,0 +1,8 @@ + + + + + IDEDidComputeMac32BitWarning + + + diff --git a/.swiftpm/xcode/xcshareddata/xcschemes/Semaphore.xcscheme b/.swiftpm/xcode/xcshareddata/xcschemes/Semaphore.xcscheme new file mode 100644 index 0000000..99f6062 --- /dev/null +++ b/.swiftpm/xcode/xcshareddata/xcschemes/Semaphore.xcscheme @@ -0,0 +1,92 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/README.md b/README.md index 42cb43e..7d5124e 100644 --- a/README.md +++ b/README.md @@ -11,7 +11,7 @@ let semaphore = Semaphore(value: 0) Task { // Suspends the task until a signal occurs. - await semaphore.wait() + try await semaphore.wait() await doSomething() } @@ -19,6 +19,8 @@ Task { semaphore.signal() ``` +The `wait()` method throws `CancellationError` if the task is cancelled while waiting for a signal. + Semaphores also provide a way to restrict the access to a limited resource. The sample code below makes sure that `downloadAndSave()` waits until the previous call has completed: ```swift @@ -26,7 +28,7 @@ let semaphore = Semaphore(value: 1) // There is at most one task that is downloading and saving at any given time func downloadAndSave() async throws { - await semaphore.wait() + try await semaphore.wait() let value = try await downloadValue() try await save(value) semaphore.signal() diff --git a/Sources/Semaphore/Semaphore.swift b/Sources/Semaphore/Semaphore.swift index 543d40a..0ec30ac 100644 --- a/Sources/Semaphore/Semaphore.swift +++ b/Sources/Semaphore/Semaphore.swift @@ -41,18 +41,27 @@ import Foundation /// /// - ``wait()`` /// - ``run(_:)`` -public class Semaphore { +public final class Semaphore { /// The semaphore value. private var value: Int - /// An array of continuations that release waiting tasks. - private var continuations: [UnsafeContinuation] = [] + private class Suspension { + enum State { + case pending + case suspended(UnsafeContinuation) + case cancelled + } + var state = State.pending + } + + private var suspensions: [Suspension] = [] /// This lock would be required even if ``Semaphore`` were made an actor, /// because `withUnsafeContinuation` suspends before it runs its closure /// argument. Also, by making ``Semaphore`` a plain class, we can expose a - /// non-async ``signal()`` method. - private let lock = NSLock() + /// non-async ``signal()`` method. The lock is recursive in order to handle + /// cancellation (see the implementation of ``wait()``). + private let lock = NSRecursiveLock() /// Creates a semaphore. /// @@ -64,28 +73,72 @@ public class Semaphore { } deinit { - precondition(continuations.isEmpty, "Semaphore is deallocated while some task(s) are suspended waiting for a signal.") + precondition(suspensions.isEmpty, "Semaphore is deallocated while some task(s) are suspended waiting for a signal.") } /// Waits for, or decrements, a semaphore. /// /// Decrement the counting semaphore. If the resulting value is less than - /// zero, this function suspends the current task until a signal occurs. - /// Otherwise, no suspension happens. - public func wait() async { + /// zero, this function suspends the current task until a signal occurs, + /// without blocking the underlying thread. Otherwise, no suspension happens. + /// + /// - Throws: If the task is canceled before a signal occurs, this function + /// throws `CancellationError`. + public func wait() async throws { lock.lock() value -= 1 - if value < 0 { - await withUnsafeContinuation { continuation in - // The first task to wait will be the first task woken by `signal`. - // This is not intended to be a strong fifo guarantee, but just - // an attempt at some fairness. - continuations.insert(continuation, at: 0) - lock.unlock() - } - } else { + if value >= 0 { lock.unlock() + return + } + + // Get ready for being suspended waiting for a continuation, or for + // early cancellation. + let suspension = Suspension() + + try await withTaskCancellationHandler { + try await withUnsafeThrowingContinuation { (continuation: UnsafeContinuation) in + if case .cancelled = suspension.state { + // Current task was already cancelled when withTaskCancellationHandler + // was invoked. + lock.unlock() + continuation.resume(throwing: CancellationError()) + } else { + // Current task was not cancelled: register the continuation + // that `signal` will resume. + // + // The first suspended task will be the first task resumed by `signal`. + // This is not intended to be a strong fifo guarantee, but just + // an attempt at some fairness. + suspension.state = .suspended(continuation) + suspensions.insert(suspension, at: 0) + lock.unlock() + } + } + } onCancel: { + // withTaskCancellationHandler may immediately call this block (if + // the current task is cancelled), or call it later (if the task is + // cancelled later). In the first case, we're still holding the lock, + // waiting for the continuation. In the second case, we do not hold + // the lock. This is the reason why we use a recursive lock. + lock.lock() + defer { lock.unlock() } + + // We're no longer waiting for a signal + value += 1 + if let index = suspensions.firstIndex(where: { $0 === suspension }) { + suspensions.remove(at: index) + } + + if case let .suspended(continuation) = suspension.state { + // Task is cancelled while suspended: resume with a CancellationError. + continuation.resume(throwing: CancellationError()) + } else { + // Current task is cancelled + // Next step: withUnsafeThrowingContinuation right above + suspension.state = .cancelled + } } } @@ -94,15 +147,15 @@ public class Semaphore { /// Increment the counting semaphore. If the previous value was less than /// zero, this function resumes a task currently suspended in ``wait()``. /// - /// - returns This function returns true if a task is resumed. Otherwise, - /// false is returned. + /// - returns This function returns true if a suspended task is resumed. + /// Otherwise, false is returned. @discardableResult public func signal() -> Bool { lock.lock() defer { lock.unlock() } value += 1 - if let continuation = continuations.popLast() { + if case let .suspended(continuation) = suspensions.popLast()?.state { continuation.resume() return true } @@ -114,16 +167,22 @@ public class Semaphore { /// The two sample codes below are equivalent: /// /// ```swift - /// let value = await semaphore.run { + /// let value = try await semaphore.run { /// await getValue() /// } /// - /// await semaphore.wait() + /// try await semaphore.wait() /// let value = await getValue() /// semaphore.signal() /// ``` - public func run(_ execute: @escaping () async throws -> T) async rethrows -> T { - await wait() + /// + /// - Parameter execute: The closure to execute between `wait()` + /// and `signal()`. + /// - Throws: If the task is canceled before a signal occurs, this function + /// throws `CancellationError`. Otherwise, it throws the error thrown by + /// the `execute` closure. + public func run(_ execute: @escaping () async throws -> T) async throws -> T { + try await wait() defer { signal() } return try await execute() } diff --git a/Tests/SemaphoreTests/SemaphoreTests.swift b/Tests/SemaphoreTests/SemaphoreTests.swift index 756d582..19856f4 100644 --- a/Tests/SemaphoreTests/SemaphoreTests.swift +++ b/Tests/SemaphoreTests/SemaphoreTests.swift @@ -61,7 +61,7 @@ final class SemaphoreTests: XCTestCase { do { // Given a task suspended on the semaphore let sem = Semaphore(value: 0) - Task { await sem.wait() } + Task { try await sem.wait() } try await Task.sleep(nanoseconds: delay) // First signal resumes the suspended task @@ -76,17 +76,18 @@ final class SemaphoreTests: XCTestCase { do { // Given a zero semaphore let sem = DispatchSemaphore(value: 0) + + // When a thread waits for this semaphore, let ex1 = expectation(description: "wait") ex1.isInverted = true let ex2 = expectation(description: "woken") - - // When a thread waits for this semaphore, - // Then the thread is initially blocked. Thread { sem.wait() ex1.fulfill() ex2.fulfill() }.start() + + // Then the thread is initially blocked. wait(for: [ex1], timeout: 0.5) // When a signal occurs, then the waiting thread is woken. @@ -98,17 +99,18 @@ final class SemaphoreTests: XCTestCase { do { // Given a zero semaphore let sem = Semaphore(value: 0) + + // When a task waits for this semaphore, let ex1 = expectation(description: "wait") ex1.isInverted = true let ex2 = expectation(description: "woken") - - // When a task waits for this semaphore, - // Then the task is initially suspended. Task { - await sem.wait() + try await sem.wait() ex1.fulfill() ex2.fulfill() } + + // Then the task is initially suspended. wait(for: [ex1], timeout: 0.5) // When a signal occurs, then the suspended task is resumed. @@ -117,6 +119,106 @@ final class SemaphoreTests: XCTestCase { } } + func test_cancellation_while_suspended_throws_CancellationError() async throws { + let sem = Semaphore(value: 0) + let ex = expectation(description: "cancellation") + let task = Task { + do { + try await sem.wait() + XCTFail("Expected CancellationError") + } catch is CancellationError { + } catch { + XCTFail("Unexpected error") + } + ex.fulfill() + } + try await Task.sleep(nanoseconds: 100_000_000) + task.cancel() + wait(for: [ex], timeout: 1) + } + + func test_cancellation_before_suspension_throws_CancellationError() async throws { + let sem = Semaphore(value: 0) + let ex = expectation(description: "cancellation") + let task = Task { + // Uncancellable delay + await withUnsafeContinuation { continuation in + DispatchQueue.main.asyncAfter(deadline: .now() + 0.1) { + continuation.resume() + } + } + do { + try await sem.wait() + XCTFail("Expected CancellationError") + } catch is CancellationError { + } catch { + XCTFail("Unexpected error") + } + ex.fulfill() + } + task.cancel() + wait(for: [ex], timeout: 5) + } + + func test_that_cancellation_while_suspended_increments_the_semaphore() async throws { + // Given a task cancelled while suspended on a semaphore, + let sem = Semaphore(value: 0) + let task = Task { + try await sem.wait() + } + try await Task.sleep(nanoseconds: 100_000_000) + task.cancel() + + // When a task waits for this semaphore, + let ex1 = expectation(description: "wait") + ex1.isInverted = true + let ex2 = expectation(description: "woken") + Task { + try await sem.wait() + ex1.fulfill() + ex2.fulfill() + } + + // Then the task is initially suspended. + wait(for: [ex1], timeout: 0.5) + + // When a signal occurs, then the suspended task is resumed. + sem.signal() + wait(for: [ex2], timeout: 0.5) + } + + func test_that_cancellation_before_suspension_increments_the_semaphore() async throws { + // Given a task cancelled before it waits on a semaphore, + let sem = Semaphore(value: 0) + let task = Task { + // Uncancellable delay + await withUnsafeContinuation { continuation in + DispatchQueue.main.asyncAfter(deadline: .now() + 0.1) { + continuation.resume() + } + } + try await sem.wait() + } + task.cancel() + + // When a task waits for this semaphore, + let ex1 = expectation(description: "wait") + ex1.isInverted = true + let ex2 = expectation(description: "woken") + Task { + try await sem.wait() + ex1.fulfill() + ex2.fulfill() + } + + // Then the task is initially suspended. + wait(for: [ex1], timeout: 0.5) + + // When a signal occurs, then the suspended task is resumed. + sem.signal() + wait(for: [ex2], timeout: 0.5) + } + func test_semaphore_as_a_resource_limiter() async { /// An actor that counts the maximum number of concurrent executions of /// the `run()` method. @@ -137,10 +239,10 @@ final class SemaphoreTests: XCTestCase { let sem = Semaphore(value: count) // Spawn many concurrent tasks - await withTaskGroup(of: Void.self) { group in + await withThrowingTaskGroup(of: Void.self) { group in for _ in 0..<(maxCount * 2) { group.addTask { - await sem.wait() + try await sem.wait() await runner.run() sem.signal() }