Skip to content

Instantly share code, notes, and snippets.

@jasonrm
Last active December 11, 2018 06:16
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jasonrm/00658f4671a5f9591579d9661f65d6fc to your computer and use it in GitHub Desktop.
Save jasonrm/00658f4671a5f9591579d9661f65d6fc to your computer and use it in GitHub Desktop.
Listen for "unreachable - need to frag" ICMP packets, update local route, and re-broadcast to all interfaces
#!/usr/bin/python2
# Based on: https://tools.ietf.org/html/rfc7690#section-3.2
# More info: https://blog.cloudflare.com/path-mtu-discovery-in-practice/
# Also: https://github.com/cloudflare/pmtud
# Dependencies: python2, netifaces, scapy
from scapy.all import Ether, ICMP, IP, IPerror, TCPerror, sendp, sniff, conf
from subprocess import call
from collections import deque
from os import urandom
import hmac
import netifaces
import time
class PacketHash:
def __init__(self):
self.hmac_secret = urandom(16)
def __call__(self, pkt):
checksum = hmac.new(key=self.hmac_secret)
checksum.update(pkt[Ether][IPerror].dst)
return checksum.hexdigest()
class HashRateLimiter:
def __init__(self, maxRate=1, timeUnit=1, maxSlots=8192):
self.maxRate = maxRate
self.timeUnit = timeUnit
self.maxSlots = maxSlots
self.queues = {}
def __call__(self, packet_hash):
slot = int(int(packet_hash, 16) % self.maxSlots)
try:
queue = self.queues[slot]
except KeyError as e:
queue = self.queues[slot] = deque(maxlen=self.maxRate)
if queue.maxlen == len(queue):
cTime = time.time()
if cTime - queue[0] > self.timeUnit:
queue.append(cTime)
return False
else:
return True
queue.append(time.time())
return False
def refresh_interfaces():
global last_refresh, interfaces
if time.time() - last_refresh > 300:
interfaces = [x for x in netifaces.interfaces() if x.startswith('en')]
last_refresh = time.time()
print("Found interfaces: %s" % ', '.join(interfaces))
def icmp_callback(pkt):
if rate_limited(packet_hash(pkt)):
print("action=rate_limit packet_hash=%s" % (packet_hash(pkt)))
return
update_route_mtu(pkt)
if pkt[Ether][IP].ttl > 0:
refresh_interfaces()
broadcast(pkt)
def broadcast(pkt):
# Mangle Ether
del(pkt[Ether].src)
pkt[Ether].dst = 'ff:ff:ff:ff:ff:ff'
# Mangle IP
pkt[Ether][IP].ttl = 0
del(pkt[Ether][IP].chksum)
for iface in interfaces:
try:
sendp(pkt, iface=iface)
icmp = pkt[Ether][IP][ICMP]
print("action=broadcast interface=%s type=%s code=%s nexthopmtu=%d" % (iface, icmp.type, icmp.code, icmp.nexthopmtu))
except OSError as e:
pass
def update_route_mtu(pkt):
mtu = pkt[Ether][IP][ICMP].nexthopmtu
if mtu < 68 or mtu > 16386:
print("action=reject reason=mtu_stupid nexthopmtu=%d packet_hash=%s" % (mtu, packet_hash(pkt)))
return
# TODO: Check if interface has jumbo frames enabled
if mtu < 576 or mtu >= 1500:
print("action=reject reason=mtu_bogus nexthopmtu=%d packet_hash=%s" % (mtu, packet_hash(pkt)))
return
prefix = "%s/32" % (pkt[Ether][IPerror].dst)
# TODO: This route never expires as the expires option currently only works for IPv6...
for iface in interfaces:
call(["ip", "route", "delete", prefix, "dev", iface, "mtu", str(mtu)])
call(["ip", "route", "append", prefix, "dev", iface, "mtu", str(mtu)])
print("action=learn mtu=%s prefix=%s packet_hash=%s" % (mtu, prefix, packet_hash(pkt)))
last_refresh = 0
interfaces = []
conf.verb = 0
rate_limited = HashRateLimiter(1, 60)
packet_hash = PacketHash()
def main():
refresh_interfaces()
print("Listening and re-broadcasting on: %s" % ', '.join(interfaces))
sniff(prn=icmp_callback, filter="(icmp and icmp[0] == 3 and icmp[1] == 4) and inbound", store=0)
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment