Skip to content

Instantly share code, notes, and snippets.

@sunghyun-k
Created October 20, 2024 11:09
Show Gist options
  • Select an option

  • Save sunghyun-k/551aeafa87b64d9acc91f31e8450772f to your computer and use it in GitHub Desktop.

Select an option

Save sunghyun-k/551aeafa87b64d9acc91f31e8450772f to your computer and use it in GitHub Desktop.
import Foundation
class AuthImpl {
// Simply used instead of CancellationError to easily distinguish cancellation status
enum TaskResult {
case response(String)
case cancelled
}
typealias Continuation = CheckedContinuation<TaskResult, Never>
enum RequestState {
case cached(String)
case loading(
Task<Void, Never>,
[UUID: Continuation] // Passing by reference would be better, but we'll skip it for now
)
}
private var state: RequestState?
// Using NSLock because Mutex doesn't have lock() and unlock()
private let lock = NSLock()
func getToken() async -> TaskResult {
if Task.isCancelled {
return .cancelled
}
// This lock is released after the continuation is copied to state. Using an actor would cause re-entrancy, potentially hitting the nil case multiple times, causing issues.
// Swift 6's Mutex only has withLock, making it difficult to unlock when the Continuation block returns.
lock.lock()
switch state {
case .cached(let token):
defer { lock.unlock() }
return .response(token)
case .loading(let task, var continuations):
let id = UUID()
return await withTaskCancellationHandler {
return await withCheckedContinuation {
continuations[id] = $0
state = .loading(task, continuations)
lock.unlock()
}
} onCancel: {
removeContinuation(id: id)
}
case nil:
let id = UUID()
return await withTaskCancellationHandler {
return await withCheckedContinuation {
let task = Task {
let result = await fetch()
guard case .response(let response) = result else {
// We can return here because removeContinuation emits the cancellation
return
}
lock.withLock {
guard case .loading(_, let continuations) = state else {
fatalError()
}
continuations.forEach { $0.value.resume(returning: result) }
state = .cached(response)
}
}
state = .loading(task, [id: $0])
lock.unlock()
}
} onCancel: {
removeContinuation(id: id)
}
}
}
private func removeContinuation(id: UUID) {
lock.withLock {
guard case .loading(let task, var continuations) = state else {
fatalError()
}
continuations[id]?.resume(returning: .cancelled)
continuations[id] = nil
// Cancel the original task if all connected tasks are cancelled
if continuations.isEmpty {
task.cancel()
state = nil
} else {
state = .loading(task, continuations)
}
}
}
private func fetch() async -> TaskResult {
do {
// Substitute for actual work
try await Task.sleep(for: .seconds(3))
return .response("Sample Value")
} catch {
return .cancelled
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment