Created
May 31, 2024 22:16
-
-
Save pirogoeth/c91492edcbf18b26f349bbfe3458d27d to your computer and use it in GitHub Desktop.
things i implemented before i realized go stdlib had it
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package goro | |
import ( | |
"context" | |
"errors" | |
"sync" | |
) | |
type LimitGroup struct { | |
sema *SlottedSemaphore | |
fns []errWrapper | |
lock sync.Mutex | |
} | |
func NewLimitGroup(limit int) *LimitGroup { | |
return &LimitGroup{ | |
sema: NewSlottedSemaphore(limit), | |
fns: make([]errWrapper, 0), | |
lock: sync.Mutex{}, | |
} | |
} | |
func (lg *LimitGroup) Add(fn errWrapper) error { | |
if ok := lg.lock.TryLock(); !ok { | |
return errors.New("cannot add to a LimitGroup while running") | |
} | |
lg.fns = append(lg.fns, fn) | |
lg.lock.Unlock() | |
return nil | |
} | |
func (lg *LimitGroup) Run(parentCtx context.Context) error { | |
lg.lock.Lock() | |
ctx, cancel := context.WithCancel(parentCtx) | |
errCh := make(chan error, len(lg.fns)) | |
wg := sync.WaitGroup{} | |
for _, fn := range lg.fns { | |
wg.Add(1) | |
go func(fn errWrapper) { | |
slot := lg.sema.AcquireBlocking() | |
defer slot.Release() | |
errCh <- fn(ctx) | |
wg.Done() | |
}(fn) | |
} | |
wg.Wait() | |
close(errCh) | |
errs := make([]error, 0) | |
for err := range errCh { | |
errs = append(errs, err) | |
} | |
cancel() | |
return errors.Join(errs...) | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package goro | |
import ( | |
"context" | |
"crypto/rand" | |
"fmt" | |
"math/big" | |
"sync/atomic" | |
"testing" | |
"time" | |
) | |
func TestLimit(t *testing.T) { | |
ctx := context.Background() | |
limit := NewLimitGroup(1) | |
limit.Add(func(ctx context.Context) error { | |
return nil | |
}) | |
err := limit.Run(ctx) | |
if err != nil { | |
t.Fail() | |
} | |
} | |
func TestLimitSingleConcurrent(t *testing.T) { | |
ctx := context.Background() | |
executing := atomic.Bool{} | |
nestedWork := func(_ context.Context) error { | |
if executing.Load() { | |
return fmt.Errorf("concurrent execution detected") | |
} | |
executing.Store(true) | |
defer executing.Store(false) | |
sleepTime, err := rand.Int(rand.Reader, big.NewInt(5)) | |
if err != nil { | |
return fmt.Errorf("could not get random int: %w", err) | |
} | |
time.Sleep(time.Duration(sleepTime.Int64()) * time.Millisecond) | |
return nil | |
} | |
limit := NewLimitGroup(1) | |
limit.Add(func(ctx context.Context) error { | |
t.Log("run 1") | |
return nestedWork(ctx) | |
}) | |
limit.Add(func(ctx context.Context) error { | |
t.Log("run 2") | |
return nestedWork(ctx) | |
}) | |
limit.Add(func(ctx context.Context) error { | |
t.Log("run 3") | |
return nestedWork(ctx) | |
}) | |
if err := limit.Run(ctx); err != nil { | |
t.Error(err) | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package goro | |
import ( | |
"errors" | |
"sync" | |
"github.com/sirupsen/logrus" | |
) | |
var ( | |
ErrNoSlotsAvailable = errors.New("no slots available") | |
ErrSlotAlreadyReleased = errors.New("slot already released") | |
) | |
type slot struct { | |
parentSema *SlottedSemaphore | |
} | |
func (s *slot) Release() error { | |
if s.parentSema == nil { | |
logrus.Errorf("slot already released") | |
return ErrSlotAlreadyReleased | |
} | |
if err := s.parentSema.Release(s); err != nil { | |
logrus.Fatalf("failed to release slot: %s", err.Error()) | |
} | |
logrus.Tracef("released slot %v", s) | |
s.parentSema = nil | |
return nil | |
} | |
func (s *slot) IsReleased() bool { | |
return s.parentSema == nil | |
} | |
type SlottedSemaphore struct { | |
sema []*slot | |
lock *sync.Mutex | |
acquisitionCond *sync.Cond | |
} | |
func NewSlottedSemaphore(limit int) *SlottedSemaphore { | |
lock := &sync.Mutex{} | |
return &SlottedSemaphore{ | |
sema: make([]*slot, limit), | |
lock: lock, | |
acquisitionCond: sync.NewCond(lock), | |
} | |
} | |
func (ss *SlottedSemaphore) findFreeSlot() int { | |
for i, slot := range ss.sema { | |
if slot == nil { | |
logrus.Tracef("first free slot at %d", i) | |
return i | |
} | |
} | |
return -1 | |
} | |
func (ss *SlottedSemaphore) Acquire() (*slot, error) { | |
ss.lock.Lock() | |
defer ss.lock.Unlock() | |
slotIdx := ss.findFreeSlot() | |
if slotIdx == -1 { | |
return nil, ErrNoSlotsAvailable | |
} | |
slot := &slot{ | |
parentSema: ss, | |
} | |
ss.sema[slotIdx] = slot | |
logrus.Tracef("slot %d acquired", slotIdx) | |
return slot, nil | |
} | |
func (ss *SlottedSemaphore) AcquireBlocking() *slot { | |
ss.lock.Lock() | |
for { | |
slotIdx := ss.findFreeSlot() | |
if slotIdx != -1 { | |
slot := &slot{ | |
parentSema: ss, | |
} | |
ss.sema[slotIdx] = slot | |
logrus.Tracef("slot %d acquired (blocking)", slotIdx) | |
ss.lock.Unlock() | |
return slot | |
} | |
logrus.Trace("no slots available, sleeping for release") | |
ss.acquisitionCond.Wait() | |
} | |
} | |
func (ss *SlottedSemaphore) Release(s *slot) error { | |
ss.lock.Lock() | |
defer ss.lock.Unlock() | |
for i, slot := range ss.sema { | |
if slot == s { | |
logrus.Tracef("releasing slot at index %d", i) | |
ss.sema[i] = nil | |
ss.acquisitionCond.Signal() | |
return nil | |
} | |
} | |
return ErrSlotAlreadyReleased | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package goro | |
import ( | |
"testing" | |
"time" | |
) | |
func TestSlottedSemaphore(t *testing.T) { | |
ss := NewSlottedSemaphore(1) | |
results := make([]int, 2) | |
doneCh := make(chan bool) | |
go func() { | |
t.Log("Acquiring slot 1") | |
slot := ss.AcquireBlocking() | |
t.Logf("Acquired slot 1: %v", slot) | |
defer slot.Release() | |
time.Sleep(1 * time.Second) | |
results[0] = 1 | |
doneCh <- true | |
}() | |
go func() { | |
t.Log("Acquiring slot 2") | |
slot := ss.AcquireBlocking() | |
t.Logf("Acquired slot 2: %v", slot) | |
defer slot.Release() | |
results[1] = 2 | |
doneCh <- true | |
}() | |
<-doneCh | |
<-doneCh | |
if results[0] != 1 || results[1] != 2 { | |
t.Fail() | |
} | |
} | |
func TestSlottedSemaphoreDoubleRelease(t *testing.T) { | |
ss := NewSlottedSemaphore(1) | |
slot := ss.AcquireBlocking() | |
slot.Release() | |
err := slot.Release() | |
if err != ErrSlotAlreadyReleased { | |
t.Fail() | |
} | |
} | |
func TestSlottedSemaphoreAsyncAcquire(t *testing.T) { | |
ss := NewSlottedSemaphore(1) | |
slot, err := ss.Acquire() | |
if err != nil { | |
t.Fail() | |
} | |
go func() { | |
time.Sleep(1 * time.Second) | |
slot.Release() | |
}() | |
for { | |
if slot.IsReleased() { | |
break | |
} | |
} | |
} | |
func TestSlottedSemaphoreGoroWait(t *testing.T) { | |
ss := NewSlottedSemaphore(1) | |
slot := ss.AcquireBlocking() | |
done := make(chan bool) | |
go func() { | |
for i := 0; i < 10; i++ { | |
_, err := ss.Acquire() | |
if err == nil { | |
t.Errorf("should not be able to acquire a slot") | |
} | |
t.Log("Waiting for slot") | |
time.Sleep(500 * time.Millisecond) | |
} | |
done <- true | |
}() | |
<-done | |
slot.Release() | |
t.Log("Waiting goro closed") | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment