Skip to content

Instantly share code, notes, and snippets.

Last active August 29, 2015 14:00
Show Gist options
  • Save gregswift/11401891 to your computer and use it in GitHub Desktop.
Save gregswift/11401891 to your computer and use it in GitHub Desktop.
Standalone implementation of enhancement to ansible to add wait_for state=drained
import socket
import datetime
import time
import sys
import re
import binascii
import psutil
# just because we can import it on Linux doesn't mean we will use it
except ImportError:
def load_platform_subclass(cls, *args, **kwargs):
used by modules like User to have different implementations based on detected platform. See User
module for an example.
this_platform = get_platform()
distribution = get_distribution()
subclass = None
# get the most specific superclass for this platform
if distribution is not None:
for sc in cls.__subclasses__():
if sc.distribution is not None and sc.distribution == distribution and sc.platform == this_platform:
subclass = sc
if subclass is None:
for sc in cls.__subclasses__():
if sc.platform == this_platform and sc.distribution is None:
subclass = sc
if subclass is None:
subclass = cls
return super(cls, subclass).__new__(subclass)
class TCPConnectionInfo(object):
This is a generic TCP Connection Info strategy class that relies
on the psutil module, which is not ideal for targets, but necessary
for cross platform support.
A subclass may wish to override some or all of these methods.
- get_exclude_ips()
- get_active_connections()
All subclasses MUST define platform and distribution (which may be None).
platform = 'Generic'
distribution = None
match_all_ips = {
socket.AF_INET: '',
socket.AF_INET6: '::',
connection_states = {
'02': 'SYN_SENT',
'03': 'SYN_RECV',
'04': 'FIN_WAIT1',
'05': 'FIN_WAIT2',
'06': 'TIME_WAIT',
def __new__(cls, *args, **kwargs):
return load_platform_subclass(TCPConnectionInfo, args, kwargs)
def __init__(self, module):
self.module = module = module.params['name']
(, = _convert_host_to_ip(self.module.params['host'])
self.port = int(self.module.params['port'])
self.exclude_ips = self._get_exclude_ips()
if not HAS_PSUTIL:
module.fail_json(msg="psutil module required for wait_for")
def _get_exclude_ips(self):
exclude_hosts = self.module.params['exclude_hosts'].split(','))
return [ _convert_host_to_hex(h) for h in exclude_hosts ]
def get_active_connections_count(self):
active_connections = 0
for p in psutil.process_iter():
connections = p.get_connections(kind='inet')
for conn in connections:
if conn.status not in self.connection_states.values():
(local_ip, local_port) = conn.local_address
if self.port == local_port and ip in [self.match_all_ips[], local_ip]:
(remote_ip, remote_port) = conn.remote_address
if remote_hex_ip not in self.exclude_ips:
active_connections += 1
return active_connections
class LinuxTCPConnectionInfo(TCPConnectionInfo):
This is a TCP Connection Info evaluation strategy class
that utilizes information from Linux's procfs. While less universal,
does allow Linux targets to not require an additional library.
platform = 'Linux'
distribution = None
source_file = {
socket.AF_INET: '/proc/net/tcp',
socket.AF_INET6: '/proc/net/tcp6'
match_all_ips = {
socket.AF_INET: '00000000',
socket.AF_INET6: '00000000000000000000000000000000',
local_address_field = 1
remote_address_field = 2
connection_state_field = 3
def __init__(self, module):
self.module = module = module.params['name']
(, = _convert_host_to_hex(module.params['host'])
self.port = "%0.4X" % int(module.params['port'])
self.exclude_ips = self._get_exclude_ips()
def _get_exclude_ips(self):
exclude_hosts = self.module.params['exclude_hosts'].split(','))
return [ _convert_host_to_hex(h) for h in exclude_hosts ]
def get_active_connections_count(self, family):
active_connections = 0
f = open(self.source_file[family])
except IOError:
for tcp_connection in f.readlines():
tcp_connection = tcp_connection.strip().split(' ')
if tcp_connection[self.local_address_field] == 'local_address':
if tcp_connection[self.connection_state_field] not in self.connection_states:
(local_ip, local_port) = tcp_connection[self.local_address_field].split(':')
if self.port == local_port and ip in [self.match_all_ips[], local_ip]:
(remote_ip, remote_port) = tcp_connection[self.remote_address_field].split(':')
if remote_hex_ip not in self.exclude_ips:
active_connections += 1
return active_connections
def _convert_host_to_ip(host):
Perform forward DNS resolution on host, IP will give the same IP
host: String with either hostname, IPv4, or IPv6 address
Tuple containing address family and IP
addrinfo = socket.getaddrinfo(host, 80, 0, 0, socket.SOL_TCP)[0]
return (addrinfo[0], addrinfo[4][0])
def _convert_host_to_hex(host):
Convert the provided host to the format in /proc/net/tcp*
/proc/net/tcp uses little-endian four byte hex for ipv4
/proc/net/tcp6 uses little-endian per 4B word for ipv6
host: String with either hostname, IPv4, or IPv6 address
Tuple containing address family and the little-endian converted host
(family, ip) = _convert_host_to_ip(host)
hexed = binascii.hexlify(socket.inet_pton(family, ip).upper()
if family == socket.AF_INET:
hexed = _little_endian_convert_32bit(hexed)
elif family == socket.AF_INET6:
# xrange loops through each 8 character (4B) set in the 128bit total
hexed = "".join([ _little_endian_convert_32bit(hexed[x:x+8]) for x in xrange(0, 32, 8) ])
return (family, hexed)
def _little_endian_convert_32bit(block):
Convert to little-endian, effectively transposing
the order of the four byte word
12345678 -> 78563412
block: String containing a 4 byte hex representation
String containing the little-endian converted block
# xrange starts at 6, and increments by -2 until it reaches -2
# which lets us start at the end of the string block and work to the begining
return "".join([ block[x:x+2] for x in xrange(6, -2, -2) ])
def wait_for_drain(host, port, exclude_hosts=None, timeout=30):
@host can be any resolveable DNS or an IP
@port numerical representation only
@exclude_hosts array of host entries, entries same as @host
@timeout how long to wait for the drain
returns count of active connections
start =
end = start + datetime.timedelta(seconds=timeout)
tcpconns = TCPConnectionInfo(module)
while < end:
active_connections = tcpconns.get_active_connections_count()
if active_connections == 0:
print "Timeout when waiting for %s:%s to drain" % (host, port)
return active_count
if __name__ == '__main__':
print "Found {0} connections to localhost:22".format(wait_for_drain('',22))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment