Skip to content

Instantly share code, notes, and snippets.

@cahna
Last active July 28, 2022 00:00
Show Gist options
  • Save cahna/218fe8edd0b47089146825a4b27be6bf to your computer and use it in GitHub Desktop.
Save cahna/218fe8edd0b47089146825a4b27be6bf to your computer and use it in GitHub Desktop.
Custom CLI wrapper for common ansible tasks.
#!/usr/bin/env python
#
# Run `python ansible-cli-wrapper.py -h` for usage information.
#
import os
from datetime import datetime
from tempfile import NamedTemporaryFile
import boto
import click
from ansible.executor import playbook_executor
from ansible.inventory import Inventory
from ansible.parsing.dataloader import DataLoader
from ansible.plugins.callback import CallbackBase
from ansible.utils.display import Display
from ansible.vars import VariableManager
from boto.vpc import VPCConnection
__version__ = '0.0.1'
DEFAULT_AWS_REGION = 'us-east-1'
SCRIPT_ABS_PATH = os.path.realpath(__file__)
SCRIPT_DIR = os.path.dirname(SCRIPT_ABS_PATH)
PLAYBOOK_ROOT = os.path.dirname(SCRIPT_DIR)
INVENTORY_FILE = os.path.join(PLAYBOOK_ROOT, 'inventory', 'ec2.py')
QUEUE_NAMES = [
'queue',
'price',
'cancel',
'regeneration',
]
class Options(object):
"""
Options class to replace Ansible OptParser
"""
def __init__(self, verbosity=None, inventory=None, listhosts=None, subset=None, module_paths=None, extra_vars=None,
forks=None, ask_vault_pass=None, vault_password_files=None, new_vault_password_file=None,
output_file=None, tags=None, skip_tags=None, one_line=None, tree=None, ask_sudo_pass=None, ask_su_pass=None,
sudo=None, sudo_user=None, become=None, become_method=None, become_user=None, become_ask_pass=None,
ask_pass=None, private_key_file=None, remote_user=None, connection=None, timeout=None, ssh_common_args=None,
sftp_extra_args=None, scp_extra_args=None, ssh_extra_args=None, poll_interval=None, seconds=None, check=None,
syntax=None, diff=None, force_handlers=None, flush_cache=None, listtasks=None, listtags=None, module_path=None):
self.verbosity = verbosity
self.inventory = inventory
self.listhosts = listhosts
self.subset = subset
self.module_paths = module_paths
self.extra_vars = extra_vars
self.forks = forks
self.ask_vault_pass = ask_vault_pass
self.vault_password_files = vault_password_files
self.new_vault_password_file = new_vault_password_file
self.output_file = output_file
self.tags = tags
self.skip_tags = skip_tags
self.one_line = one_line
self.tree = tree
self.ask_sudo_pass = ask_sudo_pass
self.ask_su_pass = ask_su_pass
self.sudo = sudo
self.sudo_user = sudo_user
self.become = become
self.become_method = become_method
self.become_user = become_user
self.become_ask_pass = become_ask_pass
self.ask_pass = ask_pass
self.private_key_file = private_key_file
self.remote_user = remote_user
self.connection = connection
self.timeout = timeout
self.ssh_common_args = ssh_common_args
self.sftp_extra_args = sftp_extra_args
self.scp_extra_args = scp_extra_args
self.ssh_extra_args = ssh_extra_args
self.poll_interval = poll_interval
self.seconds = seconds
self.check = check
self.syntax = syntax
self.diff = diff
self.force_handlers = force_handlers
self.flush_cache = flush_cache
self.listtasks = listtasks
self.listtags = listtags
self.module_path = module_path
class Runner(object):
def __init__(self, hostnames, playbook, var_overrides,
private_key_file='~/.ssh/id_rsa',
connection='ssh',
become_pass=None,
ask_vault_pass=False,
verbosity=0,
vault_pass=None):
self.run_data = var_overrides
self.options = Options()
self.options.private_key_file = private_key_file
self.options.verbosity = verbosity
self.options.connection = connection # Need a connection type "smart" or "ssh"
self.options.become = False
self.options.become_method = 'sudo'
self.options.become_user = 'root'
self.options.ask_vault_pass = ask_vault_pass
# Set global verbosity
self.display = Display()
self.display.verbosity = self.options.verbosity
# Executor appears to have it's own
# verbosity object/setting as well
playbook_executor.verbosity = self.options.verbosity
# Become Pass Needed if not logging in as user root
passwords = {'become_pass': become_pass}
# Gets data from YAML/JSON files
self.loader = DataLoader()
if vault_pass:
self.loader.set_vault_password(vault_pass)
# All the variables from all the various places
self.variable_manager = VariableManager()
self.variable_manager.extra_vars = self.run_data
# Parse hosts, I haven't found a good way to
# pass hosts in without using a parsed template :(
# (Maybe you know how?)
self.hosts = NamedTemporaryFile(delete=False)
self.hosts.write("""[run_hosts]
%s
""" % hostnames)
self.hosts.close()
# This was my attempt to pass in hosts directly.
#
# Also Note: In py2.7, "isinstance(foo, str)" is valid for
# latin chars only. Luckily, hostnames are
# ascii-only, which overlaps latin charset
## if isinstance(hostnames, str):
## hostnames = {"customers": {"hosts": [hostnames]}}
# Set inventory, using most of above objects
self.inventory = Inventory(loader=self.loader, variable_manager=self.variable_manager, host_list=INVENTORY_FILE)
self.variable_manager.set_inventory(self.inventory)
# Playbook to run. Assumes it is
# local to this python file
pb_dir = os.path.dirname(__file__)
playbook = "%s/%s" % (pb_dir, playbook)
# Setup playbook executor, but don't run until run() called
self.pbex = playbook_executor.PlaybookExecutor(
playbooks=[playbook],
inventory=self.inventory,
variable_manager=self.variable_manager,
loader=self.loader,
options=self.options,
passwords=passwords)
def run(self):
# Results of PlaybookExecutor
self.pbex.run()
stats = self.pbex._tqm._stats
# Test if success for record_logs
run_success = True
hosts = sorted(stats.processed.keys())
for h in hosts:
t = stats.summarize(h)
if t['unreachable'] > 0 or t['failures'] > 0:
run_success = False
# Dirty hack to send callback to save logs with data we want
# Note that function "record_logs" is one I created and put into
# the playbook callback file
# self.pbex._tqm.send_callback(
# 'record_logs',
# user_id=self.run_data['user_id'],
# success=run_success
# )
# Remove created temporary files
os.remove(self.hosts.name)
return stats
class PlayLogger:
"""Store log output in a single object.
We create a new object per Ansible run
"""
def __init__(self):
self.log = ''
self.runtime = 0
def append(self, log_line):
"""append to log"""
self.log += log_line + "\n\n"
def banner(self, msg):
"""Output Trailing Stars"""
width = 78 - len(msg)
if width < 3:
width = 3
filler = "*" * width
return "\n%s %s " % (msg, filler)
class CallbackModule(CallbackBase):
"""
Reference: https://github.com/ansible/ansible/blob/v2.0.0.2-1/lib/ansible/plugins/callback/default.py
"""
CALLBACK_VERSION = 2.0
CALLBACK_TYPE = 'stored'
CALLBACK_NAME = 'database'
def __init__(self):
super(CallbackModule, self).__init__()
self.logger = PlayLogger()
self.start_time = datetime.now()
def v2_runner_on_failed(self, result, ignore_errors=False):
delegated_vars = result._result.get('_ansible_delegated_vars', None)
# Catch an exception
# This may never be called because default handler deletes
# the exception, since Ansible thinks it knows better
if 'exception' in result._result:
# Extract the error message and log it
error = result._result['exception'].strip().split('\n')[-1]
self.logger.append(error)
# Remove the exception from the result so it's not shown every time
del result._result['exception']
# Else log the reason for the failure
if result._task.loop and 'results' in result._result:
self._process_items(result) # item_on_failed, item_on_skipped, item_on_ok
else:
if delegated_vars:
self.logger.append(
"fatal: [%s -> %s]: FAILED! => %s" % (result._host.get_name(), delegated_vars['ansible_host'], self._dump_results(result._result)))
else:
self.logger.append("fatal: [%s]: FAILED! => %s" % (result._host.get_name(), self._dump_results(result._result)))
def v2_runner_on_ok(self, result):
self._clean_results(result._result, result._task.action)
delegated_vars = result._result.get('_ansible_delegated_vars', None)
if result._task.action == 'include':
return
elif result._result.get('changed', False):
if delegated_vars:
msg = "changed: [%s -> %s]" % (result._host.get_name(), delegated_vars['ansible_host'])
else:
msg = "changed: [%s]" % result._host.get_name()
else:
if delegated_vars:
msg = "ok: [%s -> %s]" % (result._host.get_name(), delegated_vars['ansible_host'])
else:
msg = "ok: [%s]" % result._host.get_name()
if result._task.loop and 'results' in result._result:
self._process_items(result) # item_on_failed, item_on_skipped, item_on_ok
else:
self.logger.append(msg)
def v2_runner_on_skipped(self, result):
if result._task.loop and 'results' in result._result:
self._process_items(result) # item_on_failed, item_on_skipped, item_on_ok
else:
msg = "skipping: [%s]" % result._host.get_name()
self.logger.append(msg)
def v2_runner_on_unreachable(self, result):
delegated_vars = result._result.get('_ansible_delegated_vars', None)
if delegated_vars:
self.logger.append(
"fatal: [%s -> %s]: UNREACHABLE! => %s" % (result._host.get_name(), delegated_vars['ansible_host'], self._dump_results(result._result)))
else:
self.logger.append("fatal: [%s]: UNREACHABLE! => %s" % (result._host.get_name(), self._dump_results(result._result)))
def v2_runner_on_no_hosts(self, task):
self.logger.append("skipping: no hosts matched")
def v2_playbook_on_task_start(self, task, is_conditional):
self.logger.append("TASK [%s]" % task.get_name().strip())
def v2_playbook_on_play_start(self, play):
name = play.get_name().strip()
if not name:
msg = "PLAY"
else:
msg = "PLAY [%s]" % name
self.logger.append(msg)
def v2_playbook_item_on_ok(self, result):
delegated_vars = result._result.get('_ansible_delegated_vars', None)
if result._task.action == 'include':
return
elif result._result.get('changed', False):
if delegated_vars:
msg = "changed: [%s -> %s]" % (result._host.get_name(), delegated_vars['ansible_host'])
else:
msg = "changed: [%s]" % result._host.get_name()
else:
if delegated_vars:
msg = "ok: [%s -> %s]" % (result._host.get_name(), delegated_vars['ansible_host'])
else:
msg = "ok: [%s]" % result._host.get_name()
msg += " => (item=%s)" % (result._result['item'])
self.logger.append(msg)
def v2_playbook_item_on_failed(self, result):
delegated_vars = result._result.get('_ansible_delegated_vars', None)
if 'exception' in result._result:
# Extract the error message and log it
error = result._result['exception'].strip().split('\n')[-1]
self.logger.append(error)
# Remove the exception from the result so it's not shown every time
del result._result['exception']
if delegated_vars:
self.logger.append("failed: [%s -> %s] => (item=%s) => %s" % (
result._host.get_name(), delegated_vars['ansible_host'], result._result['item'], self._dump_results(result._result)))
else:
self.logger.append("failed: [%s] => (item=%s) => %s" % (result._host.get_name(), result._result['item'], self._dump_results(result._result)))
def v2_playbook_item_on_skipped(self, result):
msg = "skipping: [%s] => (item=%s) " % (result._host.get_name(), result._result['item'])
self.logger.append(msg)
def v2_playbook_on_stats(self, stats):
run_time = datetime.now() - self.start_time
self.logger.runtime = run_time.seconds # returns an int, unlike run_time.total_seconds()
hosts = sorted(stats.processed.keys())
for h in hosts:
t = stats.summarize(h)
msg = "PLAY RECAP [%s] : %s %s %s %s %s" % (
h,
"ok: %s" % (t['ok']),
"changed: %s" % (t['changed']),
"unreachable: %s" % (t['unreachable']),
"skipped: %s" % (t['skipped']),
"failed: %s" % (t['failures']),
)
self.logger.append(msg)
def record_logs(self, user_id, success=False):
"""
Special callback added to this callback plugin
Called by Runner objet
:param user_id:
:return:
"""
# log_storage = Logs()
# return log_storage.save_log(user_id, self.logger.log, self.logger.runtime, success)
pass
class CliConfig:
def __init__(self, verbosity, aws_region, vault_pass=None):
self._verbosity = verbosity
self._aws_region = os.environ['AWS_REGION'] = aws_region
self._vault_pass = vault_pass or os.environ.get('VAULT_PASS')
@property
def aws_region(self):
return self._aws_region
@property
def verbosity(self):
return self._verbosity
@property
def vault_pass(self):
if not self._vault_pass:
self._vault_pass = os.environ['VAULT_PASS'] = click.prompt('Enter ansible-vault password', type=str, hide_input=True)
return self._vault_pass
def list_environments(region_name):
region = boto.ec2.get_region(region_name)
conn = VPCConnection(region=region)
return {vpc.tags.get('Environment', 'null') for vpc in conn.get_all_vpcs()}
_default_click_kwargs = dict(context_settings={'help_option_names': ['-h', '--help']})
@click.group(**_default_click_kwargs)
@click.version_option(version=__version__, message='%(prog)s : v%(version)s')
@click.option('-v', '--verbose', count=True, help="Increase verbosity of output produced by Ansible")
@click.option('-r', '--aws-region', envvar='AWS_REGION', default=DEFAULT_AWS_REGION, help="AWS Region to connect to (Default: %s)" % DEFAULT_AWS_REGION)
@click.pass_context
def cli(ctx, verbose, aws_region):
"""
CLI wrapper for performing configuration management and automation
tasks with Ansible. Use -h or --help after a command for detailed per-
command usage information.
"""
ctx.obj = CliConfig(aws_region=aws_region, verbosity=verbose)
def _require_valid_env_tag(ctx, param, value):
if value not in list_environments(ctx.obj.aws_region):
raise click.BadParameter('Invalid environment name')
return str(value).lower()
def _prompt_before_creating_new_env(ctx, param, value):
if value not in list_environments(ctx.obj.aws_region):
click.confirm("Do you really want to create a new AWS environment called '%s'?" % value)
return str(value).lower()
def _require_valid_queue_name(ctx, param, value):
if not value:
return QUEUE_NAMES
for v in value:
if v.lower() == 'all':
return QUEUE_NAMES
if v not in QUEUE_NAMES:
raise click.BadParameter('Invalid queue_name value: %s', v)
return list(value)
@cli.command(**_default_click_kwargs)
@click.argument('environment_name', callback=_prompt_before_creating_new_env)
@click.pass_obj
def provision(config, environment_name):
"""
Provision an AWS VPC environment
"""
return Runner(
hostnames='localhost',
playbook='../aws_provision_environment.yml',
connection='local',
verbosity=config.verbosity,
vault_pass=config.vault_pass,
var_overrides={
'ENVIRONMENT_NAME': environment_name,
'AWS_REGION': config.aws_region
}).run()
@cli.command(**_default_click_kwargs)
@click.argument('environment_name', callback=_require_valid_env_tag)
@click.pass_obj
def configure(config, environment_name):
"""
Configure AWS instances
"""
sudo_pass = click.prompt('Enter your sudo password', type=str, hide_input=True)
extra_vars = {
'ENVIRONMENT_NAME': environment_name,
'AWS_REGION': config.aws_region,
'ansible_ssh_pipelining': True
}
return Runner(hostnames='localhost',
playbook='../configure_environment.yml',
connection='ssh',
become_pass=sudo_pass,
vault_pass=config.vault_pass,
verbosity=config.verbosity,
var_overrides=extra_vars).run()
@cli.command(**_default_click_kwargs)
@click.argument('environment_name', callback=_require_valid_env_tag)
@click.argument('deployment_artifact', type=click.Path(exists=True))
@click.pass_obj
def deploy_python(config, environment_name, deployment_artifact):
"""
Deploy Python services
"""
sudo_pass = click.prompt('Enter your sudo password', type=str, hide_input=True)
return Runner(
hostnames='localhost',
playbook='../deploy_python.yml',
connection='ssh',
vault_pass=config.vault_pass,
become_pass=sudo_pass,
verbosity=config.verbosity,
var_overrides={
'ansible_ssh_pipelining': True,
'ENVIRONMENT_NAME': environment_name,
'AWS_REGION': config.aws_region,
'DEPLOYMENT_ARTIFACT': deployment_artifact
}).run()
@cli.command(**_default_click_kwargs)
@click.argument('environment_name', callback=_require_valid_env_tag)
@click.argument('hosts', default='all')
@click.pass_obj
def app_status(config, environment_name, hosts):
"""
Show # running of each process type
"""
return Runner(
hostnames=hosts,
playbook='../app_status.yml',
connection='local' if any(s == hosts.strip().lower() for s in ['localhost', '127.0.0.1']) else 'ssh',
verbosity=config.verbosity,
vault_pass=config.vault_pass,
var_overrides={
'ansible_ssh_pipelining': True,
'ENVIRONMENT_NAME': environment_name,
'AWS_REGION': config.aws_region
}).run()
@cli.command(**_default_click_kwargs)
@click.argument('environment_name', callback=_require_valid_env_tag)
@click.argument('queue_names', nargs=-1, callback=_require_valid_queue_name)
@click.pass_obj
def populate_jobs(config, environment_name, queue_names):
"""
(Re)populate beanstalkd job queue(s)
"""
sudo_pass = click.prompt('Enter your sudo password', type=str, hide_input=True)
results = []
for queue in queue_names:
click.secho('Populating "%s" queue' % queue, fg='green')
extra_vars = {
'ansible_ssh_pipelining': True,
'ENVIRONMENT_NAME': environment_name,
'AWS_REGION': config.aws_region,
'queue_name': queue
}
results.append(Runner(
hostnames='app1.%s' % environment_name,
playbook='../populate_queue.yml',
connection='ssh',
verbosity=config.verbosity,
vault_pass=config.vault_pass,
become_pass=sudo_pass,
var_overrides=extra_vars).run())
return results
@cli.command(**_default_click_kwargs)
@click.argument('environment_name', callback=_require_valid_env_tag)
@click.argument('action', type=click.Choice(['enable', 'disable']))
@click.pass_obj
def toggle_alarms(config, environment_name, action):
"""
Enable/disable CloudWatch alarms
"""
return Runner(
hostnames='localhost',
playbook='../cloudwatch_alarms_toggle.yml',
connection='local',
verbosity=config.verbosity,
vault_pass=config.vault_pass,
var_overrides={
'ENVIRONMENT_NAME': environment_name,
'AWS_REGION': config.aws_region,
'cw_alarm_action': action
}).run()
if __name__ == '__main__':
cli()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment