Skip to content

Instantly share code, notes, and snippets.

@CAFxX
Last active July 27, 2023 01:42
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save CAFxX/a6fca31790e0dcc773390c1faf2e9f86 to your computer and use it in GitHub Desktop.
Save CAFxX/a6fca31790e0dcc773390c1faf2e9f86 to your computer and use it in GitHub Desktop.
Request batcher
package batchgetter
type Getter[I, T any] interface {
Get(context.Context, []I) ([]T, error)
}
type BatchGetter[I, T any] struct {
parent Getter[I, T]
batchWait time.Duration
mu sync.Mutex
ctx []context.Context
batch []I
batchTimer *time.Timer
resCh chan struct{}
res *result[I, T]
}
var _ Getter[I, T] = (*BatchGetter[I, T])(nil)
type result[I, T any] struct {
m map[I]T
err error
}
func (g *BatchGetter[I, T]) Get(ctx context.Context, id []I) ([]T, error) {
if g.batchWait <= 0 {
return g.parent.Get(ctx, id)
}
mu.Lock()
if batchTimer == nil {
g.resCh = make(chan struct{})
g.res = new(result[I, T])
g.batchTimer = time.AfterFunc(g.batchWait, g.get)
}
g.batch = append(g.batch, id...)
g.ctx = append(g.ctx, ctx)
res := g.res
resCh := g.resCh
mu.Unlock()
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-resCh:
}
if res.err != nil {
return nil, res.err
}
r := make([]T, 0, len(id))
for _, e := range id {
r = append(r, res.m[e])
}
return r, nil
}
func (g *BatchGetter[I, T]) get() {
g.mu.Lock()
batch, ctx, res, resCh := g.batch[:len(g.batch):len(g.batch)], g.ctx[:len(g.ctx):len(g.ctx)], g.res, g.resCh
g.batchTimer, g.batch, g.ctx, g.res, g.resCh = nil, g.batch[len(g.batch):], g.ctx[len(g.ctx):], nil, nil
g.mu.Unlock()
defer close(resCh)
defer func() {
if r := recover(); r != nil {
res.m = nil
res.err = fmt.Errorf("panic: %v", r)
}
}()
res.m = make(map[I]T, len(batch))
for _, e := range batch {
// TODO: filter out entries for which the context has already expired
var zero T
res.m[e] = zero
}
batch = batch[:0]
for k := range g.res.m {
batch = append(batch, k)
}
actx, cancel := anyCtx(ctx)
if cancel != nil {
defer cancel()
}
res, err := g.parent.Get(actx, batch)
for i, e := range res {
res.m[i] = e
}
res.err = err
}
func anyCtx(ctxs []context.Context) (context.Context, func()) {
if len(ctxs) == 0 {
return context.Background(), nil
}
for _, ctx := range ctxs {
if ctx == nil {
panic("nil context")
}
if ctx.Done() == nil {
return context.Background(), nil
}
}
if len(ctxs) == 1 {
return ctxs[0], nil
}
actx, cancel := context.WithCancel()
go func() {
for _, ctx := range ctxs {
select {
case <-actx.Done():
break
case <-ctx.Done():
}
}
cancel()
}()
return actx, cancel
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment