Last active
October 20, 2023 22:11
-
-
Save romsar/cb34ef0a03bda696e1376e01521c72a8 to your computer and use it in GitHub Desktop.
Graceful shutdown
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package graceful | |
import ( | |
"context" | |
"errors" | |
"log" | |
"strconv" | |
"time" | |
"golang.org/x/sync/errgroup" | |
) | |
var ErrShutdownTimeout = errors.New("shutdown timeout") | |
type Graceful struct { | |
ctx context.Context | |
cancel context.CancelFunc | |
errGrp *errgroup.Group | |
shutdownTimeout time.Duration | |
logger logger | |
} | |
type logger interface { | |
Printf(string, ...interface{}) | |
} | |
type Option func(*Graceful) | |
func WithLogger(logger logger) Option { | |
return func(s *Graceful) { | |
s.logger = logger | |
} | |
} | |
func WithShutdownTimeout(timeout time.Duration) Option { | |
return func(s *Graceful) { | |
s.shutdownTimeout = timeout | |
} | |
} | |
func New(ctx context.Context, cancel context.CancelFunc, opts ...Option) *Graceful { | |
errGrp, errCtx := errgroup.WithContext(ctx) | |
g := &Graceful{ | |
ctx: errCtx, | |
errGrp: errGrp, | |
cancel: cancel, | |
} | |
for _, opt := range opts { | |
opt(g) | |
} | |
if g.logger == nil { | |
g.logger = log.New(log.Writer(), log.Prefix(), log.Flags()) | |
} | |
return g | |
} | |
type GoRun struct { | |
workFuncs []func() error | |
shutdownFunc func() error | |
cancelOnFinish bool | |
} | |
type GoRunOption func(*GoRun) | |
func WithWork(f func() error) GoRunOption { | |
return func(run *GoRun) { | |
run.workFuncs = append(run.workFuncs, f) | |
} | |
} | |
func SetShutdownFunc(f func() error) GoRunOption { | |
return func(run *GoRun) { | |
run.shutdownFunc = f | |
} | |
} | |
func (g *Graceful) Go(name string, opts ...GoRunOption) { | |
run := GoRun{ | |
cancelOnFinish: true, | |
} | |
for _, opt := range opts { | |
opt(&run) | |
} | |
for i, f := range run.workFuncs { | |
f, name := f, name | |
if len(run.workFuncs) > 1 { | |
name = name + " #" + strconv.Itoa(i+1) | |
} | |
g.errGrp.Go(func() error { | |
if run.cancelOnFinish { | |
defer g.cancel() | |
} | |
g.logger.Printf("[%s]: start", name) | |
workCh := make(chan error, 1) | |
go func() { | |
select { | |
case <-g.ctx.Done(): | |
workCh <- g.ctx.Err() | |
case workCh <- f(): | |
} | |
}() | |
select { | |
case err := <-workCh: | |
if err != nil && !errors.Is(err, context.Canceled) { | |
g.logger.Printf("[%s]: %s", name, err) | |
return err | |
} | |
g.logger.Printf("[%s]: process finished", name) | |
case <-g.ctx.Done(): | |
} | |
return nil | |
}) | |
} | |
g.errGrp.Go(func() error { | |
<-g.ctx.Done() | |
g.logger.Printf("[%s]: stopping", name) | |
if run.shutdownFunc != nil { | |
shutdownCh := make(chan error, 1) | |
go func() { | |
shutdownCh <- run.shutdownFunc() | |
close(shutdownCh) | |
}() | |
var err error | |
if g.shutdownTimeout > 0 { | |
timer := time.NewTimer(g.shutdownTimeout) | |
defer timer.Stop() | |
select { | |
case <-timer.C: | |
g.logger.Printf("[%s]: shutdown timeout", name) | |
err = ErrShutdownTimeout | |
case err = <-shutdownCh: | |
} | |
} else { | |
err = <-shutdownCh | |
} | |
if err != nil { | |
g.logger.Printf("[%s]: graceful stop error", name) | |
return nil | |
} | |
} | |
g.logger.Printf("[%s]: graceful stopped", name) | |
return nil | |
}) | |
} | |
func (g *Graceful) Wait() error { | |
return g.errGrp.Wait() | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment