Created
February 27, 2024 22:40
-
-
Save JadenGeller/2c819d15ebbbf2a3e3f501764dc6c8b8 to your computer and use it in GitHub Desktop.
Request batcher using Swift async
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
actor RequestBatcher<Request, Response> { | |
var maxBatch: Int | |
var maxDelay: Duration | |
var priority: TaskPriority? | |
var dispatchBatch: ([Request]) async -> [Result<Response, Error>] | |
init(maxBatch: Int, maxDelay: Duration, dispatchBatch: @escaping ([Request]) async -> [Result<Response, Error>]) { | |
self.maxBatch = maxBatch | |
self.maxDelay = maxDelay | |
self.dispatchBatch = dispatchBatch | |
} | |
var flushTask: Task<Void, Never>? | |
var batch: [(request: Request, continuation: CheckedContinuation<Response, Error>)] = [] | |
func dispatch(_ request: Request) async throws -> Response { | |
try await withCheckedThrowingContinuation { continuation in | |
batch.append((request: request, continuation: continuation)) | |
if batch.count >= maxBatch { | |
Task(priority: priority) { | |
flushTask?.cancel() | |
flushTask = nil | |
await flush() | |
} | |
} | |
else if batch.count == 1 { | |
flushTask = Task(priority: priority) { | |
guard (try? await Task.sleep(for: maxDelay)) != nil else { return } | |
flushTask = nil | |
await flush() | |
} | |
} | |
} | |
} | |
func flush() async { | |
let requests = batch.map(\.request) | |
let continuations = batch.map(\.continuation) | |
batch.removeAll() | |
let responses = await dispatchBatch(requests) | |
precondition(requests.count == responses.count) | |
for (continuation, response) in zip(continuations, responses) { | |
continuation.resume(with: response) | |
} | |
} | |
} | |
extension RequestBatcher { | |
init<ID: Hashable>(maxBatch: Int, maxDelay: Duration, id: KeyPath<Request, ID>, missingResponseError: Error = BatchError.missingResponse, dispatchBatch: @escaping ([Request]) async-> [Request: Result<Response, Error>]) { | |
self.init(maxBatch: maxBatch, maxDelay: maxDelay) { requests in | |
let responses = await dispatchBatch(requests) | |
return requests.map { request in | |
if let response = responses[request] { | |
response | |
} else { | |
.failure(missingResponseError) | |
} | |
} | |
} | |
} | |
} | |
extension RequestBatcher where Request: Hashable { | |
init(maxBatch: Int, maxDelay: Duration, missingResponseError: Error = BatchError.missingResponse, dispatchBatch: @escaping ([Request]) async-> [Request: Result<Response, Error>]) { | |
self.init(maxBatch: maxBatch, maxDelay: maxDelay, id: \.self, missingResponseError: missingResponseError, dispatchBatch: dispatchBatch) | |
} | |
} | |
enum BatchError: Error { | |
case missingResponse | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment