Created
May 13, 2021 12:36
-
-
Save sineemore/08d85deba2aa710035c0b0ce4fba682d to your computer and use it in GitHub Desktop.
SNI patch proxy
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package main | |
import ( | |
"errors" | |
"flag" | |
"io" | |
"log" | |
"math" | |
"net" | |
"os" | |
"sync" | |
) | |
const SNI = 0x0 | |
var ( | |
proto string | |
listen string | |
upstream string | |
domain string | |
sni []byte | |
) | |
// Serve reads TLS records with ClientHello message, | |
// adds missing SNI, writes patched ClientHello to upstream | |
// and forwards the rest of the traffic | |
func Serve(conn io.ReadWriter) (err error) { | |
record := make([]byte, 5) | |
_, err = io.ReadFull(conn, record) | |
if err != nil { | |
return | |
} | |
recordLen := Uint16(record[3:]) | |
hello := make([]byte, recordLen) | |
_, err = io.ReadFull(conn, hello) | |
if err != nil { | |
return | |
} | |
helloLen := Uint24(hello[1:]) | |
var recordData []byte | |
// If handshake message spans multiple records (unlikely), read them | |
for len(hello)-4 < int(helloLen) { | |
_, err = io.ReadFull(conn, record) | |
if err != nil { | |
return | |
} | |
recordLen = Uint16(record[3:]) | |
if recordLen == 0 { | |
continue | |
} | |
recordData = make([]byte, recordLen) | |
_, err = io.ReadFull(conn, recordData) | |
if err != nil { | |
return | |
} | |
hello = append(hello, recordData...) | |
} | |
if len(hello)-4 != int(helloLen) { | |
return errors.New("hello not equal helloLen") | |
} | |
cursor := 0 | |
cursor += 38 // skip random | |
cursor += 1 + int(hello[cursor]) // skip session | |
cursor += 2 + int(Uint16(hello[cursor:])) // skip cipher | |
cursor += 1 + int(hello[cursor]) // skip compression | |
extLenCursor := cursor // cursor at extensions length | |
extLen := Uint16(hello[cursor:]) | |
cursor += 2 | |
for cursor < extLenCursor+2+int(extLen) { | |
eType := Uint16(hello[cursor:]) | |
cursor += 2 | |
if eType == SNI { | |
// Already has SNI extension | |
goto forwardConnection | |
} | |
eLen := Uint16(hello[cursor:]) | |
cursor += 2 + int(eLen) | |
} | |
// Add missing SNI | |
hello = append(hello, sni...) | |
// Update extensions and hello length | |
PutUint16(hello[extLenCursor:], extLen+uint16(len(sni))) | |
PutUint24(hello[1:], helloLen+uint32(len(sni))) | |
forwardConnection: | |
upstreamConn, err := net.Dial("tcp", upstream) | |
if err != nil { | |
return | |
} | |
defer upstreamConn.Close() | |
// Write records with hello to upstream | |
for len(hello) > 0 { | |
recordLen = uint16(math.Min(float64(len(hello)), math.MaxUint16)) | |
PutUint16(record[3:], recordLen) | |
_, err = upstreamConn.Write(record) | |
if err != nil { | |
return | |
} | |
recordData, hello = hello[:recordLen], hello[recordLen:] | |
_, err = upstreamConn.Write(recordData) | |
if err != nil { | |
return | |
} | |
} | |
// Forward the rest of the traffic | |
var err2 error | |
wg := sync.WaitGroup{} | |
wg.Add(1) | |
go func() { | |
defer wg.Done() | |
_, err2 = io.Copy(upstreamConn, conn) | |
upstreamConn.Close() | |
}() | |
_, err = io.Copy(conn, upstreamConn) | |
wg.Wait() | |
if err2 != nil { | |
err = err2 | |
} | |
return | |
} | |
func main() { | |
flag.StringVar(&proto, "proto", "tcp", "tcp4, tcp6 or tcp for both") | |
flag.StringVar(&listen, "listen", ":443", "listen addr with port") | |
flag.StringVar(&upstream, "upstream", "", "upstream host with port") | |
flag.StringVar(&domain, "domain", "", "FQDN to add in SNI extension, probably should end with dot (.)") | |
flag.Parse() | |
if upstream == "" || domain == "" { | |
flag.Usage() | |
os.Exit(1) | |
} | |
// Create SNI extension blob | |
sni = make([]byte, 9+len(domain)) | |
PutUint16(sni[2:], uint16(len(sni)-4)) | |
PutUint16(sni[4:], uint16(len(sni)-6)) | |
PutUint16(sni[7:], uint16(len(domain))) | |
copy(sni[9:], domain) | |
l, err := net.Listen(proto, listen) | |
if err != nil { | |
log.Fatal(err) | |
} | |
for { | |
conn, err := l.Accept() | |
if err != nil { | |
log.Println(err) | |
continue | |
} | |
go func(conn net.Conn) { | |
defer conn.Close() | |
if err := Serve(conn); err != nil { | |
log.Println(err) | |
} | |
}(conn) | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment