Skip to content

Instantly share code, notes, and snippets.

@majek
Last active April 16, 2021 07:07
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save majek/778021e4f95f3e77ada5afcecacbd819 to your computer and use it in GitHub Desktop.
Save majek/778021e4f95f3e77ada5afcecacbd819 to your computer and use it in GitHub Desktop.
netstack from gvisor 4 netns
package main
import (
"flag"
"fmt"
"math/rand"
"net"
"os"
"os/signal"
"runtime"
"syscall"
"time"
specs "github.com/opencontainers/runtime-spec/specs-go"
"gvisor.dev/gvisor/runsc/specutils"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/fdbased"
"gvisor.dev/gvisor/pkg/tcpip/link/rawfile"
"gvisor.dev/gvisor/pkg/tcpip/link/sniffer"
"gvisor.dev/gvisor/pkg/tcpip/link/tun"
"gvisor.dev/gvisor/pkg/tcpip/network/arp"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
"gvisor.dev/gvisor/pkg/waiter"
)
func joinNetNS(nsPath string) (func(), error) {
runtime.LockOSThread()
restoreNS, err := specutils.ApplyNS(specs.LinuxNamespace{
Type: specs.NetworkNamespace,
Path: nsPath,
})
if err != nil {
runtime.UnlockOSThread()
return nil, fmt.Errorf("joining net namespace %q: %v", nsPath, err)
}
return func() {
restoreNS()
runtime.UnlockOSThread()
}, nil
}
type ProxyError struct {
LocalRead error
LocalWrite error
RemoteRead error
RemoteWrite error
First int
}
func (pe ProxyError) String() string {
x := []string{
fmt.Sprintf("%s", pe.LocalRead),
fmt.Sprintf("%s", pe.LocalWrite),
fmt.Sprintf("%s", pe.RemoteRead),
fmt.Sprintf("%s", pe.RemoteWrite),
}
if pe.LocalRead == nil {
x[0] = "0"
}
if pe.LocalWrite == nil {
x[1] = "0"
}
if pe.RemoteRead == nil {
x[2] = "0"
}
if pe.RemoteWrite == nil {
x[3] = "0"
}
x[pe.First] = fmt.Sprintf("[%s]", x[pe.First])
return fmt.Sprintf("l=%s/%s r=%s/%s", x[0], x[1], x[2], x[3])
}
const (
MINPROXYBUFSIZE = 2 * 1024
MAXPROXYBUFSIZE = 256 * 1024
)
type Closer interface {
CloseRead() error
CloseWrite() error
}
func proxyOneFlow(
in, out net.Conn,
readErrPtr *error,
writeErrPtr *error,
doneCh chan int,
scDir int,
) {
buf := make([]byte, MINPROXYBUFSIZE)
for {
n, err := in.Read(buf[:])
if err != nil {
*readErrPtr = err
break
}
// Write must return n==len(buf) or err
// https://golang.org/pkg/io/#Writer
m, err := out.Write(buf[:n])
if err != nil {
*writeErrPtr = err
break
}
if n != m {
panic("write len")
}
// Heuristics: Start with small buffer and bump it up
// if full reads, up to some defined max.
if n == len(buf) && len(buf) < MAXPROXYBUFSIZE {
buf = make([]byte, len(buf)*2)
}
}
if c, ok := in.(Closer); ok {
c.CloseRead()
} else {
in.Close()
}
if c, ok := out.(Closer); ok {
c.CloseWrite()
} else {
out.Close()
}
// Synchronize with parent.
doneCh <- scDir
}
func proxyTcp(network string, local *gonet.Conn, ra net.Addr) ProxyError {
var (
pe ProxyError
doneCh = make(chan int, 2)
)
fmt.Printf("going to %q\n", ra)
// ra := local.LocalAddr()
remote, err := net.Dial(network, ra.String())
if err != nil {
local.Close()
pe.RemoteRead = err
pe.First = 2
return pe
}
go proxyOneFlow(local, remote, &pe.LocalRead,
&pe.RemoteWrite, doneCh, 0)
proxyOneFlow(remote, local, &pe.RemoteRead,
&pe.LocalWrite, doneCh, 1)
first := <-doneCh
_ = <-doneCh
switch {
case first == 0 && pe.LocalRead != nil:
pe.First = 0
case first == 0 && pe.RemoteWrite != nil:
pe.First = 3
case first == 1 && pe.RemoteRead != nil:
pe.First = 2
case first == 1 && pe.LocalWrite != nil:
pe.First = 1
}
return pe
}
var (
debug = flag.Bool("debug", false, "enable debug logging.")
debugLog = flag.String("debug-log", "", "additional location for logs. If it ends with '/', log files are created inside the directory with default names. The following variables are available: %TIMESTAMP%, %COMMAND%.")
netNsPath = flag.String("netns", "", "path to network namespace")
ifName = flag.String("interface", "tun0", "interface name within netns")
)
func main() {
sigCh := make(chan os.Signal, 4)
signal.Notify(sigCh, syscall.SIGINT)
signal.Notify(sigCh, syscall.SIGTERM)
flag.Parse()
if *debug {
log.SetLevel(log.Info)
} else {
log.SetLevel(log.Warning)
}
if *debugLog != "" {
f, err := specutils.DebugLogFile(*debugLog, "slirp", "" /* name */)
if err != nil {
fmt.Fprintf(os.Stderr, "error opening debug log file in %q: %v", *debugLog, err)
os.Exit(-1)
}
var e log.Emitter
e = &log.GoogleEmitter{&log.Writer{Next: f}}
log.SetTarget(e)
}
var (
restore func()
err error
)
if *netNsPath != "" {
fmt.Fprintf(os.Stderr, "[.] Joininig netns %s\n", *netNsPath)
restore, err = joinNetNS(*netNsPath)
if err != nil {
fmt.Fprintf(os.Stderr, "[!] Can't join netns %s: %s\n", netNsPath, err)
os.Exit(-2)
}
}
fmt.Fprintf(os.Stderr, "[.] Opening tun interface %s\n", *ifName)
mtu, err := rawfile.GetMTU(*ifName)
if err != nil {
fmt.Fprintf(os.Stderr, "[!] GetMTU(%s) = %s\n", *ifName, err)
os.Exit(-1)
}
tapMode := false
fd, err := tun.Open(*ifName)
if err != nil {
tapMode = true
fd, err = tun.OpenTAP(*ifName)
if err != nil {
fmt.Fprintf(os.Stderr, "[!] open(%s) = %s\n", *ifName, err)
os.Exit(-1)
}
}
if *netNsPath != "" {
fmt.Fprintf(os.Stderr, "[.] Restoring root netns\n")
restore()
}
rand.Seed(time.Now().UnixNano())
// Create the stack with ipv4 and tcp protocols, then add a tun-based
// NIC and ipv4 address.
opts := stack.Options{
NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol(), arp.NewProtocol()},
TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol(), udp.NewProtocol(), icmp.NewProtocol4(), icmp.NewProtocol6()},
HandleLocal: true,
}
s := stack.New(opts)
s.SetForwarding(true)
s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SACKEnabled(true))
s.SetNetworkProtocolOption(ipv4.ProtocolNumber, tcpip.DefaultTTLOption(64))
s.SetNetworkProtocolOption(ipv6.ProtocolNumber, tcpip.DefaultTTLOption(64))
s.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.ModerateReceiveBufferOption(true))
// We expect no packet loss, therefore we can bump
// buffers. Too large buffers thrash cache, so there is a
// wrong value, benchmark required.
s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{1, 4 * tcp.DefaultReceiveBufferSize, 8 * tcp.DefaultReceiveBufferSize})
s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SendBufferSizeOption{1, 4 * tcp.DefaultSendBufferSize, 8 * tcp.DefaultSendBufferSize})
mac, _ := net.ParseMAC("70:71:aa:4b:29:aa")
parms := fdbased.Options{FDs: []int{fd},
MTU: mtu,
}
if tapMode {
parms.EthernetHeader = true
parms.Address = tcpip.LinkAddress(mac)
}
linkEP, err := fdbased.New(&parms)
if err != nil {
fmt.Fprintf(os.Stderr, "[!] fdbased.New(%s) = %s\n", ifName, err)
os.Exit(-1)
}
if err := s.CreateNIC(1, sniffer.New(linkEP)); err != nil {
fmt.Fprintf(os.Stderr, "[!] CreateNIC(%s) = %s\n", ifName, err)
os.Exit(-1)
}
{
// Assign L2 and L3 addresses
s.AddAddress(1, arp.ProtocolNumber, arp.ProtocolAddress)
tenAddr := tcpip.Address(net.IPv4(10, 0, 2, 2).To4())
s.AddAddress(1, ipv4.ProtocolNumber, tenAddr)
s.AddAddressRange(1, ipv4.ProtocolNumber, header.IPv4EmptySubnet)
s.AddAddressRange(1, ipv6.ProtocolNumber, header.IPv6EmptySubnet)
}
hostfwd := make(map[string]string, 0)
hostfwd["10.0.2.2"] = "127.0.0.1"
fwdTcp := tcp.NewForwarder(s, 30000, 10, func(r *tcp.ForwarderRequest) {
id := r.ID()
var wq waiter.Queue
ep, err := r.CreateEndpoint(&wq)
if err != nil {
fmt.Printf("r.CreateEndpoint() = %v\n", err)
return
}
ep.SetSockOptInt(tcpip.DelayOption, 0)
r.Complete(false)
c := gonet.NewConn(&wq, ep)
fwdDst := net.TCPAddr{
IP: net.ParseIP(id.LocalAddress.String()),
Port: int(id.LocalPort),
}
fmt.Printf("[+] %s/%s New conn\n", c.RemoteAddr(), fwdDst.String())
if y, ok := hostfwd[id.LocalAddress.String()]; ok == true {
fwdDst.IP = net.ParseIP(y)
}
pe := proxyTcp("tcp", c, &fwdDst)
fmt.Printf("[-] %s/%s Conn done: %s\n", c.LocalAddr(), c.RemoteAddr(), pe)
})
s.SetTransportProtocolHandler(tcp.ProtocolNumber, fwdTcp.HandlePacket)
fwdUdp := udp.NewForwarder(s, func(r *udp.ForwarderRequest) {
id := r.ID()
var wq waiter.Queue
ep, err := r.CreateEndpoint(&wq)
if err != nil {
fmt.Printf("r.CreateEndpoint() = %v\n", err)
return
}
c := gonet.NewConn(&wq, ep)
fwdDst := net.TCPAddr{
IP: net.ParseIP(id.LocalAddress.String()),
Port: int(id.LocalPort),
}
fmt.Printf("[+] %s/%s New conn\n", c.RemoteAddr(), fwdDst.String())
if y, ok := hostfwd[id.LocalAddress.String()]; ok == true {
fwdDst.IP = net.ParseIP(y)
}
pe := proxyTcp("udp", c, &fwdDst)
fmt.Printf("[-] %s/%s Conn done: %s\n", c.RemoteAddr(), fwdDst.String(), pe)
})
s.SetTransportProtocolHandler(udp.ProtocolNumber, fwdUdp.HandlePacket)
ticker := time.Tick(15 * time.Second)
for {
select {
case sig := <-sigCh:
signal.Reset(sig)
fmt.Fprintf(os.Stderr, "[-] Closing\n")
goto stop
case <-ticker:
stats := s.Stats()
fmt.Printf("rx=%d tx=%d\n", stats.IP.PacketsReceived.Value(),
stats.IP.PacketsSent.Value())
}
}
stop:
s.Wait()
}
# setup tuntap
nsenter -n -t ${NSPID} ip link set lo up
nsenter -n -t ${NSPID} ip tuntap add mode tap name eth0
nsenter -n -t ${NSPID} ip link set eth0 mtu $[65536-4096+40]
nsenter -n -t ${NSPID} ip link set dev eth0 up
nsenter -n -t ${NSPID} ip addr add 10.0.2.100/24 dev eth0
nsenter -n -t ${NSPID} ip route add 0.0.0.0/0 via 10.0.2.2 dev eth0
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment