Skip to content

Instantly share code, notes, and snippets.

@epcim
Forked from fpytloun/clexec.py
Created June 16, 2016 15:48
Show Gist options
  • Save epcim/fe95003aaf70d2c4d63744a741a7cbfc to your computer and use it in GitHub Desktop.
Save epcim/fe95003aaf70d2c4d63744a741a7cbfc to your computer and use it in GitHub Desktop.
Cluster execution tool
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Tool for commands execution over clusters
"""
import os, sys, logging, urllib
import argparse
import threading
import paramiko
import time
import socket
import base64
import signal
import getpass
logging.basicConfig(level=logging.WARN, format='%(levelname)s: %(message)s')
lg = logging.getLogger()
sshPool = {}
sshFailed = []
sshDone = []
sshHosts = []
threadLimiter = None
exitcode = 0
def main():
"""
main entrance
"""
global exitcode, apiUrl
# Catch SIGINFO if supported
if hasattr(signal, 'SIGINFO'):
signal.signal(signal.SIGINFO, siginfo_handler)
if hasattr(signal, 'SIGUSR1'):
signal.signal(signal.SIGUSR1, siginfo_handler)
parser = argparse.ArgumentParser(description='Execute command on cluster', add_help=False)
# Required
group_req = parser.add_argument_group('Required arguments')
group_req.add_argument('command', help="Command to be executed. Use -- after arguments accepting multiple values.", nargs='?')
# Optional
group_opt = parser.add_argument_group('Optional arguments')
group_opt.add_argument('--system-ssh', dest='system_ssh', action='store_true', help="Call system SSH client instead of Pythonish (worse escaping, obsolete)")
group_opt.add_argument('--serial', '--no-parallel', dest='serial', action='store_true', help="Execute commands on hosts one-by-one")
group_opt.add_argument('-t', '--threads', dest='threads', type=int, default=120, help="Execute commands on hosts in x threads (default 120)")
group_opt.add_argument('-f', '--file', dest='file', help="Read list of nodes from file (ignores cluster)")
group_opt.add_argument('-m', '--machines', dest='machines', nargs='+', default=[], help="List of machines to operate on")
group_opt.add_argument('-u', '--user', dest='sshUser', help="SSH user to connect with, defaults to the current user", default=getpass.getuser())
group_opt.add_argument('-K', '--key-file', dest='sshKeyFile', help="SSH user key file to connect with")
group_opt.add_argument('-d', '--domain', dest='domain', help="Location domain to use")
group_opt.add_argument('-e', '--exitcode', dest='exitcode', action='store_true', help="Exit with non-zero exit code if command return non-zero exit code")
# Output switchers
group_out = parser.add_argument_group('Output switchers')
group_out.add_argument('--debug', dest='debug', action='store_true')
# Action switchers
group_act = parser.add_argument_group('Action switchers')
group_act.add_argument('-h', '--help', dest='help', action='store_true', help="Show this help")
group_act.add_argument('-L', '--list-nodes', dest='list_nodes', action='store_true', help="List nodes where command would be executed")
group_act.add_argument('-I', '--interactive', '--shell', dest='interactive', action='store_true', help="Run in interactive mode, same as if command is -")
group_act.add_argument('-U', '--upload', dest='upload', help="Upload file to [command] on nodes")
args = parser.parse_args()
if args.debug:
lg.setLevel(logging.DEBUG)
if args.help:
parser.print_help()
sys.exit(0)
if args.interactive:
args.command = '-'
if args.threads != None and args.threads == 0:
print base64.b64decode('ICAgICAgIF8gICAgIF8KICAgICAgIFxgXCAvYC8KICAgICAgICBcIFYgLyAgICAgICAgICAgICAgIAogICAgICAgIC8uIC5cICAgICAgIAogICAgICAgPVwgVCAvPSAgICAgICAgICAgICAgICAgIAogICAgICAgIC8gXiBcICAgICAKICAgICAgIC9cXCAvL1wKICAgICBfX1wgIiAiIC9fXyAgICAgICAgICAgCiAgICAoX19fXy9eXF9fX18pCiAgWW91J3JlIGEgVGVhcG90IQo=')
sys.exit(1)
# Set thread limiter
global threadLimiter
if args.threads:
threadLimiter = threading.BoundedSemaphore(args.threads)
else:
if args.serial:
threadLimiter = threading.BoundedSemaphore(1)
else:
# Default limit is 120 threads at once
threadLimiter = threading.BoundedSemaphore(120)
# Can't read from stdin for multiple options
if args.file == '-' and args.command == '-':
lg.error("Can't read nodes and command from stdin, try to use -m option instead of -f")
sys.exit(1)
lg.debug("Command: %s" % args.command)
if args.file:
if args.file != '-':
try:
machines = open(args.file, 'r')
except:
lg.error("Can't open file %s" % args.file)
sys.exit(1)
else:
machines = sys.stdin
m = []
for machine in machines.readlines():
m.append(machine.replace('\n', ''))
machines = m
elif args.machines:
machines = args.machines
else:
raise RuntimeError("You need to submit list of hosts to connect to")
global sshHosts
sshHosts = machines
# Interactive mode
if args.command == '-':
import readline
readline.parse_and_bind('tab: complete')
readline.parse_and_bind('set editing-mode vi')
while True:
try:
args.command = raw_input("$> ")
except (KeyboardInterrupt, SystemExit, EOFError):
lg.debug("Interrupted")
sshCleanup()
print ''
sys.exit(0)
if args.command in ['exit', 'quit']:
sshCleanup()
sys.exit(0)
if args.command:
# Do the job
pool = run(machines, args)
# Wait till all threads are done
try:
alive = len(pool)
while alive > 0:
alive = len(pool)
lg.debug("Waiting for %i threads" % alive)
for thread in pool:
if not thread.isAlive():
alive -= 1
time.sleep(0.5)
except (KeyboardInterrupt, SystemExit):
lg.debug("Received keyboard interrupt. Cleaning threads and exitting.")
for thread in pool:
if thread.isAlive():
lg.debug("Killing thread %s" % thread.getName())
try:
thread._Thread__stop()
except:
lg.error("Thread %s cannot be terminated" % thread.getName())
sshCleanup()
sys.exit(1)
sshCleanup()
sys.exit(0)
# Do the job (normal mode)
pool = run(machines, args)
# Wait till all threads are done
try:
alive = len(pool)
while alive > 0:
alive = len(pool)
lg.debug("Waiting for %i threads" % alive)
for thread in pool:
if not thread.isAlive():
alive -= 1
time.sleep(0.5)
except (KeyboardInterrupt, SystemExit):
lg.debug("Received keyboard interrupt. Cleaning threads and exitting.")
for thread in pool:
if thread.isAlive():
lg.debug("Killing thread %s" % thread.getName())
try:
thread._Thread__stop()
except Exception as e:
lg.error("Thread %s cannot be terminated: %s" % (thread.getName(), e))
finally:
sshCleanup()
if sshFailed:
lg.error("Failed connections (%s/%s): %s" % (len(sshFailed), len(sshHosts), ','.join(sshFailed)))
sys.exit(1)
if args.exitcode:
sys.exit(exitcode)
def wait_threads():
"""
Wait until all active threads are done
we usually don't want to use this, because
it will also wait for infinite transport threads
"""
try:
while threading.activeCount() > 1:
lg.debug("Waiting for %i threads" % (threading.activeCount() - 1))
time.sleep(0.5)
except (KeyboardInterrupt, SystemExit, EOFError):
threads = threading.enumerate()
for thread in threads:
lg.debug("Killing thread %s" % thread.getName())
try:
thread._Thread__stop()
except Exception as e:
lg.error("Thread %s cannot be terminated: %s" % (thread.getName(), e))
finally:
sshCleanup()
def run(machines, args):
lg.debug("Hosts: %s" % machines)
pool = []
if isinstance(machines, list):
tmp = {}
for host in machines:
tmp[host] = {
'hostname' : host,
'ip_public' : host,
'ip' : host,
'instance_id' : None,
}
machines = tmp
for hostname in machines.iterkeys():
node = machines[hostname]
if args.domain:
node['connect'] = "%s.%s" % (hostname, args.domain)
else:
node['connect'] = hostname
if args.list_nodes:
if not args.ip:
print "{0:<25}{1}".format(hostname, machines[hostname]['instance_id'])
else:
print "{0:<25}{1}{2:>20}".format(hostname, machines[hostname]['instance_id'], node['connect'])
else:
if not args.command:
lg.error("'command' option have to be set")
sys.exit(1)
if args.upload:
t = threading.Thread(target=uploadFile, args=(node, args.upload, args.command, args.sshUser, args.sshKeyFile))
else:
if args.system_ssh:
lg.warn('Using system SSH is obsolete and may be buggy. Avoid using this option!')
t = threading.Thread(name=hostname, target=runSSH, args=(node['connect'], args.command, args.sshUser, args.sshKeyFile))
else:
t = threading.Thread(name=hostname, target=runRemote, args=(node, args.command, args.sshUser, args.sshKeyFile))
t.start()
pool.append(t)
return pool
def runSSH(name, command, user, keyFile=None):
"""
execute command on <name> host
"""
global exitcode
threadLimiter.acquire()
try:
# popen 0 - last argument means unbuffered output
lg.debug("Execute '%s' on '%s'" % (command, name))
command = 'source /etc/profile >/dev/null;%s' % command
fh = os.popen(('ssh -qAY -o UserKnownHostsFile=/dev/null -o StrictHostKeyChecking=no -o IdentityFile=%s %s@%s -- \"' % (keyFile, user, name)) + command + ' 2>&1\"', 'r', 0)
for line in fh:
sys.stdout.write ("%s: %s" % (name, line))
sys.stdout.flush()
fh.close()
print "close: %s os.status" % (fh.close(), os.WEXITSTATUS)
if os.WEXITSTATUS != 0:
exitcode = os.WEXITSTATUS
finally:
threadLimiter.release()
def sshCleanup():
for ssh in sshPool.keys():
lg.debug("Closing connection to %s" % ssh)
sshPool[ssh].close()
def runRemote(node, command, user, keyFile=None):
"""
execute command on <node> host with Paramiko
"""
global sshPool
global sshDone
global exitcode
threadLimiter.acquire()
try:
lg.debug("Execute '%s' on '%s'" % (command, node))
command = 'source /etc/profile >/dev/null;%s' % command
connect = node['connect']
if connect in sshFailed:
return False
try: sshPool[connect]
except:
sshPool[connect] = paramiko.SSHClient()
sshPool[connect].set_missing_host_key_policy(paramiko.AutoAddPolicy())
try:
sshPool[connect].load_system_host_keys()
except paramiko.SSHException as e:
lg.error("Can't load system known hosts: %s" % e)
sshFailed.append(connect)
return False
try:
sshPool[connect].connect(connect, username=user, timeout=5, key_filename=keyFile)
except KeyboardInterrupt:
lg.info("Interrupted")
sys.exit(0)
except (socket.gaierror, socket.error) as e:
lg.error("Can't connect to %s (%s): %s" % (connect, node['hostname'], e))
sshFailed.append(connect)
return False
except socket.timeout as e:
lg.error("Timeout during connecting to %s (%s)" % (connect, node['hostname']))
sshFailed.append(connect)
return False
except paramiko.SSHException as e:
lg.error("Can't connect to %s (%s) as user %s: %s" % (connect, node['hostname'], user, e))
sshFailed.append(connect)
return False
trans = sshPool[connect].get_transport()
if not trans:
lg.error("Can't get transport for connection %s (%s)", (connect, node['hostname']))
sshFailed.append(connect)
return False
chan = trans.open_session()
if not chan:
lg.error("Connection to %s (%s) no longer active", (connect, node['hostname']))
sshFailed.append(connect)
return False
chan.get_pty()
# Timeout 5 seconds for first command
# to test connection
chan.settimeout(5)
try:
output = chan.makefile()
chan.exec_command('hostname')
for line in output:
if command != 'hostname':
lg.debug("Connected to %s (hostname %s)" % (connect, line.replace("\r\n", "\n")))
else:
sys.stdout.write("%s: %s" % (connect, line.replace("\r\n", "\n")))
sys.stdout.flush()
return True
except socket.timeout:
lg.error("Timeout during communication with %s (%s)" % (connect, node['hostname']))
chan.close()
return False
# Channel without timeout for our command
chan = trans.open_session()
chan.settimeout(None)
chan.get_pty()
output = chan.makefile()
chan.exec_command(command)
for line in output:
sys.stdout.write("%s: %s" % (node['hostname'], line.replace("\r\n", "\n")))
sys.stdout.flush()
# Cleanup
chan.close()
status = chan.recv_exit_status()
lg.debug("Exit status: %s" % status)
if status != -1 and status != 0 :
exitcode = status
finally:
sshDone.append(node['hostname'])
threadLimiter.release()
def uploadFile(node, localFile, remoteFile, user, keyFile=None):
"""
upload file to remote host with Paramiko
"""
connect = node['connect']
lg.debug("Upload '%s' on '%s:%s'" % (localFile, node['hostname'], remoteFile))
ssh = paramiko.SSHClient()
ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
ssh.load_system_host_keys()
try:
ssh.connect(connect, username=user, timeout=5, key_filename=keyFile)
except:
lg.warn("Can't connect to %s (%s)" % (connect, node['hostname']))
return False
ftp = ssh.open_sftp()
try:
ftp.put(localFile, remoteFile)
except (OSError, IOError), e:
lg.error(e)
sys.exit(1)
ftp.close()
ssh.close()
def siginfo_handler(signum, frame):
if threading.activeCount() > 1:
nodes_active = []
for thread in threading.enumerate():
if thread.getName() != 'MainThread':
nodes_active.append(thread.getName())
print "--"
print "Done: %s/%s" % (len(sshDone), len(sshHosts))
print "SSH connections: %s" % len(sshPool)
print "Connections failed: %s" % len(sshFailed)
print "Threads count: %s" % threading.activeCount()
print "Thread names: %s" % ','.join(nodes_active)
print "--"
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment