Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
package main
import (
"bytes"
"encoding/gob"
"fmt"
"net"
"os"
"os/exec"
"os/signal"
"strings"
"syscall"
"time"
)
const (
VERSION = 1
)
var (
pos = 0
isRunning = false
)
func main() {
log("Start")
var err error
parentPID := os.Getppid()
log(fmt.Sprintf("Parent: %d", parentPID))
var c1, c2 net.Conn
var l net.Listener
var f *os.File
ppid := os.Getenv("PARENT_PID")
if ppid != "" {
// We're the child
// Connect to the domain socket pipe
f = os.NewFile(3, "domain socket")
_, err = f.Stat()
if err != nil {
log("Domain pipe err")
fmt.Println(err)
os.Exit(1)
}
defer f.Close()
var ok bool
var uc *net.UnixConn
netc, _ := net.FileConn(f)
uc, ok = netc.(*net.UnixConn)
if !ok {
log("Domain pipe is not UnixConn")
os.Exit(1)
}
defer uc.Close()
// Signal the parent to send us the fds
var parent *os.Process
parent, err = os.FindProcess(parentPID)
if err != nil {
log("FindProcess Err")
fmt.Println(err)
os.Exit(1)
}
parent.Signal(syscall.SIGTERM)
// Receive the fd
buf := make([]byte, 1024) // connections metadata
oob := make([]byte, 32) // expect 24 bytes
_, oobn, _, _, err := uc.ReadMsgUnix(buf, oob)
scms, err := syscall.ParseSocketControlMessage(oob[:oobn])
if err != nil {
fmt.Printf("ParseSocketControlMessage: %v\n", err)
os.Exit(1)
}
if len(scms) != 1 {
fmt.Printf("expected 1 SocketControlMessage; got scms = %#v\n", scms)
}
scm := scms[0]
gotFds, err := syscall.ParseUnixRights(&scm)
if err != nil {
fmt.Printf("syscall.ParseUnixRights: %v\n", err)
}
if len(gotFds) != 2 {
fmt.Printf("wanted 2 fd; got %#v\n", gotFds)
}
contentMsg := decodeContentMessage(buf)
fmt.Printf("Message received: %s (%d)\n", contentMsg, len(contentMsg))
// Rebuild the net.Conn(s)
fConn1 := os.NewFile(uintptr(gotFds[0]), "fd-from-parent-1")
fConn2 := os.NewFile(uintptr(gotFds[1]), "fd-from-parent-2")
defer fConn1.Close()
defer fConn2.Close()
c1, err = net.FileConn(fConn1)
c2, err = net.FileConn(fConn2)
if err != nil {
log("FileConn Err")
fmt.Println(err)
os.Exit(1)
}
fmt.Println("c1.RemoteAddr():", c1.RemoteAddr())
defer c1.Close()
defer c2.Close()
} else {
// We're the parent, open the connection
l, err = net.Listen("tcp", ":1122")
if err != nil {
log("Listen Err")
fmt.Println(err)
os.Exit(1)
}
defer l.Close()
c1, err = l.Accept()
if err != nil {
log("Accept Err")
fmt.Println(err)
os.Exit(1)
}
defer c1.Close()
fmt.Println("Got first connection")
c2, err = l.Accept()
if err != nil {
log("Accept Err")
fmt.Println(err)
os.Exit(1)
}
defer c2.Close()
fmt.Println("Got second connection")
}
go writeLoop(c1)
go writeLoop(c2)
// Build the domain socket pair to communicate with child
var fds [2]int
fds, err = syscall.Socketpair(syscall.AF_LOCAL, syscall.SOCK_STREAM, 0)
if err != nil {
fmt.Printf("Socketpair: %v\n", err)
}
defer syscall.Close(fds[0])
defer syscall.Close(fds[1])
writeFile := os.NewFile(uintptr(fds[0]), "write-end")
readFile := os.NewFile(uintptr(fds[1]), "read-end")
defer writeFile.Close()
defer readFile.Close()
// Turn writeFile into a UnixConn. We use it on SIGTERM receipt
var writeConnI net.Conn
writeConnI, err = net.FileConn(writeFile)
if err != nil {
fmt.Printf("FileConn: %v\n", err)
os.Exit(1)
}
defer writeConnI.Close()
writeConn, ok := writeConnI.(*net.UnixConn)
if !ok {
fmt.Printf("unexpected FileConn type; expected UnixConn, got %T\n", writeConnI)
os.Exit(1)
}
defer writeConn.Close()
incoming := make(chan os.Signal)
signal.Notify(incoming,
syscall.SIGINT,
syscall.SIGTERM,
syscall.SIGUSR2,
os.Interrupt)
isRunning = true
for isRunning {
sig := <-incoming
fmt.Println(sig)
switch sig {
case syscall.SIGINT, syscall.SIGKILL, syscall.SIGTERM:
isRunning = false
var connFile1, connFile2 *os.File
connFile1, err = c1.(*net.TCPConn).File()
if err != nil {
log("TCPConn File Err")
fmt.Println(err)
os.Exit(1)
}
defer connFile1.Close()
connFile2, err = c2.(*net.TCPConn).File()
if err != nil {
log("TCPConn File Err")
fmt.Println(err)
os.Exit(1)
}
defer connFile2.Close()
// Send socket fd(s) down the pipe
rights := syscall.UnixRights(
int(connFile1.Fd()), int(connFile2.Fd()))
contentMessage := encodeContentMessage()
n, oobn, err := writeConn.WriteMsgUnix(contentMessage, rights, nil)
if err != nil {
fmt.Printf("WriteMsgUnix: %v\n", err)
return
}
if oobn != len(rights) {
fmt.Printf("WriteMsgUnix = %d, %d\n", n, oobn)
return
}
// Parent done - quit
log("Bye")
case syscall.SIGUSR2:
startChild(readFile)
}
}
log("End")
}
func encodeContentMessage() []byte {
var buf bytes.Buffer
encoder := gob.NewEncoder(&buf)
encoder.Encode("LISTEN")
encoder.Encode(pos)
return buf.Bytes()
}
func decodeContentMessage(msg []byte) string {
decoder := gob.NewDecoder(bytes.NewBuffer(msg))
var contentType string
decoder.Decode(&contentType)
decoder.Decode(&pos)
return contentType
}
func log(msg string) {
fmt.Printf("[%d] v%d %s\n", os.Getpid(), VERSION, msg)
}
func writeLoop(c net.Conn) {
var err error
for err == nil && isRunning {
msg := fmt.Sprintf("[v%d] Message %d\n", VERSION, pos)
pos++
_, err = c.Write([]byte(msg))
if err != nil {
fmt.Println(err)
}
time.Sleep(500 * time.Millisecond)
}
fmt.Println("Exit writeLoop")
}
func startChild(pipe *os.File) {
var err error
var path string
path, err = exec.LookPath(os.Args[0])
if strings.HasPrefix(path, "./") {
var pwd string
pwd, err = os.Getwd()
if err != nil {
log("Getwd Err")
fmt.Println(err)
os.Exit(1)
}
path = pwd + string(os.PathSeparator) + path[2:]
}
cmd := exec.Command(path)
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
cmd.ExtraFiles = []*os.File{pipe}
cmd.Env = append(os.Environ(), fmt.Sprintf("PARENT_PID=%d", os.Getpid()))
err = cmd.Start()
if err != nil {
log("Start Err")
fmt.Println(err)
os.Exit(1)
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.