Skip to content

Instantly share code, notes, and snippets.

@rrredbeard
Last active July 28, 2022 00:04
Show Gist options
  • Save rrredbeard/f41b0c609fb9b3b0739845f393416ed6 to your computer and use it in GitHub Desktop.
Save rrredbeard/f41b0c609fb9b3b0739845f393416ed6 to your computer and use it in GitHub Desktop.
Dynamic inventory script for Ansible with Vagrant
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Dynamic inventory script for Ansible with Vagrant.
"""
__author__ = "Alessandro Venditti"
__version__ = "0.1.0"
__license__ = "MIT"
import argparse
import json
import sys
import os
import subprocess
from symtable import Function
from typing import Dict, List
BOX_IP_PREFIX = os.getenv("BOX_IP_PREFIX", "192.168.56")
assert sys.version_info >= (3, 9)
# UTILS
def get_script_path():
return os.path.dirname(os.path.realpath(__file__))
def assert_dir_exist(d) -> None:
assert os.path.isdir(d), f"Expected directory: {d}"
def assert_file_exist(f) -> None:
assert os.path.isfile(f), f"Expected file: {f}"
# OBJECTS
class Counter(object):
def __init__(self) -> None:
self.__num = 0
def increment_and_get(self, delta: int = 1) -> int:
self.__num += delta
return self.__num
def get_and_increment(self, delta: int = 1) -> int:
out = self.__num
self.__num += delta
return out
class VagrantNodeTunnel(object):
# NodeTunnel
def __init__(self, name: str) -> None:
assert name
self.__name: str = name
self.__ssh: subprocess.Popen[str] = None
def __str__(self):
return f"<tunnel-to[{self.__name}]>"
def exec(self, *cmds) -> str:
assert self.__ssh == None
assert len(cmds) > 0
self.__ssh = subprocess.Popen(
f"cd {get_script_path()}; vagrant ssh {self.__name}",
shell=True,
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
universal_newlines=True,
bufsize=0,
)
# Send ssh commands to stdin
for cmd in cmds:
self.__ssh.stdin.write(f"{cmd}\n")
self.__ssh.stdin.flush()
out, err = self.__ssh.communicate()
if self.__ssh.returncode != None and self.__ssh.returncode > 0:
if not err:
err = os.strerror(self.__ssh.returncode)
raise OSError(self.__ssh.returncode, err, self.__name)
assert out != None
self.__ssh = None
return out.strip()
def log(self, *cmds) -> None:
assert len(cmds) > 0
for cmd in cmds:
try:
out = self.exec(cmd)
print(f"[{self.__name}]: {out}")
except OSError as e:
return print(
f"[{e.filename}]: FAIL '{cmd}' -> ({e.errno}) {e.strerror}"
)
class VagrantContextReader(object):
# VagrantContextReader
__vagrant_dir = ".vagrant/machines"
__private_key_name = "private_key"
__provider = "virtualbox"
def __init__(self, root_dir: str = get_script_path(), vars: Dict = {}) -> None:
self.__hosts: List = []
self.__machines_dir: str = os.path.join(root_dir, self.__vagrant_dir)
self.__default_vars: Dict = {
"ansible_user": "vagrant",
"ansible_port": 22,
"ansible_connection": "ssh",
}
self.__default_vars.update(vars)
# read ctx
self.__init_hosts()
def get_hosts(self) -> Dict:
return self.__hosts.copy()
def __read_machine_info(self, name: str) -> Dict:
tunnel = VagrantNodeTunnel(name)
machine_dir = os.path.join(self.__machines_dir, name)
assert_dir_exist(machine_dir)
p_key_file = os.path.join(machine_dir, self.__provider, self.__private_key_name)
node_ip = tunnel.exec(
"hostname -I | awk '{print $2}' | tr ' ' '\n' | grep " + BOX_IP_PREFIX
)
host_vars = self.__default_vars.copy()
host_vars.update(
{"ansible_host": node_ip, "ansible_ssh_private_key_file": p_key_file}
)
assert_file_exist(p_key_file)
return {
"name": f"{name}.local",
"vars": host_vars,
}
def __init_hosts(self) -> None:
assert_dir_exist(self.__machines_dir)
for entry in os.listdir(self.__machines_dir):
if os.path.isdir(os.path.join(self.__machines_dir, entry)):
self.__hosts.append(self.__read_machine_info(entry))
class SimpleInventory(object):
@classmethod
def from_current(cls):
self = cls.__new__(cls)
self.__init__(hosts=VagrantContextReader().get_hosts())
return self
# SimpleInventory
def __init__(self, hosts: List) -> None:
self.__inventory = {"_meta": {"hostvars": {}}}
self.__metadata = self.__inventory["_meta"]
# import hosts
self.__import_hosts(hosts, self.__metadata)
self.add_group("vagrant", lambda _: True)
def print(self, fn: Function = (lambda i: i)) -> None:
print(fn(self.__inventory))
def update_vars(self, fn: Function = (lambda h, v: {})):
hosts = self.__metadata["hostvars"]
for host in hosts:
value = hosts[host]
updates = fn(host, value)
if updates:
value.update(updates)
return self
def add_group(self, name: str, fn: Function = (lambda h: False)) -> Dict:
assert name
assert name != "_meta"
host_group = []
for host in self.__metadata["hostvars"].keys():
if fn(host):
host_group.append(host)
if len(host_group) > 0:
self.__inventory[name] = host_group
return self
def __import_hosts(self, host_list, group) -> None:
assert len(host_list) > 0
for host in host_list:
assert "name" in host
group["hostvars"][host["name"]] = (
host["vars"].copy() if ("vars" in host) else {}
)
######
def main():
counter = Counter()
ctx = VagrantContextReader(
root_dir=get_script_path(),
vars={
"ansible_user": "vagrant",
"ansible_port": 22,
"ansible_connection": "ssh",
"ansible_become": True,
},
)
inventory = SimpleInventory(hosts=ctx.get_hosts()).update_vars(
lambda h, _: {
"role": "server" if (counter.get_and_increment()) % 5 == 0 else "client"
}
)
return inventory.print(lambda i: json.dumps(obj=i, indent=4, sort_keys=True))
if __name__ == "__main__":
"""This is executed when run from the command line"""
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment