Skip to content

Instantly share code, notes, and snippets.

@scottt
Last active August 1, 2019 14:26
Show Gist options
  • Save scottt/6037e0c6d767781359f2e8d42e1f61c1 to your computer and use it in GitHub Desktop.
Save scottt/6037e0c6d767781359f2e8d42e1f61c1 to your computer and use it in GitHub Desktop.
package network
import (
"bufio"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"fmt"
"log"
"math/big"
"net"
"os"
"time"
)
const (
port0 = 1337
port1 = port0 + 1
certValidFor = 10 * 365 * 24 * time.Hour
organizationName = "ThunderCore"
)
var (
serverLgr, clientLgr, mitmLgr *log.Logger
mitmPublicKey, serverPublicKey, publicKeySeenByServer, clientPublicKey, publicKeySeenByClient []byte
)
func generateCert(domainName string, lgr *log.Logger) (publicKeyRaw, certPem, privPem []byte, err error) {
priv, err := ecdsa.GenerateKey(elliptic.P384(), rand.Reader)
if err != nil {
lgr.Fatalf("ecdsa.GenerateKey: %s\n", err)
}
privBytes, err := x509.MarshalECPrivateKey(priv)
if err != nil {
lgr.Fatalf("MarshalECPrivateKey: %s\n", err)
}
privPem = pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: privBytes})
publicKeyRaw, err = x509.MarshalPKIXPublicKey(priv.Public())
if err != nil {
lgr.Fatalf("MarshalPKIXPublicKey: %s\n", err)
}
var (
notBefore, notAfter time.Time
)
notBefore = time.Now()
notAfter = notBefore.Add(certValidFor)
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
if err != nil {
log.Fatalf("failed to generate serial number: %s", err)
}
template := x509.Certificate{
SerialNumber: serialNumber,
Subject: pkix.Name{
Organization: []string{organizationName},
},
NotBefore: notBefore,
NotAfter: notAfter,
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
DNSNames: []string{domainName},
}
derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv)
if err != nil {
lgr.Fatalf("Failed to create certificate: %s", err)
}
certPem = pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
return publicKeyRaw, certPem, privPem, nil
}
func Server(done chan interface{}) {
serverLgr = log.New(os.Stderr, "sever ", log.LstdFlags|log.Lshortfile)
lgr := serverLgr
pub, certPem, privPem, err := generateCert("server", lgr)
serverPublicKey = pub
if err != nil {
lgr.Fatalf("generateCert failed: %s", err)
}
cert, err := tls.X509KeyPair(certPem, privPem)
if err != nil {
lgr.Fatalf("X509KeyPair failed: %s", err)
}
ln, err := tls.Listen("tcp", fmt.Sprintf(":%d", port0), &tls.Config{
Certificates: []tls.Certificate{cert},
ClientAuth: tls.RequireAnyClientCert,
})
if err != nil {
lgr.Fatalf("Listen: %s", err)
}
defer ln.Close()
for {
conn, err := ln.Accept()
if err != nil {
lgr.Fatalf("Accept: %s", err)
}
err = conn.(*tls.Conn).Handshake()
if err != nil {
lgr.Fatalf("Handshake: %s", err)
}
s := conn.(*tls.Conn).ConnectionState()
certs := s.PeerCertificates
if len(certs) < 1 {
lgr.Fatalf("peer certificates: %v, connection state: %v", certs, s)
}
publicKeySeenByServer = certs[0].RawSubjectPublicKeyInfo
done <- nil
go handleConnection(conn)
}
}
func handleConnection(conn net.Conn) {
lgr := serverLgr
defer conn.Close()
r := bufio.NewReader(conn)
for {
msg, err := r.ReadString('\n')
if err != nil {
lgr.Println(err)
return
}
lgr.Printf("msg: %q", msg)
n, err := conn.Write([]byte("world\n"))
if err != nil {
lgr.Fatalf("conn.Write: %d, %s", n, err)
}
}
}
func Client(done chan interface{}) {
clientLgr = log.New(os.Stderr, "client ", log.LstdFlags|log.Lshortfile)
lgr := clientLgr
var (
conn *tls.Conn
err error
)
pub, certPem, privPem, err := generateCert("client", lgr)
clientPublicKey = pub
if err != nil {
lgr.Fatalf("generateCert failed: %s", err)
}
cert, err := tls.X509KeyPair(certPem, privPem)
if err != nil {
lgr.Fatalf("X509KeyPair failed: %s", err)
}
conf := &tls.Config{
Certificates: []tls.Certificate{cert},
InsecureSkipVerify: true,
}
for {
conn, err = tls.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", port1), conf)
if err == nil {
break
}
}
defer conn.Close()
err = conn.Handshake()
if err != nil {
lgr.Fatalf("Handshake: %s", err)
}
certs := conn.ConnectionState().PeerCertificates
if len(certs) < 1 {
lgr.Fatalf("peer certificates: %v", certs)
}
publicKeySeenByClient = certs[0].RawSubjectPublicKeyInfo
done <- nil
n, err := conn.Write([]byte("hello\n"))
if err != nil {
lgr.Fatalf("conn.Write: %d, %s", n, err)
}
buf := make([]byte, 100)
n, err = conn.Read(buf)
if err != nil {
lgr.Fatalf("conn.Read: %d, %s", n, err)
}
lgr.Printf("msg: %q\n", string(buf[:n]))
}
func Mitm() {
mitmLgr = log.New(os.Stderr, "mitm ", log.LstdFlags|log.Lshortfile)
lgr := mitmLgr
pub, certPem, privPem, err := generateCert("mitm", lgr)
mitmPublicKey = pub
if err != nil {
lgr.Fatalf("generateCert failed: %s", err)
}
cert, err := tls.X509KeyPair(certPem, privPem)
if err != nil {
lgr.Fatalf("X509KeyPair failed: %s", err)
}
ln, err := tls.Listen("tcp", fmt.Sprintf(":%d", port1),
&tls.Config{
Certificates: []tls.Certificate{cert},
ClientAuth: tls.RequireAnyClientCert,
})
if err != nil {
lgr.Fatalf("Listen: %s", err)
}
defer ln.Close()
for {
inConn, err := ln.Accept()
if err != nil {
lgr.Fatalf("Accept: %s", err)
}
go handleMitmConnection(inConn, cert)
}
}
func handleMitmConnection(inConn net.Conn, cert tls.Certificate) {
lgr := mitmLgr
defer inConn.Close()
var (
outConn *tls.Conn
err error
)
config := &tls.Config{
InsecureSkipVerify: true,
Certificates: []tls.Certificate{cert},
}
for {
outConn, err = tls.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", port0), config)
if err == nil {
break
}
}
defer outConn.Close()
r := bufio.NewReader(inConn)
for {
msg, err := r.ReadString('\n')
if err != nil {
lgr.Fatalf("inConn.ReadString: %s", err)
}
lgr.Printf("msg: %q", msg)
_, err = outConn.Write([]byte(msg))
if err != nil {
lgr.Fatalf("outConn.Write: %s", err)
}
}
}
package network
import (
"bytes"
"fmt"
"testing"
)
func TestMitm(t *testing.T) {
c := make(chan interface{})
go Server(c)
go Mitm()
go Client(c)
<-c
<-c
fmt.Printf("server public key: %+v\n\n", serverPublicKey)
fmt.Printf("public key seen by client: %+v\n\n", publicKeySeenByClient)
fmt.Printf("client public key: %+v\n\n", clientPublicKey)
fmt.Printf("public key seen by server: %+v\n\n", publicKeySeenByServer)
fmt.Printf("mitm public key: %+v\n\n", mitmPublicKey)
if bytes.Equal(serverPublicKey, mitmPublicKey) {
t.Fail()
}
if bytes.Equal(clientPublicKey, mitmPublicKey) {
t.Fail()
}
if !bytes.Equal(mitmPublicKey, publicKeySeenByServer) {
t.Fail()
}
if !bytes.Equal(mitmPublicKey, publicKeySeenByClient) {
t.Fail()
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment