Skip to content

Instantly share code, notes, and snippets.

@romsar
Last active October 20, 2023 22:11
Show Gist options
  • Save romsar/cb34ef0a03bda696e1376e01521c72a8 to your computer and use it in GitHub Desktop.
Save romsar/cb34ef0a03bda696e1376e01521c72a8 to your computer and use it in GitHub Desktop.
Graceful shutdown
package graceful
import (
"context"
"errors"
"log"
"strconv"
"time"
"golang.org/x/sync/errgroup"
)
var ErrShutdownTimeout = errors.New("shutdown timeout")
type Graceful struct {
ctx context.Context
cancel context.CancelFunc
errGrp *errgroup.Group
shutdownTimeout time.Duration
logger logger
}
type logger interface {
Printf(string, ...interface{})
}
type Option func(*Graceful)
func WithLogger(logger logger) Option {
return func(s *Graceful) {
s.logger = logger
}
}
func WithShutdownTimeout(timeout time.Duration) Option {
return func(s *Graceful) {
s.shutdownTimeout = timeout
}
}
func New(ctx context.Context, cancel context.CancelFunc, opts ...Option) *Graceful {
errGrp, errCtx := errgroup.WithContext(ctx)
g := &Graceful{
ctx: errCtx,
errGrp: errGrp,
cancel: cancel,
}
for _, opt := range opts {
opt(g)
}
if g.logger == nil {
g.logger = log.New(log.Writer(), log.Prefix(), log.Flags())
}
return g
}
type GoRun struct {
workFuncs []func() error
shutdownFunc func() error
cancelOnFinish bool
}
type GoRunOption func(*GoRun)
func WithWork(f func() error) GoRunOption {
return func(run *GoRun) {
run.workFuncs = append(run.workFuncs, f)
}
}
func SetShutdownFunc(f func() error) GoRunOption {
return func(run *GoRun) {
run.shutdownFunc = f
}
}
func (g *Graceful) Go(name string, opts ...GoRunOption) {
run := GoRun{
cancelOnFinish: true,
}
for _, opt := range opts {
opt(&run)
}
for i, f := range run.workFuncs {
f, name := f, name
if len(run.workFuncs) > 1 {
name = name + " #" + strconv.Itoa(i+1)
}
g.errGrp.Go(func() error {
if run.cancelOnFinish {
defer g.cancel()
}
g.logger.Printf("[%s]: start", name)
workCh := make(chan error, 1)
go func() {
select {
case <-g.ctx.Done():
workCh <- g.ctx.Err()
case workCh <- f():
}
}()
select {
case err := <-workCh:
if err != nil && !errors.Is(err, context.Canceled) {
g.logger.Printf("[%s]: %s", name, err)
return err
}
g.logger.Printf("[%s]: process finished", name)
case <-g.ctx.Done():
}
return nil
})
}
g.errGrp.Go(func() error {
<-g.ctx.Done()
g.logger.Printf("[%s]: stopping", name)
if run.shutdownFunc != nil {
shutdownCh := make(chan error, 1)
go func() {
shutdownCh <- run.shutdownFunc()
close(shutdownCh)
}()
var err error
if g.shutdownTimeout > 0 {
timer := time.NewTimer(g.shutdownTimeout)
defer timer.Stop()
select {
case <-timer.C:
g.logger.Printf("[%s]: shutdown timeout", name)
err = ErrShutdownTimeout
case err = <-shutdownCh:
}
} else {
err = <-shutdownCh
}
if err != nil {
g.logger.Printf("[%s]: graceful stop error", name)
return nil
}
}
g.logger.Printf("[%s]: graceful stopped", name)
return nil
})
}
func (g *Graceful) Wait() error {
return g.errGrp.Wait()
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment