Skip to content

Instantly share code, notes, and snippets.

@rvsrvs
Last active May 27, 2022 17:23
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save rvsrvs/74a50d73bb647a721654cb378be10198 to your computer and use it in GitHub Desktop.
Save rvsrvs/74a50d73bb647a721654cb378be10198 to your computer and use it in GitHub Desktop.
CheckedExpectation for swift concurrency
//
// CheckedExpectation+Timeout.swift
//
//
// Created by Van Simmons on 2/20/22.
//
public func wait(
for expectation: CheckedExpectation<Void>,
timeout: UInt64 = .max
) async throws -> Void {
try await wait(for: [expectation], timeout: timeout, reducing: (), with: {_, _ in })
}
public extension CheckedExpectation where Arg == Void {
func timeout(
after timeout: UInt64 = .max
) async throws -> Void {
try await wait(for: self, timeout: timeout)
}
}
public func wait<FinalResult, PartialResult>(
for expectation: CheckedExpectation<PartialResult>,
timeout: UInt64 = .max,
reducing initialValue: FinalResult,
with reducer: @escaping (inout FinalResult, PartialResult) throws -> Void
) async throws -> FinalResult {
try await wait(for: [expectation], timeout: timeout, reducing: initialValue, with: reducer)
}
public extension CheckedExpectation {
func timeout<FinalResult>(
after timeout: UInt64 = .max,
reducing initialValue: FinalResult,
with reducer: @escaping (inout FinalResult, Arg) throws -> Void
) async throws -> FinalResult {
try await wait(for: self, timeout: timeout, reducing: initialValue, with: reducer)
}
}
public extension Array {
func timeout<FinalResult, PartialResult>(
after timeout: UInt64 = .max,
reducing initialValue: FinalResult,
with reducer: @escaping (inout FinalResult, PartialResult) throws -> Void
) async throws -> FinalResult where Element == CheckedExpectation<PartialResult> {
try await wait(for: self, timeout: timeout, reducing: initialValue, with: reducer)
}
}
public func wait<FinalResult, PartialResult, S: Sequence>(
for expectations: S,
timeout: UInt64 = .max,
reducing initialValue: FinalResult,
with reducer: @escaping (inout FinalResult, PartialResult) throws -> Void
) async throws -> FinalResult where S.Element == CheckedExpectation<PartialResult> {
let reducingTask = Task<FinalResult, Error>.init {
let stateTask = await StateTask<WaitState<FinalResult, PartialResult>, WaitState<FinalResult, PartialResult>.Action>.stateTask(
initialState: { channel in
.init(with: channel, for: expectations, timeout: timeout, reducer: reducer, initialValue: initialValue)
},
buffering: .bufferingOldest(expectations.underestimatedCount * 2 + 1),
reducer: Reducer(reducer: WaitState<FinalResult, PartialResult>.reduce)
)
return try await stateTask.finalState.finalResult
}
return try await reducingTask.value
}
//
// CheckedExpectation.swift
//
//
// Created by Van Simmons on 1/28/22.
//
public class CheckedExpectation<Arg> {
public enum Error: Swift.Error, Equatable {
case alreadyCancelled
case alreadyCompleted
case cancelled
case inconsistentState
case timedOut
}
public enum Status: Equatable {
case cancelled
case completed
case waiting
case failed
}
actor State {
private(set) var status: Status
private(set) var resumption: UnsafeContinuation<Arg, Swift.Error>?
init(status: Status = .waiting) {
self.status = status
}
fileprivate func set(resumption: UnsafeContinuation<Arg, Swift.Error>) {
status = .waiting
self.resumption = resumption
}
private func validateState() throws -> UnsafeContinuation<Arg, Swift.Error> {
guard status != .completed else { throw Error.alreadyCompleted }
guard status != .cancelled else { throw Error.alreadyCancelled }
guard let resumption = resumption, status == .waiting else { throw Error.inconsistentState }
return resumption
}
func cancel() throws {
let resumption = try validateState()
status = .cancelled
resumption.resume(throwing: Error.cancelled)
}
func complete(_ arg: Arg) throws {
let resumption = try validateState()
status = .completed
resumption.resume(returning: arg)
}
func fail(_ error: Swift.Error) throws {
let resumption = try validateState()
status = .failed
resumption.resume(throwing: error)
}
}
private let task: Task<Arg, Swift.Error>
private let state: State
public init(name: String = "") async {
let localState = State()
var localTask: Task<Arg, Swift.Error>!
await localState.set(resumption: await withCheckedContinuation { cc in
localTask = Task<Arg, Swift.Error> {
try await withTaskCancellationHandler(handler: { Task { try await localState.cancel() } }) {
do { return try await withUnsafeThrowingContinuation(cc.resume) }
catch { throw error }
}
}
})
task = localTask
state = localState
}
deinit { cancel() }
public var isCancelled: Bool { task.isCancelled }
public func cancel() -> Void { task.cancel() }
public func status() async -> Status { await state.status }
public func complete(_ arg: Arg) async throws -> Void { try await state.complete(arg) }
public func fail(_ error: Error) async throws -> Void { try await state.fail(error) }
@discardableResult public func result() async -> Result<Arg, Swift.Error> { await task.result }
@discardableResult public func value() async throws -> Arg {
do { return try await task.value }
catch { throw error }
}
}
extension CheckedExpectation where Arg == Void {
nonisolated public func complete() async throws -> Void {
try await state.complete(())
}
}
//
// WaitState.swift
//
//
// Created by Van Simmons on 5/12/22.
//
struct WaitState<FinalResult, PartialResult> {
typealias ST = StateTask<WaitState<FinalResult, PartialResult>, WaitState<FinalResult, PartialResult>.Action>
enum Action {
case complete(Int, PartialResult)
case timeout
}
let channel: Channel<WaitState<FinalResult, PartialResult>.Action>
let watchdog: Task<Void, Swift.Error>
let resultReducer: (inout FinalResult, PartialResult) throws -> Void
var expectations: [Int: CheckedExpectation<PartialResult>]
var tasks: [Int: Task<PartialResult, Swift.Error>]
var finalResult: FinalResult
init<S: Sequence>(
with channel: Channel<WaitState<FinalResult, PartialResult>.Action>,
for expectations: S,
timeout: UInt64,
reducer: @escaping (inout FinalResult, PartialResult) throws -> Void,
initialValue: FinalResult
) where S.Element == CheckedExpectation<PartialResult> {
let tasks = expectations.enumerated().map { index, expectation in
Task<PartialResult, Swift.Error> {
guard !Task.isCancelled else { throw PublisherError.cancelled }
let pResult = await expectation.result()
guard case let .success(pValue) = pResult else {
throw PublisherError.cancelled
}
guard !Task.isCancelled else { throw PublisherError.cancelled }
guard case .enqueued = channel.yield(.complete(index, pValue)) else {
throw PublisherError.internalError
}
return pValue
}
}
let expectationDict: [Int: CheckedExpectation<PartialResult>] = .init(
uniqueKeysWithValues: expectations.enumerated().map { ($0, $1) }
)
let taskDict: [Int: Task<PartialResult, Swift.Error>] = .init(
uniqueKeysWithValues: tasks.enumerated().map { ($0, $1) }
)
self.channel = channel
self.expectations = expectationDict
self.tasks = taskDict
self.watchdog = .init {
do {
try await Task.sleep(nanoseconds: timeout)
guard case .enqueued = channel.yield(.timeout) else {
fatalError("Unable to process timeout")
}
} catch {
throw error
}
}
self.resultReducer = reducer
self.finalResult = initialValue
}
mutating func cancel() -> Void {
expectations.values.forEach { $0.cancel() }
expectations.removeAll()
tasks.removeAll()
}
static func reduce(`self`: inout Self, action: Self.Action) throws -> Reducer<Self, Action>.Effect {
try `self`.reduce(action: action)
}
mutating func reduce(action: Action) throws -> Reducer<Self, Action>.Effect {
switch action {
case let .complete(index, partialResult):
guard let _ = expectations.removeValue(forKey: index),
let _ = tasks.removeValue(forKey: index) else {
fatalError("could not find task")
}
if expectations.count == 0 {
watchdog.cancel()
channel.finish()
}
do { try resultReducer(&finalResult, partialResult) }
catch {
cancel()
throw error
}
case .timeout:
cancel()
throw CheckedExpectation<FinalResult>.Error.timedOut
}
return .none
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment