How to use withTaskCancellationHandler properly?

I tried using your latest solution robert.ryan to adapt Apollo fetch method. The problem is that the continuation inside tryFetch method now leaks when I cancel the request. Here is the code:

actor SendableQueryAdapter {
    var state: State = .ready
    
    func fetch<Query: GraphQLQuery>(
        query: Query,
        client: ApolloClientProtocol,
        cachePolicy: CachePolicy,
        contextIdentifier: UUID? = nil,
        queue: DispatchQueue,
        resultHandler: @Sendable @escaping (Result<GraphQLResult<Query.Data>, Error>) -> Void
    ) {
        if case .cancelled = state {
            resultHandler(.failure(CancellationError()))
            return
        }
        
        let cancellable = client.fetch(query: query, cachePolicy: cachePolicy, contextIdentifier: contextIdentifier, queue: queue, resultHandler: resultHandler)

        state = .executing(cancellable)
    }

    func cancel() {
        if case .executing(let task) = state {
            task.cancel()
        }
        state = .cancelled
    }
}

extension SendableQueryAdapter {
    enum State {
        case ready
        case executing(Apollo.Cancellable)
        case cancelled
    }
}

extension ApolloClientProtocol {
    func tryFetch<Query: GraphQLQuery>(
        query: Query,
        cachePolicy: CachePolicy,
        contextIdentifier: UUID?,
        queue: DispatchQueue
    ) async throws -> Query.Data {
        let adapter = SendableQueryAdapter()
        
        return try await withTaskCancellationHandler {
            try await withCheckedThrowingContinuation { continuation in
                Task {
                    await adapter.fetch(query: query, client: self, cachePolicy: cachePolicy, queue: queue, resultHandler: { result in
                        switch result {
                        case .success(let graphQLResult):
                            if let data = graphQLResult.data {
                                continuation.resume(returning: data)
                            } else if let error = graphQLResult.errors?.first {
                                continuation.resume(throwing: error)
                            } else {
                                continuation.resume(throwing: URLError(.badServerResponse))
                            }
                        case .failure(let error):
                            continuation.resume(throwing: error)
                        }
                    })
                }
            }
        } onCancel: {
            Task { await adapter.cancel() }
        }
    }
}

Am I missing something?

Yes, it seems that you are right.
I just made a unit test and very similar implementation but for Combine AnyPublisher adoption to Swift Concurrency.

final class AnyPublisherExtensionsTests: XCTestCase {
    func test_cancel() async {
        let publisher = makePublisher(result: .success(10), delayed: 0.2)

        let task = Task {
            try await publisher.async()
        }

        DispatchQueue.main.asyncAfter(deadline: .now() + 0.1) {
            task.cancel()
        }

        do {
            _ = try await task.value
            XCTFail("Should throw CancellationError")
        } catch {
            XCTAssert(error is CancellationError)
        }
    }

    // MARK: Helpers
    func makePublisher(result: Result<Int, Error>, delayed: TimeInterval) -> AnyPublisher<Int, Error> {
        return Deferred {
            Future { promise in
                DispatchQueue.main.asyncAfter(deadline: .now() + delayed) {
                    promise(result)
                }
            }
        }.eraseToAnyPublisher()
    }
}

Implemenetation:

extension AnyPublisher {
    func async() async throws -> Output {
        let task = CancellableTask()

        debugPrint("#1")
        return try await withTaskCancellationHandler {
            debugPrint("#2")
            return try await withCheckedThrowingContinuation { continuation in
                debugPrint("#3")
                Task {
                    debugPrint("#4")
                    await task.start(on: self) { result in
                        debugPrint("#5")
                        continuation.resume(with: result)
                    }
                }
            }
        } onCancel: {
            debugPrint("#6")
            Task {
                debugPrint("#7")
                await task.cancel()
            }
        }
    }
}

private extension AnyPublisher {
    actor CancellableTask {
        var state: State = .ready

        func start(on publisher: AnyPublisher, completionHandler: @Sendable @escaping (Result<Output, Error>) -> Void) {
            debugPrint("1")
            if case .cancelled = state {
                debugPrint("2")
                completionHandler(.failure(CancellationError()))
                return
            }

            var cancellable: AnyCancellable?
            cancellable = publisher
                .first()
                .sink { result in
                    switch result {
                    case .finished:
                        debugPrint("3")
                        break
                    case let .failure(error):
                        debugPrint("4")
                        completionHandler(.failure(error))
                    }
                    debugPrint("5")
                    cancellable?.cancel()
                } receiveValue: { value in
                    debugPrint("6")
                    completionHandler(.success(value))
                }

            if let cancellable {
                debugPrint("7")
                state = .executing(cancellable, { completionHandler(.failure(CancellationError())) })
            }
        }

        func cancel() {
            debugPrint("8")
            if case .executing(let cancellable, let cancelCompletion) = state {
                debugPrint("9")
                cancellable.cancel()
                cancelCompletion() // <-- If you comment this out, the continuation will leak
            }
            debugPrint("10")
            state = .cancelled
        }
    }

    enum State {
        case ready
        case executing(AnyCancellable, () -> Void)
        case cancelled
    }
}

As you can see I have added to @robert.ryan solution, a closure to case executing(AnyCancellable, () -> Void) to just quickly verify if the problem is the issue that once you cancel, this information is never propagated back to the continuation

Based on the prints that I added if you comment out the line cancelCompletion() you have the leak problem:

Test Case '-[AnyPublisherExtensionsTests test_cancel]' started.
"#1"
"#2"
"#3"
"#4"
"1"
"7"
"#6"
"#7"
"8"
"9"
"10"
SWIFT TASK CONTINUATION MISUSE: async() leaked its continuation!

But if cancelCompletion() is called to propagate back the cancellation to the continuation, no more leaking is observed.

Test Case '-[AnyPublisherExtensionsTests test_cancel]' started.
"#1"
"#2"
"#3"
"#4"
"1"
"7"
"#6"
"#7"
"8"
"9"
"#5"
"10"

It shows that the withTaskCancellationHandler API is very hard to get right. I hope it can be somehow improved for use from the Foundation or Swift level.

3 Likes

I have found how withTaskCancellationHandler is implemented in Foundation for back ported URLSession.data async/await (/Applications/Xcode.app/Contents/Developer/Platforms/iPhoneOS.platform/Developer/SDKs/iPhoneOS.sdk/System/Library/Frameworks/Foundation.framework/Modules/Foundation.swiftmodule/arm64e-apple-ios.swiftinterface).

@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *)
extension Foundation.URLSession {
  @_alwaysEmitIntoClient private func makeState() -> Swift.ManagedBuffer<(isCancelled: Swift.Bool, task: Foundation.URLSessionTask?), Darwin.os_unfair_lock> {
        ManagedBuffer<(isCancelled: Bool, task: URLSessionTask?), os_unfair_lock>.create(minimumCapacity: 1) { buffer in
            buffer.withUnsafeMutablePointerToElements { $0.initialize(to: os_unfair_lock()) }
            return (isCancelled: false, task: nil)
        }
    }
  @_alwaysEmitIntoClient private func cancel(state: Swift.ManagedBuffer<(isCancelled: Swift.Bool, task: Foundation.URLSessionTask?), Darwin.os_unfair_lock>) {
        state.withUnsafeMutablePointers { state, lock in
            os_unfair_lock_lock(lock)
            let task = state.pointee.task
            state.pointee = (isCancelled: true, task: nil)
            os_unfair_lock_unlock(lock)
            task?.cancel()
        }
    }
  @_alwaysEmitIntoClient private func activate(state: Swift.ManagedBuffer<(isCancelled: Swift.Bool, task: Foundation.URLSessionTask?), Darwin.os_unfair_lock>, task: Foundation.URLSessionTask) {
        state.withUnsafeMutablePointers { state, lock in
            os_unfair_lock_lock(lock)
            if state.pointee.task != nil {
                fatalError("Cannot activate twice")
            }
            if state.pointee.isCancelled {
                os_unfair_lock_unlock(lock)
                task.cancel()
            } else {
                state.pointee = (isCancelled: false, task: task)
                os_unfair_lock_unlock(lock)
            }
        }
    }
}
@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *)
extension Foundation.URLSession {
  #if compiler(>=5.3) && $AsyncAwait
  @_alwaysEmitIntoClient @_disfavoredOverload public func data(for request: Foundation.URLRequest) async throws -> (Foundation.Data, Foundation.URLResponse) {
        if #available(macOS 12.0, iOS 15.0, watchOS 8.0, tvOS 15.0, *) {
            return try await data(for: request, delegate: nil)
        }
        let cancelState = makeState()
        return try await withTaskCancellationHandler {
            cancel(state: cancelState)
        } operation: {
            try await withCheckedThrowingContinuation { continuation in
                let task = dataTask(with: request) { data, response, error in
                    if let error = error {
                        continuation.resume(throwing: error)
                    } else {
                        continuation.resume(returning: (data!, response!))
                    }
                }
                task.resume()
                activate(state: cancelState, task: task)
            }
        }
    }
  #endif
  ...
}
5 Likes

What does this look like to people here? A goal is to try to cut down on thread context switching for the purposes of wrapping any cancellable API in a non-async/await context.

import Foundation

private class CancellableCheckedThrowingContinuation<T> {

    private enum State {
        case ready
        case executing(cancel: () -> Void, continuation: CheckedContinuation<T, Error>)
        case cancelled
    }

    private var state: State = .ready
    private var lock: UnsafeMutablePointer<os_unfair_lock>

    init() {
        lock = UnsafeMutablePointer<os_unfair_lock>.allocate(capacity: 1)
        lock.initialize(to: os_unfair_lock())
    }

    deinit {
        lock.deallocate()
    }

    func operation(_ body: @escaping (CheckedContinuation<T, Error>) -> () -> Void) async throws -> T {
        try Task.checkCancellation()
        return try await withTaskCancellationHandler(operation: {
            try Task.checkCancellation()
            return try await withCheckedThrowingContinuation({ continuation in
                let cancel = body(continuation)
                os_unfair_lock_lock(lock)
                if case .ready = state {
                    state = .executing(cancel: cancel, continuation: continuation)
                    os_unfair_lock_unlock(lock)
                } else {
                    os_unfair_lock_unlock(lock)
                    cancel()
                    continuation.resume(throwing: CancellationError())
                }
            })
        }, onCancel: {
            os_unfair_lock_lock(lock)
            if case .executing(let cancel, let continuation) = state {
                state = .cancelled
                os_unfair_lock_unlock(lock)
                cancel()
                continuation.resume(throwing: CancellationError())
            } else {
                state = .cancelled
                os_unfair_lock_unlock(lock)
            }
        })
    }
}

public func withCancellableCheckedThrowingContinuation<T>(_ body: @escaping (CheckedContinuation<T, Error>) -> () -> Void) async throws -> T {
    try Task.checkCancellation()
    return try await CancellableCheckedThrowingContinuation().operation(body)
}

It seems to be working, but I question a few things.

  1. Is that many calls to try Task.checkCancellation() actually worth anything?
  2. Could I actually safely invoke the body function within the lock?
    Something like this?
os_unfair_lock_lock(lock)
if case .ready = state {
    state = .executing(cancel: body(continuation), continuation: continuation)
    os_unfair_lock_unlock(lock)
} else {
    os_unfair_lock_unlock(lock)
    continuation.resume(throwing: CancellationError())
}

Usage:

class Something {
    init(_ completion: (Result<Data, Error>) -> Void) {
        // ... non-async/await stuff
    }
    func cancel() {
        // ... non-async/await stuff
    }
}

public func doSomething() async throws -> Data {
    return try await withCancellableCheckedThrowingContinuation { continuation in
        let something = Something { result in
            switch result {
            case .success(let data):
                continuation.resume(returning: data)
            case .failure(let error):
                continuation.resume(throwing: error)
            }
        }
        return something.cancel
    }
}

There was a (I think) very interesting bit of information about using withTaskCancellationHandler in the WWDC23 session, Beyond the basics of structured concurrency:

We do this by synchronously calling the "cancel" function on our sequence state machine.

Note that because the cancellation handler runs immediately, the state machine is shared mutable state between the cancellation handler and main body, which can run concurrently. We'll need to protect our state machine.

While actors are great for protecting encapsulated state, we want to modify and read individual properties on our state machine, so actors aren't quite the right tool for this.

Furthermore, we can't guarantee the order that operations run on an actor, so we can't ensure that our cancellation will run first. We'll need something else. I've decided to use atomics from the Swift Atomics package, but we could use a dispatch queue or locks.

These mechanisms allow us to synchronize the shared state, avoiding race conditions, while allowing us to cancel the running state machine without introducing an unstructured task in the cancellation handler.

6 Likes

Yes I saw that and thought it was unfortunate. Dropping to another solution outside swift concurrency to write correct code feels like a failure.

5 Likes

Perhaps, but I am still grateful for the clarity on correct use of the API. I wish the docs had contained that explanation all along.

7 Likes