Created
September 2, 2019 10:54
-
-
Save ARMATURETechnologies/3fc9756204e5f5b4c4ca047a5e951659 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/python3.6 | |
import argparse | |
import enum | |
import pathlib | |
import sys | |
import tempfile | |
from scapy.all import * | |
from typing import ( | |
Any, | |
Dict, | |
List, | |
Optional, | |
Union, | |
) | |
from netaddr import ( | |
IPAddress, | |
INET_PTON | |
) | |
PROTO_TBL = { | |
UDP: 17, | |
TCP: 6, | |
ICMP: 1, | |
} | |
class TCPState(enum.Enum): | |
SYN_SENT = 1 | |
SYN_ACKED = 2 | |
OPENED = 3 | |
FIN_SENT = 4 | |
FIN_SENT2 = 5 | |
CLOSED = 6 | |
class FlowSpec: | |
@staticmethod | |
def from_pkt(pkt: Packet) -> 'FlowSpec': | |
if IP not in pkt: | |
raise NotImplementedError() | |
ip_src = IPAddress(pkt[IP].src, flags=INET_PTON) | |
ip_dst = IPAddress(pkt[IP].dst, flags=INET_PTON) | |
trans = type(pkt[IP].payload) | |
try: | |
proto = PROTO_TBL[trans] | |
except KeyError: | |
raise NotImplementedError() | |
port_src = None | |
port_dst = None | |
if trans in (UDP, TCP): | |
port_src = pkt[trans].sport | |
port_dst = pkt[trans].dport | |
vlan_id = None | |
if Dot1Q in pkt: | |
vlan_id = pkt[Dot1Q].vlan | |
flow = FlowSpec( | |
ip_src, ip_dst, | |
proto, port_src, port_dst, | |
vlan_id | |
) | |
return flow | |
def __init__(self, | |
ip_src: IPAddress, ip_dst: IPAddress, | |
protocol: int, port_src: Optional[int], port_dst: Optional[int], | |
vlan_id: Optional[int] = None): | |
self._ip_src = ip_src | |
self._ip_dst = ip_dst | |
self._proto = protocol | |
self._port_src = port_src | |
self._port_dst = port_dst | |
self._vlan_id = vlan_id | |
def __eq__(self, other: Any) -> bool: | |
if not isinstance(other, FlowSpec): | |
return False | |
return ( | |
self._ip_src == other.ip_src and | |
self._ip_dst == other.ip_dst and | |
self._proto == other.proto and | |
self._port_src == other.port_src and | |
self._port_dst == other.port_dst and | |
self._vlan_id == other.vlan_id | |
) | |
@property | |
def ip_src(self) -> IPAddress: | |
return self._ip_src | |
@property | |
def ip_dst(self) -> IPAddress: | |
return self._ip_dst | |
@property | |
def proto(self) -> int: | |
return self._proto | |
@property | |
def port_src(self) -> Optional[int]: | |
return self._port_src | |
@property | |
def port_dst(self) -> Optional[int]: | |
return self._port_dst | |
@property | |
def vlan_id(self) -> Optional[int]: | |
return self._vlan_id | |
def __str__(self) -> str: | |
return '<FlowSpec: vlan={}, ip_src={}, ip_dst={}, proto={}, port_src={}, port_dst={}'.format( | |
self.vlan_id, self.ip_src, self.ip_dst, | |
self.proto, self.port_src, self.port_dst, | |
) | |
def __repr__(self) -> str: | |
return '"{}"'.format(str(self)) | |
def __hash__(self) -> int: | |
return hash((self.ip_src, self.ip_dst, self.proto, self.port_src, self.port_dst, self.vlan_id)) | |
class OrderedFlowSpec: | |
@staticmethod | |
def from_flow_spec(flow: FlowSpec) -> 'OrderedFlowSpec': | |
if flow.ip_src < flow.ip_dst: | |
src_to_dst = True | |
elif flow.ip_src > flow.ip_dst: | |
src_to_dst = False | |
elif flow.port_src < flow.port_dst: | |
src_to_dst = True | |
else: | |
src_to_dst = False | |
if src_to_dst: | |
fst_ip, snd_ip = flow.ip_src, flow.ip_dst | |
fst_port, snd_port = flow.port_src, flow.port_dst | |
else: | |
fst_ip, snd_ip= flow.ip_dst, flow.ip_src | |
fst_port, snd_port = flow.port_dst, flow.port_src | |
return OrderedFlowSpec( | |
fst_ip, snd_ip, | |
flow.proto, fst_port, snd_port, | |
flow.vlan_id | |
) | |
def __init__(self, fst_ip: IPAddress, snd_ip: IPAddress, | |
proto: int, fst_port: Optional[int], snd_port: Optional[int], | |
vlan_id: Optional[int]): | |
self._fst_ip = fst_ip | |
self._snd_ip = snd_ip | |
self._proto = proto | |
self._fst_port = fst_port | |
self._snd_port = snd_port | |
self._vlan_id = vlan_id | |
def __eq__(self, other: Any) -> bool: | |
if not isinstance(other, OrderedFlowSpec): | |
return False | |
return ( | |
self._fst_ip == other.fst_ip and | |
self._snd_ip == other.snd_ip and | |
self._proto == other.proto and | |
self._fst_port == other.fst_port and | |
self._snd_port == other.snd_port and | |
self._vlan_id == other.vlan_id | |
) | |
@property | |
def fst_ip(self) -> IPAddress: | |
return self._fst_ip | |
@property | |
def snd_ip(self) -> IPAddress: | |
return self._snd_ip | |
@property | |
def proto(self) -> int: | |
return self._proto | |
@property | |
def fst_port(self) -> Optional[int]: | |
return self._fst_port | |
@property | |
def snd_port(self) -> Optional[int]: | |
return self._snd_port | |
@property | |
def vlan_id(self) -> Optional[int]: | |
return self._vlan_id | |
def __str__(self) -> str: | |
return '<OrderedFlowSpec: vlan={}, ip1={}, ip2={}, proto={}, port1={}, port2={}'.format( | |
self.vlan_id, self.fst_ip, self.snd_ip, | |
self.proto, self.fst_port, self.snd_port, | |
) | |
def __repr__(self) -> str: | |
return '"{}"'.format(str(self)) | |
def __hash__(self) -> int: | |
return hash((self.fst_ip, self.snd_ip, self.proto, self.fst_port, self.snd_port, self.vlan_id)) | |
class Destination: | |
def __init__(self, vlan_id: Optional[int], ip: IPAddress, proto: int, port: int): | |
self._vlan_id = vlan_id | |
self._ip = ip | |
self._proto = proto | |
self._port = port | |
def __hash__(self) -> int: | |
return hash((self._vlan_id, self._ip, self._proto, self._port)) | |
def __str__(self) -> str: | |
return '<Destination vlan: {}, IP: {}, protocol: {}, port: {}>'.format( | |
self._vlan_id, | |
self._ip, | |
self._proto, | |
self._port, | |
) | |
def __repr__(self) -> str: | |
return '"{}"'.format(str(self)) | |
@property | |
def vlan_id(self) -> Optional[int]: | |
return self._vlan_id | |
@property | |
def ip(self) -> IPAddress: | |
return self._ip | |
@property | |
def proto(self) -> int: | |
return self._proto | |
@property | |
def port(self) -> Optional[int]: | |
return self._port | |
def __eq__(self, other: Any) -> bool: | |
if isinstance(other, Destination): | |
return ( | |
self._vlan_id == other.vlan_id and | |
self._ip == other.ip and | |
self._proto == other.proto and | |
self._port == port | |
) | |
elif isinstance(other, FlowSpec): | |
return ( | |
self._vlan_id == other.vlan_id and | |
self._proto == other.proto and | |
( | |
( | |
self._ip == other.ip_src and | |
self._port == other.port_src | |
) or | |
( | |
self._ip == other.ip_dst and | |
self._port == other.port_dst | |
) | |
) | |
) | |
elif isinstance(other, OrderedFlowSpec): | |
return ( | |
self._vlan_id == other.vlan_id and | |
self._proto == other.proto and | |
( | |
( | |
self._ip == other.fst_ip and | |
self._port == other.fst_port | |
) or | |
( | |
self._ip == other.snd_ip and | |
self._port == other.snd_port | |
) | |
) | |
) | |
raise NotImplemented() | |
@classmethod | |
def from_flowspec(cls, flow: FlowSpec) -> List['Destination']: | |
return [ | |
Destination( | |
flow.vlan_id, | |
flow.ip_src, | |
flow.proto, | |
flow.port_src, | |
), | |
Destination( | |
flow.vlan_id, | |
flow.ip_dst, | |
flow.proto, | |
flow.port_dst, | |
) | |
] | |
@classmethod | |
def from_ordered_flowspec(cls, flow: OrderedFlowSpec) -> List['Destination']: | |
return [ | |
Destination( | |
flow.vlan_id, | |
flow.fst_ip, | |
flow.proto, | |
flow.fst_port, | |
), | |
Destination( | |
flow.vlan_id, | |
flow.snd_ip, | |
flow.proto, | |
flow.snd_port, | |
) | |
] | |
class SessionState: | |
def __init__(self): | |
self._cur_state = TCPState.CLOSED | |
self._pkt_nums: List[int] = [] | |
self._fin_src: Optional[str] = None | |
@property | |
def cur_state(self) -> TCPState: | |
return self._cur_state | |
@cur_state.setter | |
def cur_state(self, value: TCPState) -> None: | |
self._cur_state = value | |
def append(self, value: int) -> None: | |
self._pkt_nums.append(value) | |
@property | |
def pkt_nums(self) -> List[int]: | |
return self._pkt_nums | |
@property | |
def fin_src(self) -> str: | |
return self._fin_src | |
@fin_src.setter | |
def fin_src(self, value: str) -> None: | |
self._fin_src = value | |
class MyPcapWriter(PcapWriter): | |
def __init__(self, filename: str): | |
self._counter = 0 | |
self._filename = filename | |
def write(self, pkt_lst: Union[Packet, List[Packet]]): | |
wrpcap('{}.{}'.format(self._filename, self._counter), pkt_lst) | |
self._counter += 1 | |
class PcapCleaner: | |
def __init__(self, filerd: str, filewr: str, rejected: Optional[str]): | |
self._rd1 = PcapReader(filerd) | |
self._rd2 = PcapReader(filerd) | |
self._wr = MyPcapWriter(filewr) | |
if rejected is not None: | |
self._rejwr = MyPcapWriter(rejected) | |
else: | |
self._rejwr = None | |
self._sessions: Dict[OrderedFlowSpec, SessionState] = {} | |
self._blacklist: Dict[int, bool] = {} | |
def _notify_packet(self, i: int, pkt: Packet) -> bool: | |
if TCP not in pkt: | |
return True | |
fs = FlowSpec.from_pkt(pkt) | |
ofs = OrderedFlowSpec.from_flow_spec(fs) | |
if ofs not in self._sessions: | |
session = SessionState() | |
self._sessions[ofs] = session | |
else: | |
session = self._sessions[ofs] | |
flags = pkt.sprintf('%TCP.flags%') | |
if session.cur_state is TCPState.CLOSED and flags == 'S' and 'A' not in flags: | |
session.cur_state = TCPState.SYN_SENT | |
session.append(i) | |
return True | |
elif session.cur_state is TCPState.SYN_SENT and flags == 'SA': | |
session.cur_state = TCPState.SYN_ACKED | |
session.append(i) | |
return True | |
elif session.cur_state is TCPState.SYN_ACKED and 'A' in flags and 'S' not in flags: | |
session.cur_state = TCPState.OPENED | |
session.append(i) | |
return True | |
elif session.cur_state is TCPState.OPENED and 'F' not in flags: | |
session.append(i) | |
return True | |
if session.cur_state in (TCPState.OPENED, TCPState.SYN_ACKED) and 'F' in flags: | |
session.cur_state = TCPState.FIN_SENT | |
session.append(i) | |
session.fin_src = pkt[IP].src | |
return True | |
elif session.cur_state is TCPState.FIN_SENT and ('F' in flags or 'R' in flags) and session.fin_src == pkt[IP].dst: | |
session.cur_state = TCPState.FIN_SENT2 | |
session.append(i) | |
return True | |
elif session.cur_state is TCPState.FIN_SENT2 and session.fin_src == pkt[IP].src and flags in ('A', 'R'): | |
del self._sessions[ofs] | |
return True | |
return False | |
def _scan1(self) -> None: | |
""" Scans the pcap and adds to blacklist all packets that are midstream TCP sessions """ | |
for i, pkt in enumerate(self._rd1): | |
if i % 1000 == 0: | |
print('.', end="", flush=True) | |
if not self._notify_packet(i, pkt): | |
self._blacklist[i] = True | |
def _build_blacklist_from_first_scan(self) -> None: | |
""" Adds to the blacklist all sessions that were started but not terminated """ | |
for session in self._sessions.values(): | |
for pkt_num in session.pkt_nums: | |
self._blacklist[pkt_num] = True | |
def _scan2(self) -> None: | |
l = [] | |
rejl = [] | |
for i, pkt in enumerate(self._rd2): | |
if i % 1000000 == 0: | |
print('.', end="", flush=True) | |
self._wr.write(l) | |
if self._rejwr: | |
self._rejwr.write(rejl) | |
l = [] | |
rejl = [] | |
if i not in self._blacklist: | |
l.append(pkt) | |
elif self._rejwr: | |
rejl.append(pkt) | |
if l: | |
self._wr.write(l) | |
if self._rejwr: | |
self._rejwr.write(rejl) | |
def filter(self, keep_opened: bool) -> None: | |
print('Pass 1') | |
self._scan1() | |
if not keep_opened: | |
self._build_blacklist_from_first_scan() | |
print('\nPass 2') | |
self._scan2() | |
print('') | |
def main(argv) -> None: | |
parser = argparse.ArgumentParser() | |
parser.add_argument('-i', '--input', dest='input', action='store', required=True, type=str) | |
parser.add_argument('-o', '--output', dest='output', action='store', required=True, type=str) | |
parser.add_argument('-k', '--keep-opened', dest='keep_opened', action='store_true', required=False, default=False) | |
parser.add_argument('-r', '--rejected', dest='rejected', action='store', required=False, type=str) | |
args = vars(parser.parse_args(argv)) | |
f = PcapCleaner(args['input'], args['output'], args['rejected']) | |
f.filter(args['keep_opened']) | |
return 0 | |
if __name__ == '__main__': | |
sys.exit(main(sys.argv[1:])) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment