Skip to content

Instantly share code, notes, and snippets.

@pikulet
Created December 20, 2019 07:37
Show Gist options
  • Save pikulet/e8de2664e01be4542b810af2e2de833f to your computer and use it in GitHub Desktop.
Save pikulet/e8de2664e01be4542b810af2e2de833f to your computer and use it in GitHub Desktop.
pox controller example
# Pox controller
# 1. Self-learning switch
# 2. Quality of Service (queues)
# 3. Firewall
import sys
import os
import time
from sets import Set
from pox.core import core
import pox.openflow.libopenflow_01 as of
import pox.openflow.discovery
import pox.openflow.spanning_tree
import pox.lib.packet as pkt
from pox.lib.revent import *
from pox.lib.util import dpid_to_str
from pox.lib.addresses import IPAddr, EthAddr
log = core.getLogger()
#################### CONFIGURATION ####################
# policy file
POLICY_FILE = "policy.in"
# Switch parameters
REFRESH_INTERVAL = 5
TTL = 30
#######################################################
class Controller(EventMixin):
def __init__(self):
self.listenTo(core.openflow)
core.openflow_discovery.addListeners(self)
self.mac_to_port = dict()
self.VLAN_PRIORITY = 1000
self.PREMIUM_PRIORITY = 500
self.ip_to_vlan = dict()
self.firewall_policies = set()
self.premiums = set()
self.readPolicies()
# clears the switch table every refresh interval
def loopRefreshSwitchTable():
self.refreshSwitchTable()
core.callDelayed(REFRESH_INTERVAL, loopRefreshSwitchTable)
loopRefreshSwitchTable()
# read policies
def readPolicies(self):
log.debug("*** SETUP: Reading policy files")
with open(POLICY_FILE, "r") as f:
data = f.read().splitlines()
n, m = [int(x) for x in data[0].split()]
n_hosts_in_vlan = [int(x) for x in data[1].split()]
ptr = 2
# for each vlan
for vlan_id in range(1, n+1):
# avoid default vlan 0
# add hosts
for i in range(n_hosts_in_vlan[vlan_id - 1]):
ip_addr = data[ptr]
self.ip_to_vlan[ip_addr] = vlan_id
ptr += 1
def get_vlan(host):
return self.ip_to_vlan[host]
def addFirewallPolicy(host1, host2):
fw_msg = of.ofp_flow_mod()
fw_msg.priority = self.VLAN_PRIORITY
# link layer
fw_msg.match.dl_type = 0x800
# network layer
fw_msg.match.nw_proto = pkt.ipv4.TCP_PROTOCOL
fw_msg.match.nw_src = IPAddr(host1)
fw_msg.match.nw_dst = IPAddr(host2)
# transport layer
fw_msg.match.tp_dst = 4001
fw_msg.actions.append(of.ofp_action_output(port = of.OFPP_NONE))
self.firewall_policies.add(fw_msg)
for host1 in self.ip_to_vlan.keys():
for host2 in self.ip_to_vlan.keys():
if get_vlan(host1) != get_vlan(host2):
addFirewallPolicy(host1, host2)
# premium services
for i in range(m):
self.premiums.add(data[ptr])
ptr += 1
log.debug("SETUP: premium hosts %s" % self.premiums)
# refreshes the switch table
def refreshSwitchTable(self):
log.debug("*** REFRESH: Refreshing switch table")
current_time = time.time()
for dpid in self.mac_to_port:
switch_table = self.mac_to_port[dpid]
for dest in switch_table.copy():
port, timestamp = switch_table[dest]
elapsed_time = current_time - timestamp
# ttl exceeded
if elapsed_time >= TTL:
log.debug("*** REFRESH: Found expired entry: %s at port %s, elapsed time: %s" % (dest, port, elapsed_time))
del switch_table[dest]
def _handle_PacketIn (self, event):
packet = event.parsed
dpid = event.dpid
src = packet.src
dst = packet.dst
inport = event.port
log.debug("*** EVENT: Received packet at s%s" % dpid)
def flood():
log.debug("--- Flooding packet")
msg = of.ofp_packet_out() # no flow table entry installed
msg.data = event.ofp
msg.actions.append(of.ofp_action_output(port = of.OFPP_FLOOD))
event.connection.send(msg)
def install_enqueue(event, packet, outport, dstip, qid):
msg = of.ofp_flow_mod() # install a flow table entry
msg.data = event.ofp
msg.priority = self.PREMIUM_PRIORITY
msg.match = of.ofp_match()
# link layer
msg.match.dl_dst = dst
msg.match.dl_type = 0x800
# network layer
msg.match.nw_dst = dstip
msg.hard_timeout = TTL # timeout
msg.actions.append(of.ofp_action_enqueue(port=outport, queue_id=qid))
event.connection.send(msg)
current_time = time.time()
# update switch table
self.mac_to_port[dpid][src] = inport, current_time
if dst.is_multicast:
flood()
return
# unknown outport, flood
if dst not in self.mac_to_port[dpid]:
flood()
return
# retrieve outport
outport, timestamp = self.mac_to_port[dpid][dst]
elapsed_time = current_time - timestamp
if elapsed_time >= TTL:
# ttl exceeded
flood()
return
# Check for premium service
srcip = None
dstip = None
if packet.type == packet.IP_TYPE:
ip_packet = packet.payload
srcip = ip_packet.srcip
dstip = ip_packet.dstip
elif packet.type == packet.ARP_TYPE:
arp_packet = packet.payload
srcip = arp_packet.protosrc
dstip = arp_packet.protodst
log.debug("--- Source IP: %s Destination IP: %s" % (srcip, dstip))
def is_premium():
return str(dstip) in self.premiums
qid = 0
if is_premium():
qid = 1
log.debug("--- Packet is sent to queue: %d" % qid)
install_enqueue(event, packet, outport, dstip, qid)
def _handle_ConnectionUp(self, event):
dpid = event.dpid
log.debug("SETUP: Switch %s has come up.", dpid)
# initialise switch table
self.mac_to_port[dpid] = dict()
for fw_msg in self.firewall_policies:
event.connection.send(fw_msg)
def launch():
# Run discovery and spanning tree modules
pox.openflow.discovery.launch()
pox.openflow.spanning_tree.launch()
# Starting the controller module
core.registerNew(Controller)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment