Skip to content

Instantly share code, notes, and snippets.

@KaQuMiQ
Created March 14, 2022 16:02
Show Gist options
  • Save KaQuMiQ/8ed1f8494fad7afb1b5474fd927a8a91 to your computer and use it in GitHub Desktop.
Save KaQuMiQ/8ed1f8494fad7afb1b5474fd927a8a91 to your computer and use it in GitHub Desktop.
SpinLock with async/await
import libkern
@usableFromInline
internal final class AsyncSpinLock {
@TaskLocal private static var taskContext: Array<AsyncSpinLock> = .init()
@usableFromInline
internal let atomicFlagPointer: UnsafeMutablePointer<atomic_flag>
internal init() {
self.atomicFlagPointer = .allocate(capacity: 1)
self.atomicFlagPointer.pointee = atomic_flag()
}
deinit {
self.atomicFlagPointer.deinitialize(count: 1)
self.atomicFlagPointer.deallocate()
}
@usableFromInline
@inline(__always)
// acquires lock or throws on cancelled
internal func lock() async throws {
var currentContext: Array<AsyncSpinLock> = Self.taskContext
guard currentContext.contains(self) else { return } // recursive locking?
currentContext.append(self)
try await Self.$taskContext.withValue(currentContext) {
while !tryLock() {
await Task.yield()
try Task.checkCancellation()
}
}
}
@usableFromInline
@inline(__always)
// acquires lock even if task is cancelled
internal func shieldedLock() async {
while !tryLock() {
await Task.yield()
}
}
@usableFromInline
@inline(__always)
internal func tryLock() -> Bool {
atomic_flag_test_and_set(self.atomicFlagPointer)
}
@usableFromInline
@inline(__always)
internal func unlock() {
atomic_flag_clear(self.atomicFlagPointer)
}
@usableFromInline
@inline(__always)
internal func withLock<T>(
_ execute: () async throws -> T
) async throws -> T {
try await lock()
defer { unlock() }
return try await execute()
}
@usableFromInline
@inline(__always)
internal func withShieldedLock<T>(
_ execute: () async throws -> T
) async rethrows -> T {
await shieldedLock()
defer { unlock() }
return try await execute()
}
}
extension AsyncSpinLock: Equatable {
@usableFromInline
@inline(__always)
internal static func == (
_ lhs: AsyncSpinLock,
_ rhs: AsyncSpinLock
) -> Bool {
lhs.atomicFlagPointer == rhs.atomicFlagPointer
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment