While researching how Swift Concurrency could replace Combine for a lot of use cases one case I was interested in was the ability to share a single task across multiple awaiters. I came up with a possible solution but being new to Swift Concurrency I'm worried about gotchas (threading is hard). Can anyone find any downsides or tell me if this is just a terrible idea in general?
actor SharedTaskWrapper {
private var sharedTask: Task<Result, Error>?
private var continuations: [CheckedContinuation<Result, Error>] = []
func fetch(request: Request) async throws -> Result {
let unsharedTask: Task<Result, Error> = Task {
try await withCheckedThrowingContinuation { continuation in
continuations.append(continuation)
if sharedTask == nil {
sharedTask = Task {
do {
let result = try await actuallyFetch(request: request)
let notifyContinuations = continuations
continuations.removeAll(keepingCapacity: true)
sharedTask = nil
notifyContinuations.forEach({ $0.resume(returning: result) })
return result
}
catch {
let notifyContinuations = continuations
continuations.removeAll(keepingCapacity: true)
sharedTask = nil
notifyContinuations.forEach({ $0.resume(throwing: error) })
throw error
}
}
}
}
}
return try await unsharedTask.value
}
private actuallyFetch(request: Request) async throws -> Result {
// Do async work here
}
}
Second try with hopefully more safety (still not sure it doesn't have issues):
protocol DataFetcher {
associatedtype Request : Hashable
associatedtype Response
func fetch(request: Request) async throws -> Response
}
extension DataFetcher {
func share() -> SharedDataFetcher<Self> {
return SharedDataFetcher(dataFetcher: self)
}
}
actor SharedDataFetcher<SharedFetcher : DataFetcher> : DataFetcher {
typealias Request = SharedFetcher.Request
typealias Response = SharedFetcher.Response
typealias FetchID = UInt
actor SharedState {
var sharedTask: Task<Response, Error>?
var continuations: [FetchID : CheckedContinuation<Response, Error>]
init() {
self.sharedTask = nil
self.continuations = [FetchID : CheckedContinuation<Response, Error>](minimumCapacity: 4)
}
func join(sharedFetcher: SharedFetcher, request: Request, fetchID: FetchID, continuation: CheckedContinuation<Response, Error>) {
continuations[fetchID] = continuation
if sharedTask == nil {
sharedTask = Task {
do {
let response = try await sharedFetcher.fetch(request: request)
success(response: response)
return response
}
catch {
failure(error: error)
throw error
}
}
}
}
func success(response: Response) {
continuations.values.forEach({ $0.resume(returning: response) })
continuations.removeAll()
sharedTask = nil
}
func failure(error: Error) {
continuations.values.forEach({ $0.resume(throwing: error) })
continuations.removeAll()
sharedTask = nil
}
func cancel(fetchID: FetchID) {
continuations[fetchID]?.resume(throwing: CancellationError())
continuations.removeValue(forKey: fetchID)
if continuations.isEmpty {
sharedTask?.cancel()
sharedTask = nil
}
}
func complete(_ execute: () -> Void) {
guard sharedTask == nil else {
return
}
execute()
}
}
private var nextFetchID: FetchID
private var requestSharedState: [Request : SharedState]
private let sharedFetcher: SharedFetcher
init(dataFetcher: SharedFetcher) {
self.nextFetchID = 0
self.requestSharedState = [Request : SharedState](minimumCapacity: 128)
self.sharedFetcher = dataFetcher
}
func fetch(request: Request) async throws -> Response {
let sharedState: SharedState
if let preexistingSharedState = requestSharedState[request] {
sharedState = preexistingSharedState
}
else {
sharedState = SharedState()
requestSharedState[request] = sharedState
}
let fetchID = nextFetchID
nextFetchID = (nextFetchID == .max) ? 0 : nextFetchID + 1
let response = try await withTaskCancellationHandler {
try await withCheckedThrowingContinuation { continuation in
Task {
await sharedState.join(sharedFetcher: sharedFetcher, request: request, fetchID: fetchID, continuation: continuation)
}
}
} onCancel: {
Task {
await sharedState.cancel(fetchID: fetchID)
}
}
await requestSharedState[request]!.complete {
precondition(sharedState === requestSharedState[request]!)
requestSharedState.removeValue(forKey: request)
}
return response
}
}