Skip to content

Instantly share code, notes, and snippets.

@gregswift
Last active August 29, 2015 14:00
Show Gist options
  • Save gregswift/11401891 to your computer and use it in GitHub Desktop.
Save gregswift/11401891 to your computer and use it in GitHub Desktop.
Standalone implementation of enhancement to ansible to add wait_for state=drained
import socket
import datetime
import time
import sys
import re
import binascii
HAS_PSUTIL = False
try:
import psutil
HAS_PSUTIL = True
# just because we can import it on Linux doesn't mean we will use it
except ImportError:
pass
def load_platform_subclass(cls, *args, **kwargs):
'''
used by modules like User to have different implementations based on detected platform. See User
module for an example.
'''
this_platform = get_platform()
distribution = get_distribution()
subclass = None
# get the most specific superclass for this platform
if distribution is not None:
for sc in cls.__subclasses__():
if sc.distribution is not None and sc.distribution == distribution and sc.platform == this_platform:
subclass = sc
if subclass is None:
for sc in cls.__subclasses__():
if sc.platform == this_platform and sc.distribution is None:
subclass = sc
if subclass is None:
subclass = cls
return super(cls, subclass).__new__(subclass)
class TCPConnectionInfo(object):
"""
This is a generic TCP Connection Info strategy class that relies
on the psutil module, which is not ideal for targets, but necessary
for cross platform support.
A subclass may wish to override some or all of these methods.
- get_exclude_ips()
- get_active_connections()
All subclasses MUST define platform and distribution (which may be None).
"""
platform = 'Generic'
distribution = None
match_all_ips = {
socket.AF_INET: '0.0.0.0',
socket.AF_INET6: '::',
}
connection_states = {
'01': 'ESTABLISHED',
'02': 'SYN_SENT',
'03': 'SYN_RECV',
'04': 'FIN_WAIT1',
'05': 'FIN_WAIT2',
'06': 'TIME_WAIT',
}
def __new__(cls, *args, **kwargs):
return load_platform_subclass(TCPConnectionInfo, args, kwargs)
def __init__(self, module):
self.module = module
self.name = module.params['name']
(self.family, self.host) = _convert_host_to_ip(self.module.params['host'])
self.port = int(self.module.params['port'])
self.exclude_ips = self._get_exclude_ips()
if not HAS_PSUTIL:
module.fail_json(msg="psutil module required for wait_for")
def _get_exclude_ips(self):
exclude_hosts = self.module.params['exclude_hosts'].split(','))
return [ _convert_host_to_hex(h) for h in exclude_hosts ]
def get_active_connections_count(self):
active_connections = 0
for p in psutil.process_iter():
connections = p.get_connections(kind='inet')
for conn in connections:
if conn.status not in self.connection_states.values():
continue
(local_ip, local_port) = conn.local_address
if self.port == local_port and ip in [self.match_all_ips[self.family], local_ip]:
(remote_ip, remote_port) = conn.remote_address
if remote_hex_ip not in self.exclude_ips:
active_connections += 1
return active_connections
class LinuxTCPConnectionInfo(TCPConnectionInfo):
"""
This is a TCP Connection Info evaluation strategy class
that utilizes information from Linux's procfs. While less universal,
does allow Linux targets to not require an additional library.
"""
platform = 'Linux'
distribution = None
source_file = {
socket.AF_INET: '/proc/net/tcp',
socket.AF_INET6: '/proc/net/tcp6'
}
match_all_ips = {
socket.AF_INET: '00000000',
socket.AF_INET6: '00000000000000000000000000000000',
}
local_address_field = 1
remote_address_field = 2
connection_state_field = 3
def __init__(self, module):
self.module = module
self.name = module.params['name']
(self.family, self.host) = _convert_host_to_hex(module.params['host'])
self.port = "%0.4X" % int(module.params['port'])
self.exclude_ips = self._get_exclude_ips()
def _get_exclude_ips(self):
exclude_hosts = self.module.params['exclude_hosts'].split(','))
return [ _convert_host_to_hex(h) for h in exclude_hosts ]
def get_active_connections_count(self, family):
active_connections = 0
try:
f = open(self.source_file[family])
except IOError:
pass
else:
for tcp_connection in f.readlines():
tcp_connection = tcp_connection.strip().split(' ')
if tcp_connection[self.local_address_field] == 'local_address':
continue
if tcp_connection[self.connection_state_field] not in self.connection_states:
continue
(local_ip, local_port) = tcp_connection[self.local_address_field].split(':')
if self.port == local_port and ip in [self.match_all_ips[self.family], local_ip]:
(remote_ip, remote_port) = tcp_connection[self.remote_address_field].split(':')
if remote_hex_ip not in self.exclude_ips:
active_connections += 1
f.close()
return active_connections
def _convert_host_to_ip(host):
"""
Perform forward DNS resolution on host, IP will give the same IP
Args:
host: String with either hostname, IPv4, or IPv6 address
Returns:
Tuple containing address family and IP
"""
addrinfo = socket.getaddrinfo(host, 80, 0, 0, socket.SOL_TCP)[0]
return (addrinfo[0], addrinfo[4][0])
def _convert_host_to_hex(host):
"""
Convert the provided host to the format in /proc/net/tcp*
/proc/net/tcp uses little-endian four byte hex for ipv4
/proc/net/tcp6 uses little-endian per 4B word for ipv6
Args:
host: String with either hostname, IPv4, or IPv6 address
Returns:
Tuple containing address family and the little-endian converted host
"""
(family, ip) = _convert_host_to_ip(host)
hexed = binascii.hexlify(socket.inet_pton(family, ip).upper()
if family == socket.AF_INET:
hexed = _little_endian_convert_32bit(hexed)
elif family == socket.AF_INET6:
# xrange loops through each 8 character (4B) set in the 128bit total
hexed = "".join([ _little_endian_convert_32bit(hexed[x:x+8]) for x in xrange(0, 32, 8) ])
return (family, hexed)
def _little_endian_convert_32bit(block):
"""
Convert to little-endian, effectively transposing
the order of the four byte word
12345678 -> 78563412
Args:
block: String containing a 4 byte hex representation
Returns:
String containing the little-endian converted block
"""
# xrange starts at 6, and increments by -2 until it reaches -2
# which lets us start at the end of the string block and work to the begining
return "".join([ block[x:x+2] for x in xrange(6, -2, -2) ])
def wait_for_drain(host, port, exclude_hosts=None, timeout=30):
"""
@host can be any resolveable DNS or an IP
@port numerical representation only
@exclude_hosts array of host entries, entries same as @host
@timeout how long to wait for the drain
returns count of active connections
"""
start = datetime.datetime.now()
end = start + datetime.timedelta(seconds=timeout)
tcpconns = TCPConnectionInfo(module)
while datetime.datetime.now() < end:
active_connections = tcpconns.get_active_connections_count()
if active_connections == 0:
break
time.sleep(1)
else:
print "Timeout when waiting for %s:%s to drain" % (host, port)
return active_count
if __name__ == '__main__':
print "Found {0} connections to localhost:22".format(wait_for_drain('127.0.0.1',22))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment