Skip to content

Instantly share code, notes, and snippets.

@korc
Last active December 21, 2023 02:46
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 2 You must be signed in to fork a gist
  • Save korc/48b183723eecf5d1537e4822e6ca57b5 to your computer and use it in GitHub Desktop.
Save korc/48b183723eecf5d1537e4822e6ca57b5 to your computer and use it in GitHub Desktop.
PKCS11-authenticated TLS socket proxy
package main
import (
"crypto/tls"
"flag"
"fmt"
"io"
"log"
"net"
"net/textproto"
"os"
"sync"
"syscall"
"github.com/ThalesIgnite/crypto11"
"github.com/miekg/pkcs11"
"golang.org/x/crypto/ssh/terminal"
)
func main() {
tokenSerial := flag.String("token-serial", os.Getenv("TOKEN_SERIAL"), "token serial number")
remoteAddr := flag.String("remote", "", "remote server address as host:port")
pkcs11Lib := flag.String("pkcs11lib", "/usr/lib/x86_64-linux-gnu/opensc-pkcs11.so", "PKCS#11 library")
certLabel := flag.String("cert-label", "Certificate for PIV Authentication", "PKCS#11 certificate label")
listenAddr := flag.String("listen", "", "listen address")
startTLS := flag.String("starttls", "", "Start TLS with protocol (empty or 'smtp')")
keyPairIndex := flag.Int("key-idx", 0, "Key index (default: 0)")
flag.Parse()
if *tokenSerial == "" {
log.Fatal("Need to set -token-serial via commandline option or $TOKEN_SERIAL environment variable")
}
if *remoteAddr == "" {
log.Fatal("Need remote server via -remote option")
}
if *listenAddr == "" {
log.Fatal("Need listener address via -listen option")
}
tokenPin := os.Getenv("TOKEN_PIN")
if tokenPin == "" {
print("PIN for " + *tokenSerial + ": ")
bPin, err := terminal.ReadPassword(syscall.Stdin)
if err != nil {
log.Fatal("Could not read PIN code: ", err)
}
tokenPin = string(bPin)
if tokenPin == "" {
log.Fatal("Need to enter PIN or set via $TOKEN_PIN environment variable")
} else {
print("[OK]\n")
}
}
ctx, err := crypto11.Configure(&crypto11.Config{
Path: *pkcs11Lib,
TokenSerial: *tokenSerial,
Pin: tokenPin,
})
if err != nil {
log.Fatal("Could not configure crypto11: ", err)
}
defer ctx.Close()
kp, err := ctx.FindAllKeyPairs()
if err != nil {
log.Fatalf("Could not find any key pairs: %s", err)
}
if len(kp) == 0 {
log.Fatal("Could not find any key pairs")
}
log.Printf("Found %d key pairs:", len(kp))
for _, keyPair := range kp {
log.Printf(" %#v", keyPair)
}
if len(kp) <= *keyPairIndex {
log.Fatalf("Selected key-pair index (%d) not available.", *keyPairIndex)
}
crt, err := ctx.FindCertificate(nil, []byte(*certLabel), nil)
if err != nil {
log.Fatal("Could not search for certificates: ", err)
}
if crt == nil {
log.Fatal("No certificate found. Perhaps wrong -cert-label ?")
} else {
log.Printf("Certificate Serial: %x Subject: %s", crt.SerialNumber, crt.Subject)
}
remoteHost, _, err := net.SplitHostPort(*remoteAddr)
if err != nil {
log.Fatal("Could not split remote address: ", err)
}
tlsConfig := &tls.Config{
Certificates: []tls.Certificate{
{
Certificate: [][]byte{crt.Raw},
PrivateKey: kp[*keyPairIndex],
},
},
ServerName: remoteHost,
}
listener, err := net.Listen("tcp", *listenAddr)
if err != nil {
log.Fatal("Could not listen: ", err)
}
for {
client, err := listener.Accept()
if err != nil {
log.Fatal("Could not accept client: ", err)
}
clientAddr := client.RemoteAddr()
log.Printf("Accepted client: %s", clientAddr)
go func(client net.Conn) {
defer client.Close()
conn, err := net.Dial("tcp", *remoteAddr)
if err != nil {
log.Printf("Could not connect to remote %s: %s", *remoteAddr, err)
return
}
defer conn.Close()
if *startTLS == "smtp" {
text := textproto.NewConn(conn)
bannerCode, banner, err := text.ReadResponse(220)
if err != nil {
log.Printf("Could not read SMTP banner: %s", err)
return
}
log.Printf("Got code: %d, message: %#v", bannerCode, banner)
cmdId, err := text.Cmd("STARTTLS")
if err != nil {
log.Printf("Could not STARTTLS: %s", err)
return
}
text.StartResponse(cmdId)
code, msg, err := text.ReadResponse(220)
if err != nil {
log.Printf("Could read status for STARTTLS: %s", err)
return
}
log.Printf("Response to STARTTLS: %d, %#v", code, msg)
text.EndResponse(cmdId)
client.Write([]byte(fmt.Sprintf("%d %s\r\n", bannerCode, banner)))
}
tlsConn := tls.Client(conn, tlsConfig)
if err := tlsConn.Handshake(); err != nil {
log.Printf("Could not do TLS handshake: [%T] %s", err, err)
if pErr, ok := err.(pkcs11.Error); ok {
if pErr == pkcs11.CKR_GENERAL_ERROR {
log.Fatal("PKCS#11 CKR_GENERAL_ERROR detected, bailing out")
}
}
return
}
defer tlsConn.Close()
wg := sync.WaitGroup{}
wg.Add(2)
go func() {
defer wg.Done()
defer client.Close()
defer tlsConn.Close()
io.Copy(tlsConn, client)
}()
go func() {
defer wg.Done()
defer client.Close()
defer tlsConn.Close()
io.Copy(client, tlsConn)
}()
wg.Wait()
log.Printf("Finished client: %s", clientAddr)
}(client)
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment