Skip to content

Instantly share code, notes, and snippets.

@knight42
Created June 22, 2020 06:51
Show Gist options
  • Save knight42/6ad35ce6fbf96519259b43a8c3f37478 to your computer and use it in GitHub Desktop.
Save knight42/6ad35ce6fbf96519259b43a8c3f37478 to your computer and use it in GitHub Desktop.
package main
import (
"bytes"
"context"
"fmt"
"io"
"log"
"net"
"strconv"
"sync"
"time"
)
// TODO:
// 1. list all modules
const (
BufferSize = 1500
DefaultModuleName = "*"
ReadTimeout = time.Minute
WriteTimeout = time.Minute
)
var (
RsyncdVersionPrefix = []byte("@RSYNCD:")
RsyncdVersion = []byte("@RSYNCD: 31.0\n")
)
type Upstream struct {
Host string
Port int
}
type Relay struct {
dialer net.Dialer
listenAddr string
upstreams map[string]*Upstream
bufPool sync.Pool
}
func readWithTimeout(conn net.Conn, timeout time.Duration, buf []byte) (n int, err error) {
_ = conn.SetReadDeadline(time.Now().Add(timeout))
n, err = conn.Read(buf)
return
}
func writeWithTimeout(conn net.Conn, timeout time.Duration, buf []byte) (n int, err error) {
_ = conn.SetWriteDeadline(time.Now().Add(timeout))
n, err = conn.Write(buf)
return
}
func (r *Relay) proxyUpstream(ctx context.Context, downConn *net.TCPConn) error {
defer downConn.Close()
buf := r.bufPool.Get().([]byte)
defer r.bufPool.Put(buf)
n, err := readWithTimeout(downConn, ReadTimeout, buf)
if err != nil {
return fmt.Errorf("read version from client: %w", err)
}
data := buf[:n]
if !bytes.HasPrefix(data, RsyncdVersionPrefix) {
return fmt.Errorf("unknown version from client: %s", data)
}
_, err = writeWithTimeout(downConn, WriteTimeout, RsyncdVersion)
if err != nil {
return fmt.Errorf("send version to client: %w", err)
}
n, err = readWithTimeout(downConn, ReadTimeout, buf)
if err != nil {
return fmt.Errorf("read module from client: %w", err)
}
if n == 0 {
return fmt.Errorf("empty request from client")
}
data = buf[:n]
if len(data) == 1 {
// FIXME: list all modules
_, _ = writeWithTimeout(downConn, WriteTimeout, []byte("foo\nbar\n@RSYNCD: EXIT\n"))
return nil
}
module := string(buf[:n-1]) // trim trailing \n
upstream := r.upstreams[module]
if upstream == nil {
log.Printf("DEBUG: unknown module: %s, fallback to default upstream", module)
upstream = r.upstreams[DefaultModuleName]
}
upstreamAddr := net.JoinHostPort(upstream.Host, strconv.Itoa(upstream.Port))
conn, err := r.dialer.DialContext(ctx, "tcp", upstreamAddr)
if err != nil {
return fmt.Errorf("dial to upstream: %s: %w", upstreamAddr, err)
}
upConn := conn.(*net.TCPConn)
defer upConn.Close()
_, err = writeWithTimeout(upConn, WriteTimeout, RsyncdVersion)
if err != nil {
return fmt.Errorf("send version to upstream: %w", err)
}
n, err = readWithTimeout(upConn, ReadTimeout, buf)
if err != nil {
return fmt.Errorf("read version from upstream: %w", err)
}
data = buf[:n]
if !bytes.HasPrefix(data, RsyncdVersionPrefix) {
return fmt.Errorf("unknown version from upstream: %s", data)
}
_, err = writeWithTimeout(upConn, WriteTimeout, []byte(module+"\n"))
if err != nil {
return fmt.Errorf("send module to upstream: %w", err)
}
upClosed := make(chan struct{})
downClosed := make(chan struct{})
go func() {
_, _ = io.Copy(upConn, downConn)
close(downClosed)
}()
go func() {
_, _ = io.Copy(downConn, upConn)
close(upClosed)
}()
var waitFor chan struct{}
select {
case <-downClosed:
_ = upConn.SetLinger(0)
_ = upConn.CloseRead()
waitFor = upClosed
case <-upClosed:
_ = downConn.CloseRead()
waitFor = downClosed
}
<-waitFor
return nil
}
func (r *Relay) handleConn(ctx context.Context, downConn net.Conn) {
err := r.proxyUpstream(ctx, downConn.(*net.TCPConn))
if err != nil {
log.Printf("handleConn: %s", err)
}
}
func (r *Relay) Run(ctx context.Context) error {
log.Printf("Listening on %s", r.listenAddr)
srv, err := net.Listen("tcp", r.listenAddr)
if err != nil {
return fmt.Errorf("listen: %w", err)
}
for {
select {
case <-ctx.Done():
return nil
default:
}
conn, err := srv.Accept()
if err != nil {
log.Printf("Accept connection: %s", err)
continue
}
go r.handleConn(ctx, conn)
}
}
func NewRelay(listenAddr string, upstreams map[string]*Upstream) *Relay {
return &Relay{
dialer: net.Dialer{},
listenAddr: listenAddr,
upstreams: upstreams,
bufPool: sync.Pool{
New: func() interface{} {
return make([]byte, BufferSize)
},
},
}
}
func main() {
r := NewRelay("127.0.0.1:9988", map[string]*Upstream{
"foo": {Host: "127.0.0.1", Port: 1234},
DefaultModuleName: {Host: "127.0.0.1", Port: 1235},
})
err := r.Run(context.Background())
if err != nil {
log.Fatalln(err)
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment