Skip to content

Instantly share code, notes, and snippets.

@kirk91
Last active November 30, 2022 01:31
Show Gist options
  • Save kirk91/ec25703848172e8f56f671e0e1c73751 to your computer and use it in GitHub Desktop.
Save kirk91/ec25703848172e8f56f671e0e1c73751 to your computer and use it in GitHub Desktop.
Pass File Descriptor over Unix Domain Socket
package main
import (
"fmt"
"log"
"net"
"net/http"
"os"
"syscall"
"golang.org/x/sys/unix"
)
const udsPath = "/tmp/fd-pass-example.sock"
func main() {
os.Remove(udsPath) //nolint: errcheck
lis, err := net.Listen("unix", udsPath)
if err != nil {
panic(err)
}
defer lis.Close()
log.Println("Wait receiving listener ...")
conn, err := lis.Accept()
if err != nil {
panic(err)
}
defer conn.Close()
httpLis := receiveListener(conn.(*net.UnixConn))
http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, "[server2] Hello, world!")
})
log.Printf("Server is listening on %s ...\n", httpLis.Addr())
http.Serve(httpLis, nil)
}
func receiveListener(conn *net.UnixConn) net.Listener {
connFd, err := getConnFd(conn)
if err != nil {
panic(err)
}
// receive socket control message
b := make([]byte, unix.CmsgSpace(4))
_, _, _, _, err = unix.Recvmsg(connFd, nil, b, 0)
if err != nil {
panic(err)
}
// parse socket control message
cmsgs, err := unix.ParseSocketControlMessage(b)
if err != nil {
panic(err)
}
fds, err := unix.ParseUnixRights(&cmsgs[0])
if err != nil {
panic(err)
}
fd := fds[0]
log.Printf("Got socket fd %d\n", fd)
// construct net listener
f := os.NewFile(uintptr(fd), "listener")
defer f.Close()
l, err := net.FileListener(f)
if err != nil {
panic(err)
}
return l
}
func getConnFd(conn syscall.Conn) (connFd int, err error) {
var rawConn syscall.RawConn
rawConn, err = conn.SyscallConn()
if err != nil {
return
}
err = rawConn.Control(func(fd uintptr) {
connFd = int(fd)
})
return
}
package main
import (
"context"
"fmt"
"log"
"net"
"net/http"
"syscall"
"time"
"golang.org/x/sys/unix"
)
const serverAddr = "127.0.0.1:8080"
func main() {
lis, err := net.Listen("tcp", serverAddr)
if err != nil {
panic(err)
}
var s http.Server
mux := new(http.ServeMux)
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, "[server1] Hello, world!")
})
mux.HandleFunc("/passfd", func(w http.ResponseWriter, r *http.Request) {
if err := sendListener(lis.(*net.TCPListener)); err != nil {
fmt.Fprintf(w, "Error: %v", err)
return
}
fmt.Fprintf(w, "Success")
time.AfterFunc(time.Millisecond*50, func() {
log.Println("Shutdown server ...")
s.Shutdown(context.Background())
})
})
s.Handler = mux
log.Printf("Server is listening on %s ...\n", serverAddr)
s.Serve(lis)
log.Println("Bye bye")
}
func sendListener(lis *net.TCPListener) error {
// connect to the unix socket
const udsPath = "/tmp/fd-pass-example.sock"
conn, err := net.Dial("unix", udsPath)
if err != nil {
return err
}
defer conn.Close()
connFd, err := getConnFd(conn.(*net.UnixConn))
if err != nil {
return err
}
// pass listener fd
lisFd, err := getConnFd(lis)
if err != nil {
return err
}
rights := unix.UnixRights(int(lisFd))
return unix.Sendmsg(connFd, nil, rights, nil, 0)
}
func getConnFd(conn syscall.Conn) (connFd int, err error) {
var rawConn syscall.RawConn
rawConn, err = conn.SyscallConn()
if err != nil {
return
}
err = rawConn.Control(func(fd uintptr) {
connFd = int(fd)
})
return
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment