Skip to content

Instantly share code, notes, and snippets.

@jstangroome
Created September 26, 2023 07:42
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jstangroome/d9c74c92380d43d1558f75a33591702a to your computer and use it in GitHub Desktop.
Save jstangroome/d9c74c92380d43d1558f75a33591702a to your computer and use it in GitHub Desktop.
package main
import (
"crypto/hmac"
"crypto/sha256"
"encoding/binary"
"encoding/hex"
"flag"
"fmt"
"log"
"net"
"net/netip"
"os"
"time"
)
func calculatePortKnockHash(ip netip.Addr, ts time.Time, secret []byte) ([]byte, error) {
if !ip.Is4() {
return nil, fmt.Errorf("only ipv4 supported")
}
h := hmac.New(sha256.New, secret)
_, err := h.Write(ip.AsSlice())
if err != nil {
return nil, fmt.Errorf("failed to write ip to hmac: %w", err)
}
epochMinutes := ts.UTC().Unix() / 60
epochBytes := make([]byte, 4)
binary.LittleEndian.PutUint32(epochBytes, uint32(epochMinutes))
_, err = h.Write(epochBytes)
if err != nil {
return nil, fmt.Errorf("failed to write time to hmac: %w", err)
}
return h.Sum(nil), nil
}
func sendPacket(srcIP netip.Addr, dstAddr net.Addr, payload []byte) error {
conn, err := net.ListenPacket("udp", srcIP.String()+":")
if err != nil {
return fmt.Errorf("failed to create connection context: %w", err)
}
defer conn.Close()
n, err := conn.WriteTo(payload, dstAddr)
if err != nil {
return fmt.Errorf("failed to write packet: %w", err)
}
if n != len(payload) {
return fmt.Errorf("only wrote %d of %d bytes", n, len(payload))
}
return nil
}
func main() {
var srcIPText string
var dstAddrText string
flag.StringVar(&srcIPText, "src", "", "source IP address")
flag.StringVar(&dstAddrText, "dst", "", "destination IP address and port, colon delimited")
secret := os.Getenv("PKNOCK_SECRET")
if secret == "" {
log.Fatalf("[ERROR] PKNOCK_SECRET environment variable required.")
}
flag.Parse()
if srcIPText == "" {
flag.PrintDefaults()
log.Fatalf("[ERROR] -src required.")
}
srcIP, err := netip.ParseAddr(srcIPText)
if err != nil {
log.Fatalf("[ERROR] %v", err)
}
hashBytes, err := calculatePortKnockHash(srcIP, time.Now(), []byte("foo"))
if err != nil {
log.Fatalf("[ERROR] %v", err)
}
hashHex := hex.EncodeToString(hashBytes)
if dstAddrText == "" {
fmt.Println(hashHex)
return
}
dstAddr, err := net.ResolveUDPAddr("udp", dstAddrText)
if err != nil {
log.Fatalf("[ERROR] %v", err)
}
err = sendPacket(srcIP, dstAddr, []byte(hashHex))
if err != nil {
log.Fatalf("[ERROR] %v", err)
}
}
package main
import (
"encoding/hex"
"net/netip"
"testing"
"time"
)
func TestNetworkByteOrderOfIPAddress(t *testing.T) {
ip := netip.MustParseAddr("192.0.2.1")
expectedBytes := []byte{192, 0, 2, 1}
actualBytes := ip.AsSlice()
if len(actualBytes) != len(expectedBytes) {
t.Fatalf("got %v, want %v", actualBytes, expectedBytes)
}
for i := range expectedBytes {
if actualBytes[i] != expectedBytes[i] {
t.Fatalf("got %v, want %v", actualBytes, expectedBytes)
}
}
}
func TestCalculatePortKnockHash(t *testing.T) {
const expectedHex = "7fb7ec169ef70340600e4aab04908d269f11ac6a39704604a0412af38e88f383"
ip := netip.MustParseAddr("192.0.2.1")
ts := time.Date(2023, 9, 25, 11, 12, 13, 0, time.UTC)
secret := []byte("foo")
actualBytes, err := calculatePortKnockHash(ip, ts, secret)
if err != nil {
t.Fatalf("error: %v", err)
}
actualHex := hex.EncodeToString(actualBytes)
if actualHex != expectedHex {
t.Errorf("got '%s', want '%s'", actualHex, expectedHex)
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment