PoC: Offloading L3 FIB into a NIC hardware through tc
#!/usr/bin/env python3 | |
import pyroute2 | |
import socket | |
from pyroute2.netlink import rtnl | |
import subprocess | |
import time | |
from operator import itemgetter | |
class TCL3Switch(object): | |
def __init__(self, l3mdev_ifname, block=1, chain=0): | |
self._ipr = pyroute2.IPRoute() | |
self._l3mdev_ifindex = self._get_ifindex(l3mdev_ifname) | |
self._vrf_table_id = self._get_vrf_table_id(self._l3mdev_ifindex) | |
self._l3mdev_slaves = self._get_l3mdev_slaves(self._l3mdev_ifindex) | |
self._block = block | |
self._chain = chain | |
def _get_ifindex(self, ifname): | |
try: | |
return self._ipr.link_lookup(ifname=ifname)[0] | |
except IndexError as e: | |
raise ValueError(f"ifname {ifname} is missing") from e | |
def _get_vrf_table_id(self, ifindex): | |
link = self._ipr.get_links(ifindex)[0] | |
for linkinfo in link.get_attrs("IFLA_LINKINFO"): | |
if linkinfo.get_attr("IFLA_INFO_KIND") == "vrf": | |
table_id = linkinfo.get_attr("IFLA_INFO_DATA").get_attr("IFLA_VRF_TABLE") | |
if table_id is not None: | |
return table_id | |
raise ValueError(f"Failed to find VRF table id for ifindex {l3mdev_ifindex}") | |
def _get_l3mdev_slaves(self, ifindex): | |
slaves = self._ipr.get_links(*self._ipr.link_lookup(master=ifindex)) | |
slave_map = dict([(l["index"], l.get_attr("IFLA_IFNAME")) for l in slaves]) | |
return slave_map | |
def _build_neighbour_flows(self): | |
neighbours = self._ipr.get_neighbours(state=rtnl.ndmsg.NUD_REACHABLE, family=socket.AF_INET) | |
for neigh in neighbours: | |
if neigh["ifindex"] in self._l3mdev_slaves.keys(): | |
flow = dict( | |
dst=neigh.get_attr("NDA_DST"), | |
dst_len=32, | |
action="redirect", | |
redirect_mac=neigh.get_attr("NDA_LLADDR"), | |
redirect_ifindex=neigh["ifindex"], | |
) | |
yield flow | |
def _build_route_flows(self): | |
routes = self._ipr.get_routes(table=self._vrf_table_id, family=socket.AF_INET) | |
for route in routes: | |
dst = route.get_attr("RTA_DST") | |
dst_len = route["dst_len"] | |
oif = route.get_attr("RTA_OIF") | |
gateway = route.get_attr("RTA_GATEWAY") | |
if dst is None and dst_len == 0: | |
# default route | |
dst = "0.0.0.0" | |
if route["type"] == rtnl.rt_type["unicast"] and gateway and oif: | |
try: | |
gateway_neigh = self._ipr.get_neighbours(state=rtnl.ndmsg.NUD_REACHABLE, ifindex=oif, | |
dst=gateway)[0] | |
flow = dict( | |
dst=dst, | |
dst_len=dst_len, | |
action="redirect", | |
redirect_mac=gateway_neigh.get_attr("NDA_LLADDR"), | |
redirect_ifindex=oif, | |
) | |
yield flow | |
except IndexError: | |
pass | |
def _build_flows(self): | |
flows = list(self._build_neighbour_flows()) + list(self._build_route_flows()) | |
flows.sort(key=itemgetter("dst_len"), reverse=True) | |
return flows | |
def _generate_flower_filters(self): | |
flows = self._build_flows() | |
for flow in flows: | |
command = f"flower dst_ip {flow['dst']}/{flow['dst_len']}" | |
if flow["action"] == "redirect": | |
ifname = self._l3mdev_slaves[flow['redirect_ifindex']] | |
command += f" action pedit ex munge eth dst set {flow['redirect_mac']}" | |
command += f" pipe mirred egress redirect dev {ifname}" | |
yield command | |
def set_ingress_qdisc(self): | |
for ifname in self._l3mdev_slaves.values(): | |
command = f"tc qdisc add dev {ifname} ingress_block {self._block} ingress" | |
subprocess.run(command, shell=True) | |
def install_filters(self, pref_start): | |
filters = list(self._generate_flower_filters()) | |
for pref, flower_filter in enumerate(filters, start=pref_start): | |
command = f"tc filter add block {self._block} protocol ip chain {self._chain} pref {pref} {flower_filter}" | |
subprocess.run(command, shell=True) | |
return len(filters) | |
def delete_filters(self, pref_start, num): | |
for pref in range(pref_start + num - 1, pref_start - 1, -1): | |
command = f"tc filter del block {self._block} protocol ip chain {self._chain} pref {pref}" | |
subprocess.run(command, shell=True) | |
def run(self, pref_offset=(1, 1001)): | |
pref_index = 0 | |
pref_start = pref_offset[pref_index] | |
num_old = 0 | |
while True: | |
num_new = self.install_filters(pref_start) | |
pref_index = (pref_index + 1) % 2 | |
pref_start = pref_offset[pref_index] | |
self.delete_filters(pref_start, num_old) | |
num_old = num_new | |
time.sleep(1) | |
if __name__ == '__main__': | |
import sys | |
l3mdev_ifname = sys.argv[1] | |
l3sw = TCL3Switch(l3mdev_ifname) | |
l3sw.set_ingress_qdisc() | |
l3sw.run() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment