Skip to content

Instantly share code, notes, and snippets.

@pseeth
Last active July 27, 2020 22:02
Show Gist options
  • Save pseeth/d34168daa1b757e3cc39073a3394a051 to your computer and use it in GitHub Desktop.
Save pseeth/d34168daa1b757e3cc39073a3394a051 to your computer and use it in GitHub Desktop.
A simple way to access function keyword arguments from the command line. Handy for command line configuration of experiments.
# Example of binding function keyword arguments to an
# ArgumentParser, then using scopes to initialize function
# with different arguments.
#
# Example output:
# ❯ python example.py --autoclip.percentile=100 --train/autoclip.percentile=1 --val/autoclip.percentile=5
# Before scoping
# ARGS={}
# Called autoclip with percentile=10.0
#
# Scoping
# ARGS={'autoclip.percentile': 100.0}
# Called autoclip with percentile=100.0
#
# Scoping train
# ARGS={'autoclip.percentile': 1.0}
# Called autoclip with percentile=1.0
#
# Scoping val
# ARGS={'autoclip.percentile': 5.0}
# Called autoclip with percentile=5.0
#
# Back to defaults
# ARGS={}
# Called autoclip with percentile=10.0
#
# BUILDING BLOCKS TO MAKE THIS WORK
# ---------------------------------
# Need a context manager and a function decorator.
#
import inspect
from contextlib import contextmanager
import argparse
PARSE_FUNCS = []
ARGS = {}
@contextmanager
def scope(parsed_args, pattern=''):
"""
Context manager to put parsed arguments into
a state.
"""
parsed_args = parsed_args.copy()
remove_keys = []
matched = {}
global ARGS
old_args = ARGS
for key in parsed_args:
if '/' in key:
if key.split('/')[0] == pattern:
matched[key.split('/')[-1]] = parsed_args[key]
remove_keys.append(key)
parsed_args.update(matched)
for key in remove_keys:
parsed_args.pop(key)
ARGS = parsed_args
yield
ARGS = old_args
def bind_to_parser(*patterns):
"""
Wrap the function so it looks in ARGS (managed
by the scope context manager) for keyword
arguments.
"""
def decorator(func):
PARSE_FUNCS.append((func, patterns))
def cmd_func(*args, **kwargs):
prefix = func.__name__
sig = inspect.signature(func)
cmd_kwargs = {}
for key, val in sig.parameters.items():
arg_type = val.annotation
arg_val = val.default
if arg_val is not inspect.Parameter.empty:
arg_name = f'{prefix}.{key}'
if arg_name in ARGS:
cmd_kwargs[key] = ARGS[arg_name]
kwargs.update(cmd_kwargs)
return func(*args, **kwargs)
return cmd_func
return decorator
def build_parser():
p = argparse.ArgumentParser()
f = p.add_argument_group(
title="Generated arguments from functions",
description='Specify arguments to functions.'
)
# Add kwargs from function to parser
for func, patterns in PARSE_FUNCS:
prefix = func.__name__
sig = inspect.signature(func)
for key, val in sig.parameters.items():
arg_type = val.annotation
arg_val = val.default
if arg_val is not inspect.Parameter.empty:
arg_names = []
arg_names.append(f'--{prefix}.{key}')
for pattern in patterns:
arg_names.append(f'--{pattern}/{prefix}.{key}')
for arg_name in arg_names:
f.add_argument(arg_name, type=arg_type,
default=arg_val)
return p
# FUNCTIONS YOU WANT TO ACCESS FROM COMMAND LINE
# ----------------------------------------------
# Decorate the function with bind_to_parser,
# adding it to PARSE_FUNCS. The argument
# parser inspects each function in PARSE_FUNCS
# and adds it to the argument flags. This
# functions arguments are available at:
#
# python example.py --autoclip.percentile=N
#
# The function arguments must be annotated with
# their type. Only keyword arguments are included
# in the ArgumentParser.
#
# You can optionally define additional patterns to match
# for different scopes. This will use the arguments
# given on that pattern when the scope is set to that
# pattern. The argument is available on command line at
# --pattern/func.kwarg. For example:
#
# python example.py \
# --autoclip.percentile=100
# --train/autoclip.percentile=1
# --val/autoclip.percentile=5
#
# With the corresponding code:
#
# >>> with scope(args):
# >>> autoclip() # prints 100
# >>> with scope(args, 'train'):
# >>> autoclip() # prints 1
# >>> with scope(args, 'val'):
# >>> autoclip() # prints 5
#
@bind_to_parser('train', 'val')
def autoclip(percentile : float = 10.0):
print(f'Called autoclip with percentile={percentile}')
pass
if __name__ == '__main__':
parser = build_parser()
args = vars(parser.parse_args())
# Uses default value
print("Before scoping")
print(f'ARGS={ARGS}')
autoclip()
# Uses value from parsed arguments
for _scope in ['', 'train', 'val']:
print(f"\nScoping {_scope}")
with scope(args, _scope):
print(f'ARGS={ARGS}')
autoclip()
print("\nBack to defaults")
print(f'ARGS={ARGS}')
autoclip()
@pseeth
Copy link
Author

pseeth commented Jul 27, 2020

And a quick example, with the above functions imported from utils.py.

from .utils import bind_to_parser, parse_args, scope
import nussl

@bind_to_parser()
def signal(
    window_length : int = 256,
    hop_length : int = 64,
    window_type : str = 'sqrt_hann',
    sample_rate: int = 8000
):
    """
    Defines global AudioSignal parameters and
    builds STFTParams object.

    Parameters
    ----------
    window_length : int, optional
        Window length of STFT, by default 256
    hop_length : int, optional
        Hop length of STFT, by default 64
    window_type : str, optional
        Window type of STFT., by default 'sqrt_hann'
    sample_rate : int, optional
        Sampling rate, by default 8000

    Returns
    -------
    tuple
        Tuple of nussl.STFTParams and sample_rate.
    """
    return (
        nussl.STFTParams(window_length, hop_length, window_type), 
        sample_rate
    )

@bind_to_parser()
def transform(
    stft_params : nussl.STFTParams, 
    sample_rate : int,
    excerpt_length : float = 4.0
):
    """
    Builds transforms that get applied to
    training and validation datasets.

    Parameters
    ----------
    stft_params : nussl.STFTParams
        Parameters of STFT (see: signal).
    sample_rate : int
        Sample rate of audio signal (see: signal).
    excerpt_length : float, optional
        Length of excerpt in seconds, by default 4.0
    """

    pass

if __name__ == "__main__":
    args = parse_args()
    with scope(args):
        print(signal())

Generated CLI interface:

❯ python -m src.exp -h
usage: exp.py [-h] [--signal.window_length SIGNAL.WINDOW_LENGTH] [--signal.hop_length SIGNAL.HOP_LENGTH] [--signal.window_type SIGNAL.WINDOW_TYPE]
              [--signal.sample_rate SIGNAL.SAMPLE_RATE] [--transform.excerpt_length TRANSFORM.EXCERPT_LENGTH]

optional arguments:
  -h, --help            show this help message and exit

Generated arguments for function signal:

      Defines global AudioSignal parameters and
      builds STFTParams object.

      Parameters
      ----------
      window_length : int, optional
          Window length of STFT, by default 256
      hop_length : int, optional
          Hop length of STFT, by default 64
      window_type : str, optional
          Window type of STFT., by default 'sqrt_hann'
      sample_rate : int, optional
          Sampling rate, by default 8000

      Returns
      -------
      tuple
          Tuple of nussl.STFTParams and sample_rate.


  --signal.window_length SIGNAL.WINDOW_LENGTH
  --signal.hop_length SIGNAL.HOP_LENGTH
  --signal.window_type SIGNAL.WINDOW_TYPE
  --signal.sample_rate SIGNAL.SAMPLE_RATE

Generated arguments for function transform:

      Builds transforms that get applied to
      training and validation datasets.

      Parameters
      ----------
      stft_params : nussl.STFTParams
          Parameters of STFT (see: signal).
      sample_rate : int
          Sample rate of audio signal (see: signal).
      excerpt_length : float, optional
          Length of excerpt in seconds, by default 4.0


  --transform.excerpt_length TRANSFORM.EXCERPT_LENGTH

Usage:

❯ python -m src.exp --stft_params.window_length=512
STFTParams(window_length=512, hop_length=64, window_type='sqrt_hann')
❯ python -m src.exp # uses defaults
STFTParams(window_length=256, hop_length=64, window_type='sqrt_hann')

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment