Skip to content

Instantly share code, notes, and snippets.

@russjones russjones/restart.go Secret
Last active Aug 27, 2019

Embed
What would you like to do?
Gracefully Restarting a Go Program
package main
import (
"context"
"encoding/json"
"flag"
"fmt"
"net"
"net/http"
"os"
"os/signal"
"path/filepath"
"syscall"
"time"
)
type listener struct {
Addr string `json:"addr"`
FD int `json:"fd"`
Filename string `json:"filename"`
}
func importListener(addr string) (net.Listener, error) {
// Extract the encoded listener metadata from the environment.
listenerEnv := os.Getenv("LISTENER")
if listenerEnv == "" {
return nil, fmt.Errorf("unable to find LISTENER environment variable")
}
// Unmarshal the listener metadata.
var l listener
err := json.Unmarshal([]byte(listenerEnv), &l)
if err != nil {
return nil, err
}
if l.Addr != addr {
return nil, fmt.Errorf("unable to find listener for %v", addr)
}
// The file has already been passed to this process, extract the file
// descriptor and name from the metadata to rebuild/find the *os.File for
// the listener.
listenerFile := os.NewFile(uintptr(l.FD), l.Filename)
if listenerFile == nil {
return nil, fmt.Errorf("unable to create listener file: %v", err)
}
defer listenerFile.Close()
// Create a net.Listener from the *os.File.
ln, err := net.FileListener(listenerFile)
if err != nil {
return nil, err
}
return ln, nil
}
func createListener(addr string) (net.Listener, error) {
ln, err := net.Listen("tcp", addr)
if err != nil {
return nil, err
}
return ln, nil
}
func createOrImportListener(addr string) (net.Listener, error) {
// Try and import a listener for addr. If it's found, use it.
ln, err := importListener(addr)
if err == nil {
fmt.Printf("Imported listener file descriptor for %v.\n", addr)
return ln, nil
}
// No listener was imported, that means this process has to create one.
ln, err = createListener(addr)
if err != nil {
return nil, err
}
fmt.Printf("Created listener file descriptor for %v.\n", addr)
return ln, nil
}
func getListenerFile(ln net.Listener) (*os.File, error) {
switch t := ln.(type) {
case *net.TCPListener:
return t.File()
case *net.UnixListener:
return t.File()
}
return nil, fmt.Errorf("unsupported listener: %T", ln)
}
func forkChild(addr string, ln net.Listener) (*os.Process, error) {
// Get the file descriptor for the listener and marshal the metadata to pass
// to the child in the environment.
lnFile, err := getListenerFile(ln)
if err != nil {
return nil, err
}
defer lnFile.Close()
l := listener{
Addr: addr,
FD: 3,
Filename: lnFile.Name(),
}
listenerEnv, err := json.Marshal(l)
if err != nil {
return nil, err
}
// Pass stdin, stdout, and stderr along with the listener to the child.
files := []*os.File{
os.Stdin,
os.Stdout,
os.Stderr,
lnFile,
}
// Get current environment and add in the listener to it.
environment := append(os.Environ(), "LISTENER="+string(listenerEnv))
// Get current process name and directory.
execName, err := os.Executable()
if err != nil {
return nil, err
}
execDir := filepath.Dir(execName)
// Spawn child process.
p, err := os.StartProcess(execName, []string{execName}, &os.ProcAttr{
Dir: execDir,
Env: environment,
Files: files,
Sys: &syscall.SysProcAttr{},
})
if err != nil {
return nil, err
}
return p, nil
}
func waitForSignals(addr string, ln net.Listener, server *http.Server) error {
signalCh := make(chan os.Signal, 1024)
signal.Notify(signalCh, syscall.SIGHUP, syscall.SIGUSR2, syscall.SIGINT, syscall.SIGQUIT)
for {
select {
case s := <-signalCh:
fmt.Printf("%v signal received.\n", s)
switch s {
case syscall.SIGHUP:
// Fork a child process.
p, err := forkChild(addr, ln)
if err != nil {
fmt.Printf("Unable to fork child: %v.\n", err)
continue
}
fmt.Printf("Forked child %v.\n", p.Pid)
// Create a context that will expire in 5 seconds and use this as a
// timeout to Shutdown.
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
// Return any errors during shutdown.
return server.Shutdown(ctx)
case syscall.SIGUSR2:
// Fork a child process.
p, err := forkChild(addr, ln)
if err != nil {
fmt.Printf("Unable to fork child: %v.\n", err)
continue
}
// Print the PID of the forked process and keep waiting for more signals.
fmt.Printf("Forked child %v.\n", p.Pid)
case syscall.SIGINT, syscall.SIGQUIT:
// Create a context that will expire in 5 seconds and use this as a
// timeout to Shutdown.
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
// Return any errors during shutdown.
return server.Shutdown(ctx)
}
}
}
}
func handler(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, "Hello from %v!\n", os.Getpid())
}
func startServer(addr string, ln net.Listener) *http.Server {
http.HandleFunc("/hello", handler)
httpServer := &http.Server{
Addr: addr,
}
go httpServer.Serve(ln)
return httpServer
}
func main() {
// Parse command line flags for the address to listen on.
var addr string
flag.StringVar(&addr, "addr", ":8080", "Address to listen on.")
// Create (or import) a net.Listener and start a goroutine that runs
// a HTTP server on that net.Listener.
ln, err := createOrImportListener(addr)
if err != nil {
fmt.Printf("Unable to create or import a listener: %v.\n", err)
os.Exit(1)
}
server := startServer(addr, ln)
// Wait for signals to either fork or quit.
err = waitForSignals(addr, ln, server)
if err != nil {
fmt.Printf("Exiting: %v\n", err)
return
}
fmt.Printf("Exiting.\n")
}
@Gurpartap

This comment has been minimized.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.