Created
December 20, 2019 07:37
-
-
Save pikulet/e8de2664e01be4542b810af2e2de833f to your computer and use it in GitHub Desktop.
pox controller example
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Pox controller | |
# 1. Self-learning switch | |
# 2. Quality of Service (queues) | |
# 3. Firewall | |
import sys | |
import os | |
import time | |
from sets import Set | |
from pox.core import core | |
import pox.openflow.libopenflow_01 as of | |
import pox.openflow.discovery | |
import pox.openflow.spanning_tree | |
import pox.lib.packet as pkt | |
from pox.lib.revent import * | |
from pox.lib.util import dpid_to_str | |
from pox.lib.addresses import IPAddr, EthAddr | |
log = core.getLogger() | |
#################### CONFIGURATION #################### | |
# policy file | |
POLICY_FILE = "policy.in" | |
# Switch parameters | |
REFRESH_INTERVAL = 5 | |
TTL = 30 | |
####################################################### | |
class Controller(EventMixin): | |
def __init__(self): | |
self.listenTo(core.openflow) | |
core.openflow_discovery.addListeners(self) | |
self.mac_to_port = dict() | |
self.VLAN_PRIORITY = 1000 | |
self.PREMIUM_PRIORITY = 500 | |
self.ip_to_vlan = dict() | |
self.firewall_policies = set() | |
self.premiums = set() | |
self.readPolicies() | |
# clears the switch table every refresh interval | |
def loopRefreshSwitchTable(): | |
self.refreshSwitchTable() | |
core.callDelayed(REFRESH_INTERVAL, loopRefreshSwitchTable) | |
loopRefreshSwitchTable() | |
# read policies | |
def readPolicies(self): | |
log.debug("*** SETUP: Reading policy files") | |
with open(POLICY_FILE, "r") as f: | |
data = f.read().splitlines() | |
n, m = [int(x) for x in data[0].split()] | |
n_hosts_in_vlan = [int(x) for x in data[1].split()] | |
ptr = 2 | |
# for each vlan | |
for vlan_id in range(1, n+1): | |
# avoid default vlan 0 | |
# add hosts | |
for i in range(n_hosts_in_vlan[vlan_id - 1]): | |
ip_addr = data[ptr] | |
self.ip_to_vlan[ip_addr] = vlan_id | |
ptr += 1 | |
def get_vlan(host): | |
return self.ip_to_vlan[host] | |
def addFirewallPolicy(host1, host2): | |
fw_msg = of.ofp_flow_mod() | |
fw_msg.priority = self.VLAN_PRIORITY | |
# link layer | |
fw_msg.match.dl_type = 0x800 | |
# network layer | |
fw_msg.match.nw_proto = pkt.ipv4.TCP_PROTOCOL | |
fw_msg.match.nw_src = IPAddr(host1) | |
fw_msg.match.nw_dst = IPAddr(host2) | |
# transport layer | |
fw_msg.match.tp_dst = 4001 | |
fw_msg.actions.append(of.ofp_action_output(port = of.OFPP_NONE)) | |
self.firewall_policies.add(fw_msg) | |
for host1 in self.ip_to_vlan.keys(): | |
for host2 in self.ip_to_vlan.keys(): | |
if get_vlan(host1) != get_vlan(host2): | |
addFirewallPolicy(host1, host2) | |
# premium services | |
for i in range(m): | |
self.premiums.add(data[ptr]) | |
ptr += 1 | |
log.debug("SETUP: premium hosts %s" % self.premiums) | |
# refreshes the switch table | |
def refreshSwitchTable(self): | |
log.debug("*** REFRESH: Refreshing switch table") | |
current_time = time.time() | |
for dpid in self.mac_to_port: | |
switch_table = self.mac_to_port[dpid] | |
for dest in switch_table.copy(): | |
port, timestamp = switch_table[dest] | |
elapsed_time = current_time - timestamp | |
# ttl exceeded | |
if elapsed_time >= TTL: | |
log.debug("*** REFRESH: Found expired entry: %s at port %s, elapsed time: %s" % (dest, port, elapsed_time)) | |
del switch_table[dest] | |
def _handle_PacketIn (self, event): | |
packet = event.parsed | |
dpid = event.dpid | |
src = packet.src | |
dst = packet.dst | |
inport = event.port | |
log.debug("*** EVENT: Received packet at s%s" % dpid) | |
def flood(): | |
log.debug("--- Flooding packet") | |
msg = of.ofp_packet_out() # no flow table entry installed | |
msg.data = event.ofp | |
msg.actions.append(of.ofp_action_output(port = of.OFPP_FLOOD)) | |
event.connection.send(msg) | |
def install_enqueue(event, packet, outport, dstip, qid): | |
msg = of.ofp_flow_mod() # install a flow table entry | |
msg.data = event.ofp | |
msg.priority = self.PREMIUM_PRIORITY | |
msg.match = of.ofp_match() | |
# link layer | |
msg.match.dl_dst = dst | |
msg.match.dl_type = 0x800 | |
# network layer | |
msg.match.nw_dst = dstip | |
msg.hard_timeout = TTL # timeout | |
msg.actions.append(of.ofp_action_enqueue(port=outport, queue_id=qid)) | |
event.connection.send(msg) | |
current_time = time.time() | |
# update switch table | |
self.mac_to_port[dpid][src] = inport, current_time | |
if dst.is_multicast: | |
flood() | |
return | |
# unknown outport, flood | |
if dst not in self.mac_to_port[dpid]: | |
flood() | |
return | |
# retrieve outport | |
outport, timestamp = self.mac_to_port[dpid][dst] | |
elapsed_time = current_time - timestamp | |
if elapsed_time >= TTL: | |
# ttl exceeded | |
flood() | |
return | |
# Check for premium service | |
srcip = None | |
dstip = None | |
if packet.type == packet.IP_TYPE: | |
ip_packet = packet.payload | |
srcip = ip_packet.srcip | |
dstip = ip_packet.dstip | |
elif packet.type == packet.ARP_TYPE: | |
arp_packet = packet.payload | |
srcip = arp_packet.protosrc | |
dstip = arp_packet.protodst | |
log.debug("--- Source IP: %s Destination IP: %s" % (srcip, dstip)) | |
def is_premium(): | |
return str(dstip) in self.premiums | |
qid = 0 | |
if is_premium(): | |
qid = 1 | |
log.debug("--- Packet is sent to queue: %d" % qid) | |
install_enqueue(event, packet, outport, dstip, qid) | |
def _handle_ConnectionUp(self, event): | |
dpid = event.dpid | |
log.debug("SETUP: Switch %s has come up.", dpid) | |
# initialise switch table | |
self.mac_to_port[dpid] = dict() | |
for fw_msg in self.firewall_policies: | |
event.connection.send(fw_msg) | |
def launch(): | |
# Run discovery and spanning tree modules | |
pox.openflow.discovery.launch() | |
pox.openflow.spanning_tree.launch() | |
# Starting the controller module | |
core.registerNew(Controller) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment