Skip to content

Instantly share code, notes, and snippets.

@kprotty
Last active March 20, 2024 16:11
Show Gist options
  • Save kprotty/5b95914aaffc5767a6cd4ae7d6217fbd to your computer and use it in GitHub Desktop.
Save kprotty/5b95914aaffc5767a6cd4ae7d6217fbd to your computer and use it in GitHub Desktop.
Crossplatfomr futex and rwlock for Go
package main
import (
"sync/atomic"
"unsafe"
)
type mutex struct {
key uintptr
}
//go:noescape
//go:linkname lock runtime.lock
func lock(m *mutex)
//go:noescape
//go:linkname unlock runtime.unlock
func unlock(m *mutex)
type waitReason uint8
const waitReasonZero = 0
type traceEvent byte
const traceEvGoBlockSync = 25
type g struct{}
//go:noescape
//go:linkname goready runtime.goready
func goready(gp *g, traceskip int)
//go:noescape
//go:linkname gopark runtime.gopark
func gopark(
unlockf func(*g, unsafe.Pointer) bool,
lock unsafe.Pointer,
reason waitReason,
traceEv traceEvent,
traceskip int,
)
type sudog struct {
gp *g
next *sudog
prev *sudog
elem unsafe.Pointer
acquiretime int64
releasetime int64
ticket uint32
isSelect bool
success bool
parent *sudog
waitlink *sudog
waittail *sudog
c unsafe.Pointer
}
//go:noescape
//go:linkname acquireSudog runtime.acquireSudog
func acquireSudog() *sudog
//go:noescape
//go:linkname releaseSudog runtime.releaseSudog
func releaseSudog(s *sudog)
//go:noescape
//go:linkname fastrand runtime.fastrand
func fastrand() uint32
type treap struct {
root *sudog
}
func (self *treap) find(addr *uint32) (*sudog, *sudog) {
var parent *sudog = nil
link := &self.root
for s := *link; s != nil && s.elem != unsafe.Pointer(addr); {
parent = s
link = self.child(s, uintptr(unsafe.Pointer(addr)) > uintptr(s.elem))
}
return *link, parent
}
func (self *treap) insert(s *sudog, addr *uint32, parent *sudog) {
s.elem = unsafe.Pointer(addr)
s.ticket = fastrand()
s.prev = nil
s.next = nil
self.setParent(parent, nil, s)
for s.parent != nil && s.parent.ticket > s.ticket {
self.rotate(s.parent, s.parent.next == s)
}
}
func (self *treap) replace(old, new *sudog) {
new.elem = old.elem
new.ticket = old.ticket
new.prev = old.prev
new.next = old.next
if old.prev != nil {
old.prev.parent = new
}
if old.next != nil {
old.next.parent = new
}
self.setParent(old.parent, old, new)
self.clear(old)
}
func (self *treap) remove(s *sudog) {
for s.prev != nil || s.next != nil {
self.rotate(s, s.next == nil || (s.prev != nil && s.prev.ticket < s.next.ticket))
}
self.setParent(s.parent, s, nil)
self.clear(s)
}
func (self *treap) clear(s *sudog) {
s.elem = nil
s.ticket = 0
s.prev = nil
s.next = nil
}
func (self *treap) rotate(s *sudog, right bool) {
parent := s.parent
target := *self.child(s, !right)
adjacent := *self.child(s, right)
*self.child(target, right) = s
s.parent = target
*self.child(s, !right) = adjacent
if adjacent != nil {
adjacent.parent = s
}
self.setParent(parent, s, target)
}
func (self *treap) setParent(parent, old, new *sudog) {
if new != nil {
new.parent = parent
}
if parent != nil {
right := parent.next == old
if old == nil && new != nil {
right = uintptr(new.elem) > uintptr(parent.elem)
}
*self.child(parent, right) = new
} else {
self.root = new
}
}
func (self *treap) child(s *sudog, right bool) **sudog {
if right {
return &s.next
} else {
return &s.prev
}
}
type queue struct {
t *treap
addr *uint32
head, parent *sudog
}
func findqueue(addr *uint32, t *treap) queue {
head, parent := t.find(addr)
return queue{t, addr, head, parent}
}
func (self *queue) push(s *sudog) {
s.waitlink = nil
if self.head != nil {
self.head.waittail.waitlink = s
self.head.waittail = s
return
}
s.waittail = s
self.head = s
self.t.insert(s, self.addr, self.parent)
}
func (self *queue) pop() *sudog {
s := self.head
if s != nil {
self.head = s.waitlink
if self.head != nil {
self.head.waittail = s.waittail
self.t.replace(s, self.head)
} else {
self.t.remove(s)
}
s.waitlink = nil
s.waittail = nil
}
return s
}
type bucket struct {
m mutex
t treap
pending uint32
}
const numBuckets = 64
const cacheLinePadSize = 128
var buckets [numBuckets]struct {
b bucket
pad [cacheLinePadSize - unsafe.Sizeof(bucket{})]byte
}
func findbucket(addr *uint32) *bucket {
hash := uintptr(unsafe.Pointer(addr)) >> 3
return &buckets[hash%numBuckets].b
}
func Wait(addr *uint32, expect uint32) {
b := findbucket(addr)
atomic.AddUint32(&b.pending, 1)
lock(&b.m)
if atomic.LoadUint32(addr) != expect {
atomic.AddUint32(&b.pending, ^uint32(0))
unlock(&b.m)
return
}
s := acquireSudog()
q := findqueue(addr, &b.t)
q.push(s)
unlockf := func(gp *g, m unsafe.Pointer) bool {
s.gp = gp
unlock((*mutex)(m))
return true
}
gopark(unlockf, unsafe.Pointer(&b.m), waitReasonZero, traceEvGoBlockSync, 3)
releaseSudog(s)
}
func Wake(addr *uint32, maxWaiters uint32) {
b := findbucket(addr)
if atomic.AddUint32(&b.pending, 0) == 0 {
return
}
lock(&b.m)
if atomic.LoadUint32(&b.pending) == 0 {
unlock(&b.m)
return
}
q := findqueue(addr, &b.t)
var notified *sudog = nil
var removed uint32 = 0
for q.head != nil && removed < maxWaiters {
s := q.pop()
removed++
s.waitlink = notified
notified = s
}
if removed > 0 {
atomic.AddUint32(&b.pending, ^uint32(0)-(removed-1))
}
unlock(&b.m)
for s := notified; notified != nil; s = notified {
notified = s.waitlink
goready(s.gp, 4)
}
}
type RwMutex struct {
state uint32
epoch uint32
}
type RwMutexAccess uint8
const (
Exclusive RwMutexAccess = iota
Shared
)
func (self *RwMutex) TryLock(access RwMutexAccess) bool {
return self.tryAcquire(access == Exclusive)
}
func (self *RwMutex) Lock(access RwMutexAccess) {
self.acquire(access == Exclusive)
}
func (self *RwMutex) Unlock() {
masked := atomic.LoadUint32(&self.state) & rwMask
if masked == rwUnlocked {
panic("RwMutex.Unlock() called when not locked")
}
self.release(masked == rwWriter)
}
const (
rwUnlocked = 0
rwPendingReaders = 1 << 0
rwPendingWriters = 1 << 1
rwValue = 1 << 2
rwMask = ^uint32(rwValue - 1)
rwReader = rwValue
rwWriter = rwMask
)
func (self *RwMutex) tryAcquire(isWriter bool) bool {
if isWriter {
return atomic.CompareAndSwapUint32(&self.state, rwUnlocked, rwWriter)
}
for {
state := atomic.LoadUint32(&self.state)
if state&rwMask < (rwWriter - 1) {
return false
}
if atomic.CompareAndSwapUint32(&self.state, state, state+rwReader) {
return true
}
}
}
func (self *RwMutex) acquire(isWriter bool) {
var newState uint32 = rwReader
if isWriter {
newState = rwWriter
}
if !atomic.CompareAndSwapUint32(&self.state, rwUnlocked, newState) {
self.acquireSlow(isWriter)
}
}
//go:noescape
//go:linkname runtime_canSpin sync.runtime_canSpin
func runtime_canSpin(i int) bool
//go:noescape
//go:linkname runtime_doSpin sync.runtime_doSpin
func runtime_doSpin()
func (self *RwMutex) acquireSlow(isWriter bool) {
var pendingMask uint32 = rwPendingReaders
if isWriter {
pendingMask = rwPendingWriters
}
var spin int = 0
var acquireWith uint32
for {
state := atomic.LoadUint32(&self.state)
masked := state & rwMask
var newState uint32
if isWriter && (masked == rwUnlocked) {
newState = state | rwWriter | acquireWith
} else if !isWriter && (masked < rwWriter-1) {
newState = state + rwReader
}
if newState != 0 {
if atomic.CompareAndSwapUint32(&self.state, state, newState) {
return
}
if !isWriter {
runtime_doSpin()
}
continue
}
if state&pendingMask == 0 {
if runtime_canSpin(spin) {
spin++
runtime_doSpin()
continue
}
if !atomic.CompareAndSwapUint32(&self.state, state, state|pendingMask) {
continue
}
}
spin = 0
if !isWriter {
Wait(&self.state, state|pendingMask)
continue
}
for {
epoch := atomic.LoadUint32(&self.epoch)
state = atomic.LoadUint32(&self.state)
if (state&rwMask == rwUnlocked) || (state&pendingMask == 0) {
break
}
Wait(&self.epoch, epoch)
acquireWith = pendingMask
}
}
}
func (self *RwMutex) release(isWriter bool) {
if isWriter {
state := atomic.SwapUint32(&self.state, rwUnlocked)
if state != rwWriter {
self.releaseSlow(isWriter, state)
}
} else {
state := atomic.AddUint32(&self.state, ^uint32(0)-(rwReader-1))
if state == (rwUnlocked | rwPendingWriters) {
self.releaseSlow(isWriter, state)
}
}
}
func (self *RwMutex) releaseSlow(isWriter bool, state uint32) {
if isWriter && (state&rwPendingReaders != 0) {
Wake(&self.state, ^uint32(0))
state = atomic.LoadUint32(&self.state)
}
if state&rwPendingWriters != 0 {
atomic.AddUint32(&self.epoch, 1)
Wake(&self.epoch, 1)
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment