Skip to content

Instantly share code, notes, and snippets.

@pirogoeth
Created May 31, 2024 22:16
Show Gist options
  • Save pirogoeth/c91492edcbf18b26f349bbfe3458d27d to your computer and use it in GitHub Desktop.
Save pirogoeth/c91492edcbf18b26f349bbfe3458d27d to your computer and use it in GitHub Desktop.
things i implemented before i realized go stdlib had it
package goro
import (
"context"
"errors"
"sync"
)
type LimitGroup struct {
sema *SlottedSemaphore
fns []errWrapper
lock sync.Mutex
}
func NewLimitGroup(limit int) *LimitGroup {
return &LimitGroup{
sema: NewSlottedSemaphore(limit),
fns: make([]errWrapper, 0),
lock: sync.Mutex{},
}
}
func (lg *LimitGroup) Add(fn errWrapper) error {
if ok := lg.lock.TryLock(); !ok {
return errors.New("cannot add to a LimitGroup while running")
}
lg.fns = append(lg.fns, fn)
lg.lock.Unlock()
return nil
}
func (lg *LimitGroup) Run(parentCtx context.Context) error {
lg.lock.Lock()
ctx, cancel := context.WithCancel(parentCtx)
errCh := make(chan error, len(lg.fns))
wg := sync.WaitGroup{}
for _, fn := range lg.fns {
wg.Add(1)
go func(fn errWrapper) {
slot := lg.sema.AcquireBlocking()
defer slot.Release()
errCh <- fn(ctx)
wg.Done()
}(fn)
}
wg.Wait()
close(errCh)
errs := make([]error, 0)
for err := range errCh {
errs = append(errs, err)
}
cancel()
return errors.Join(errs...)
}
package goro
import (
"context"
"crypto/rand"
"fmt"
"math/big"
"sync/atomic"
"testing"
"time"
)
func TestLimit(t *testing.T) {
ctx := context.Background()
limit := NewLimitGroup(1)
limit.Add(func(ctx context.Context) error {
return nil
})
err := limit.Run(ctx)
if err != nil {
t.Fail()
}
}
func TestLimitSingleConcurrent(t *testing.T) {
ctx := context.Background()
executing := atomic.Bool{}
nestedWork := func(_ context.Context) error {
if executing.Load() {
return fmt.Errorf("concurrent execution detected")
}
executing.Store(true)
defer executing.Store(false)
sleepTime, err := rand.Int(rand.Reader, big.NewInt(5))
if err != nil {
return fmt.Errorf("could not get random int: %w", err)
}
time.Sleep(time.Duration(sleepTime.Int64()) * time.Millisecond)
return nil
}
limit := NewLimitGroup(1)
limit.Add(func(ctx context.Context) error {
t.Log("run 1")
return nestedWork(ctx)
})
limit.Add(func(ctx context.Context) error {
t.Log("run 2")
return nestedWork(ctx)
})
limit.Add(func(ctx context.Context) error {
t.Log("run 3")
return nestedWork(ctx)
})
if err := limit.Run(ctx); err != nil {
t.Error(err)
}
}
package goro
import (
"errors"
"sync"
"github.com/sirupsen/logrus"
)
var (
ErrNoSlotsAvailable = errors.New("no slots available")
ErrSlotAlreadyReleased = errors.New("slot already released")
)
type slot struct {
parentSema *SlottedSemaphore
}
func (s *slot) Release() error {
if s.parentSema == nil {
logrus.Errorf("slot already released")
return ErrSlotAlreadyReleased
}
if err := s.parentSema.Release(s); err != nil {
logrus.Fatalf("failed to release slot: %s", err.Error())
}
logrus.Tracef("released slot %v", s)
s.parentSema = nil
return nil
}
func (s *slot) IsReleased() bool {
return s.parentSema == nil
}
type SlottedSemaphore struct {
sema []*slot
lock *sync.Mutex
acquisitionCond *sync.Cond
}
func NewSlottedSemaphore(limit int) *SlottedSemaphore {
lock := &sync.Mutex{}
return &SlottedSemaphore{
sema: make([]*slot, limit),
lock: lock,
acquisitionCond: sync.NewCond(lock),
}
}
func (ss *SlottedSemaphore) findFreeSlot() int {
for i, slot := range ss.sema {
if slot == nil {
logrus.Tracef("first free slot at %d", i)
return i
}
}
return -1
}
func (ss *SlottedSemaphore) Acquire() (*slot, error) {
ss.lock.Lock()
defer ss.lock.Unlock()
slotIdx := ss.findFreeSlot()
if slotIdx == -1 {
return nil, ErrNoSlotsAvailable
}
slot := &slot{
parentSema: ss,
}
ss.sema[slotIdx] = slot
logrus.Tracef("slot %d acquired", slotIdx)
return slot, nil
}
func (ss *SlottedSemaphore) AcquireBlocking() *slot {
ss.lock.Lock()
for {
slotIdx := ss.findFreeSlot()
if slotIdx != -1 {
slot := &slot{
parentSema: ss,
}
ss.sema[slotIdx] = slot
logrus.Tracef("slot %d acquired (blocking)", slotIdx)
ss.lock.Unlock()
return slot
}
logrus.Trace("no slots available, sleeping for release")
ss.acquisitionCond.Wait()
}
}
func (ss *SlottedSemaphore) Release(s *slot) error {
ss.lock.Lock()
defer ss.lock.Unlock()
for i, slot := range ss.sema {
if slot == s {
logrus.Tracef("releasing slot at index %d", i)
ss.sema[i] = nil
ss.acquisitionCond.Signal()
return nil
}
}
return ErrSlotAlreadyReleased
}
package goro
import (
"testing"
"time"
)
func TestSlottedSemaphore(t *testing.T) {
ss := NewSlottedSemaphore(1)
results := make([]int, 2)
doneCh := make(chan bool)
go func() {
t.Log("Acquiring slot 1")
slot := ss.AcquireBlocking()
t.Logf("Acquired slot 1: %v", slot)
defer slot.Release()
time.Sleep(1 * time.Second)
results[0] = 1
doneCh <- true
}()
go func() {
t.Log("Acquiring slot 2")
slot := ss.AcquireBlocking()
t.Logf("Acquired slot 2: %v", slot)
defer slot.Release()
results[1] = 2
doneCh <- true
}()
<-doneCh
<-doneCh
if results[0] != 1 || results[1] != 2 {
t.Fail()
}
}
func TestSlottedSemaphoreDoubleRelease(t *testing.T) {
ss := NewSlottedSemaphore(1)
slot := ss.AcquireBlocking()
slot.Release()
err := slot.Release()
if err != ErrSlotAlreadyReleased {
t.Fail()
}
}
func TestSlottedSemaphoreAsyncAcquire(t *testing.T) {
ss := NewSlottedSemaphore(1)
slot, err := ss.Acquire()
if err != nil {
t.Fail()
}
go func() {
time.Sleep(1 * time.Second)
slot.Release()
}()
for {
if slot.IsReleased() {
break
}
}
}
func TestSlottedSemaphoreGoroWait(t *testing.T) {
ss := NewSlottedSemaphore(1)
slot := ss.AcquireBlocking()
done := make(chan bool)
go func() {
for i := 0; i < 10; i++ {
_, err := ss.Acquire()
if err == nil {
t.Errorf("should not be able to acquire a slot")
}
t.Log("Waiting for slot")
time.Sleep(500 * time.Millisecond)
}
done <- true
}()
<-done
slot.Release()
t.Log("Waiting goro closed")
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment