Skip to content

Instantly share code, notes, and snippets.

@chronos-tachyon
Last active July 8, 2022 01:16
Show Gist options
  • Save chronos-tachyon/6e1ea73bafd6bd249b18efbb909dcd6b to your computer and use it in GitHub Desktop.
Save chronos-tachyon/6e1ea73bafd6bd249b18efbb909dcd6b to your computer and use it in GitHub Desktop.
Script to synchronize Wireguard configs across a fleet from a central command-and-control host.
#!/usr/bin/env python3
#
# Written by Donald King <chronos@chronos-tachyon.net>
# Public Domain.
#
# ==== CC0 https://creativecommons.org/publicdomain/zero/1.0/ ====
# [To the extent possible under law, I have waived all copyright ]
# [and related or neighboring rights to this script. This work is]
# [published from the United States of America. ]
import argparse
import collections
import grp
import json
import os
import pathlib
import subprocess
import sys
import tempfile
MODE_WG_QUICK = 'wg-quick'
MODE_NETWORKD = 'networkd'
UID_ROOT = 0
GID_ROOT = 0
GID_NETWORKD = grp.getgrnam('systemd-network').gr_gid
NET_DIR = pathlib.Path('/etc/systemd/network')
WG_DIR = pathlib.Path('/etc/wireguard')
HOSTS_FILE = WG_DIR / 'hosts.json'
RULES_FILE = WG_DIR / 'rules.json'
PRIVATE_DIR = WG_DIR / 'private'
PUBLIC_DIR = WG_DIR / 'public'
class Rule(collections.namedtuple('Rule', ['hosts', 'old_endpoint', 'new_endpoint'])):
def apply(self, arg):
old = self.old_endpoint
new = self.new_endpoint
return new if arg == old else arg
class Host(collections.namedtuple('Host', ['name', 'enabled', 'ipv4', 'ipv6', 'endpoint', 'mode', 'installer'])):
@property
def public_key(self):
keyfile = PUBLIC_DIR / self.name
with open(keyfile) as fp:
line = fp.readline()
return line.strip()
@property
def private_key(self):
keyfile = PRIVATE_DIR / self.name
with open(keyfile) as fp:
line = fp.readline()
return line.strip()
def run(*argv):
argv = list(map(str, argv))
print('+ {!r}'.format(argv), file=sys.stderr)
if not args.dry_run:
p = subprocess.run(argv)
if p.returncode != 0:
sys.exit(1)
def generate_config_for_host(target, tmpdir):
wg_quick_content = (
'# vi' 'm:set ft=conf:\n'
'\n'
'[Interface]\n'
'ListenPort = 51820\n'
'PrivateKey = {privkey}\n'
'Address = {ipv4}/24, {ipv6}/64\n'
'MTU = 1380\n'
).format(
ipv4=target.ipv4,
ipv6=target.ipv6,
privkey=target.private_key,
)
networkd_content = (
'# vi' 'm:set ft=systemd:\n'
'\n'
'[NetDev]\n'
'Name=wg0\n'
'Kind=wireguard\n'
'Description=WireGuard tunnel wg0\n'
'\n'
'[WireGuard]\n'
'ListenPort=51820\n'
'PrivateKeyFile=/etc/wireguard/wg0.privkey\n'
)
privkey_content = (
'{privkey}\n'
).format(
privkey=target.private_key,
)
for host in hosts:
enabled = host.enabled and (host is not target)
prefix = '' if enabled else '#'
kwargs = {
'prefix': prefix,
'name': host.name,
'ipv4': host.ipv4,
'ipv6': host.ipv6,
'pubkey': host.public_key,
}
wg_quick_content += (
'\n'
'{prefix}# {name}\n'
'{prefix}[Peer]\n'
'{prefix}PublicKey = {pubkey}\n'
'{prefix}AllowedIPs = {ipv4}/32, {ipv6}/128\n'
).format(**kwargs)
networkd_content += (
'\n'
'{prefix}# {name}\n'
'{prefix}[WireGuardPeer]\n'
'{prefix}PublicKey={pubkey}\n'
'{prefix}AllowedIPs={ipv4}/32\n'
'{prefix}AllowedIPs={ipv6}/128\n'
).format(**kwargs)
endpoint = host.endpoint
if endpoint:
for rule in rules:
if target.name in rule.hosts:
endpoint = rule.apply(endpoint)
kwargs.update(ep=endpoint, port=51820)
wg_quick_content += (
'{prefix}Endpoint = {ep}:{port}\n'
'{prefix}PersistentKeepalive = 120\n'
).format(**kwargs)
networkd_content += (
'{prefix}Endpoint={ep}:{port}\n'
'{prefix}PersistentKeepalive=120\n'
).format(**kwargs)
tmpfile0 = tmpdir / 'wg0.conf'
with open(tmpfile0, 'wt') as fp:
os.fchown(fp.fileno(), UID_ROOT, GID_ROOT)
os.fchmod(fp.fileno(), 0o600)
fp.write(wg_quick_content)
tmpfile1 = tmpdir / 'wg0.netdev'
with open(tmpfile1, 'wt') as fp:
os.fchown(fp.fileno(), UID_ROOT, GID_ROOT)
os.fchmod(fp.fileno(), 0o644)
fp.write(networkd_content)
tmpfile2 = tmpdir / 'wg0.privkey'
with open(tmpfile2, 'wt') as fp:
os.fchown(fp.fileno(), UID_ROOT, GID_NETWORKD)
os.fchmod(fp.fileno(), 0o640)
fp.write(privkey_content)
def install_none(host):
print('Skipping {}...'.format(host.name))
print()
def install_local(host):
print('Installing {} (local)...'.format(host.name))
with tempfile.TemporaryDirectory() as tmpdir:
tmpdir = pathlib.Path(tmpdir)
generate_config_for_host(host, tmpdir)
if host.mode == MODE_WG_QUICK:
srcfile = tmpdir / 'wg0.conf'
dstfile = WG_DIR / 'wg0.conf'
run('cp', '-a', srcfile, dstfile)
else:
srcfile0 = tmpdir / 'wg0.netdev'
srcfile1 = tmpdir / 'wg0.privkey'
dstfile0 = NET_DIR / '20-wg0.netdev'
dstfile1 = WG_DIR / 'wg0.privkey'
run('cp', '-a', srcfile0, dstfile0)
run('cp', '-a', srcfile1, dstfile1)
print('+ sync', file=sys.stderr)
if not args.dry_run:
os.sync()
if host.mode == MODE_WG_QUICK:
run('systemctl', 'restart', 'wg-quick@wg0')
else:
run('networkctl', 'reload')
run('networkctl', 'reconfigure', 'wg0')
print()
def install_ssh(host):
print('Installing {} (SSH)...'.format(host.name))
dsthost = host.name + '.vpn'
with tempfile.TemporaryDirectory() as tmpdir:
tmpdir = pathlib.Path(tmpdir)
generate_config_for_host(host, tmpdir)
if host.mode == MODE_WG_QUICK:
srcfile = tmpdir / 'wg0.conf'
dstfile = WG_DIR / 'wg0.conf'
dstremote = '{}:{}'.format(dsthost, dstfile)
run('rsync', '-a', srcfile, dstremote)
run('ssh', dsthost, 'systemctl', 'restart', 'wg-quick@wg0')
else:
srcfile0 = tmpdir / 'wg0.netdev'
srcfile1 = tmpdir / 'wg0.privkey'
dstfile0 = NET_DIR / '20-wg0.netdev'
dstfile1 = WG_DIR / 'wg0.privkey'
dstremote0 = '{}:{}'.format(dsthost, dstfile0)
dstremote1 = '{}:{}'.format(dsthost, dstfile1)
run('rsync', '-a', srcfile0, dstremote0)
run('rsync', '-a', srcfile1, dstremote1)
run('ssh', dsthost, 'networkctl', 'reload')
run('ssh', dsthost, 'networkctl', 'reconfigure', 'wg0')
print()
INSTALLERS = {
None: install_none,
'local': install_local,
'ssh': install_ssh,
}
parser = argparse.ArgumentParser()
parser.add_argument(
'-n',
'--dry-run',
action='store_true',
help='do not act',
)
parser.add_argument(
'--only',
metavar='HOST',
default=[],
action='append',
help='list of hosts to sync; default is all hosts',
)
args = parser.parse_args()
hosts = []
hosts_by_name = {}
with open(HOSTS_FILE, 'rt') as fp:
for row in json.load(fp):
host_name = row['name']
host_enabled = row.get('enabled', True)
host_ipv4 = row['ipv4']
host_ipv6 = row['ipv6']
host_endpoint = row.get('endpoint', None)
host_mode = row.get('mode', MODE_WG_QUICK)
host_installer = row.get('installer', None)
if host_mode not in (MODE_WG_QUICK, MODE_NETWORKD):
raise ValueError('unknown mode {!r} for host {!r}'.format(host_mode, host_name))
if host_installer not in INSTALLERS:
raise ValueError('unknown installer {!r} for host {!r}'.format(host_installer, host_name))
host = Host(host_name, host_enabled, host_ipv4, host_ipv6, host_endpoint, host_mode, host_installer)
hosts.append(host)
hosts_by_name[host_name] = host
rules = []
with open(RULES_FILE, 'rt') as fp:
for row in json.load(fp):
rule_hosts = row['hosts']
rule_old = row['old_endpoint']
rule_new = row['new_endpoint']
rule = Rule(rule_hosts, rule_old, rule_new)
rules.append(rule)
for host in hosts:
if host.enabled and (not args.only or host.name in args.only):
installer = INSTALLERS[host.installer]
installer(host)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment