Skip to content

Instantly share code, notes, and snippets.

@sineemore
Created May 13, 2021 12:36
Show Gist options
  • Save sineemore/08d85deba2aa710035c0b0ce4fba682d to your computer and use it in GitHub Desktop.
Save sineemore/08d85deba2aa710035c0b0ce4fba682d to your computer and use it in GitHub Desktop.
SNI patch proxy
package main
import (
"errors"
"flag"
"io"
"log"
"math"
"net"
"os"
"sync"
)
const SNI = 0x0
var (
proto string
listen string
upstream string
domain string
sni []byte
)
// Serve reads TLS records with ClientHello message,
// adds missing SNI, writes patched ClientHello to upstream
// and forwards the rest of the traffic
func Serve(conn io.ReadWriter) (err error) {
record := make([]byte, 5)
_, err = io.ReadFull(conn, record)
if err != nil {
return
}
recordLen := Uint16(record[3:])
hello := make([]byte, recordLen)
_, err = io.ReadFull(conn, hello)
if err != nil {
return
}
helloLen := Uint24(hello[1:])
var recordData []byte
// If handshake message spans multiple records (unlikely), read them
for len(hello)-4 < int(helloLen) {
_, err = io.ReadFull(conn, record)
if err != nil {
return
}
recordLen = Uint16(record[3:])
if recordLen == 0 {
continue
}
recordData = make([]byte, recordLen)
_, err = io.ReadFull(conn, recordData)
if err != nil {
return
}
hello = append(hello, recordData...)
}
if len(hello)-4 != int(helloLen) {
return errors.New("hello not equal helloLen")
}
cursor := 0
cursor += 38 // skip random
cursor += 1 + int(hello[cursor]) // skip session
cursor += 2 + int(Uint16(hello[cursor:])) // skip cipher
cursor += 1 + int(hello[cursor]) // skip compression
extLenCursor := cursor // cursor at extensions length
extLen := Uint16(hello[cursor:])
cursor += 2
for cursor < extLenCursor+2+int(extLen) {
eType := Uint16(hello[cursor:])
cursor += 2
if eType == SNI {
// Already has SNI extension
goto forwardConnection
}
eLen := Uint16(hello[cursor:])
cursor += 2 + int(eLen)
}
// Add missing SNI
hello = append(hello, sni...)
// Update extensions and hello length
PutUint16(hello[extLenCursor:], extLen+uint16(len(sni)))
PutUint24(hello[1:], helloLen+uint32(len(sni)))
forwardConnection:
upstreamConn, err := net.Dial("tcp", upstream)
if err != nil {
return
}
defer upstreamConn.Close()
// Write records with hello to upstream
for len(hello) > 0 {
recordLen = uint16(math.Min(float64(len(hello)), math.MaxUint16))
PutUint16(record[3:], recordLen)
_, err = upstreamConn.Write(record)
if err != nil {
return
}
recordData, hello = hello[:recordLen], hello[recordLen:]
_, err = upstreamConn.Write(recordData)
if err != nil {
return
}
}
// Forward the rest of the traffic
var err2 error
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
defer wg.Done()
_, err2 = io.Copy(upstreamConn, conn)
upstreamConn.Close()
}()
_, err = io.Copy(conn, upstreamConn)
wg.Wait()
if err2 != nil {
err = err2
}
return
}
func main() {
flag.StringVar(&proto, "proto", "tcp", "tcp4, tcp6 or tcp for both")
flag.StringVar(&listen, "listen", ":443", "listen addr with port")
flag.StringVar(&upstream, "upstream", "", "upstream host with port")
flag.StringVar(&domain, "domain", "", "FQDN to add in SNI extension, probably should end with dot (.)")
flag.Parse()
if upstream == "" || domain == "" {
flag.Usage()
os.Exit(1)
}
// Create SNI extension blob
sni = make([]byte, 9+len(domain))
PutUint16(sni[2:], uint16(len(sni)-4))
PutUint16(sni[4:], uint16(len(sni)-6))
PutUint16(sni[7:], uint16(len(domain)))
copy(sni[9:], domain)
l, err := net.Listen(proto, listen)
if err != nil {
log.Fatal(err)
}
for {
conn, err := l.Accept()
if err != nil {
log.Println(err)
continue
}
go func(conn net.Conn) {
defer conn.Close()
if err := Serve(conn); err != nil {
log.Println(err)
}
}(conn)
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment