Skip to content

Instantly share code, notes, and snippets.

@pseeth
Last active July 27, 2020 22:02
Show Gist options
  • Save pseeth/d34168daa1b757e3cc39073a3394a051 to your computer and use it in GitHub Desktop.
Save pseeth/d34168daa1b757e3cc39073a3394a051 to your computer and use it in GitHub Desktop.
A simple way to access function keyword arguments from the command line. Handy for command line configuration of experiments.
# Example of binding function keyword arguments to an
# ArgumentParser, then using scopes to initialize function
# with different arguments.
#
# Example output:
# ❯ python example.py --autoclip.percentile=100 --train/autoclip.percentile=1 --val/autoclip.percentile=5
# Before scoping
# ARGS={}
# Called autoclip with percentile=10.0
#
# Scoping
# ARGS={'autoclip.percentile': 100.0}
# Called autoclip with percentile=100.0
#
# Scoping train
# ARGS={'autoclip.percentile': 1.0}
# Called autoclip with percentile=1.0
#
# Scoping val
# ARGS={'autoclip.percentile': 5.0}
# Called autoclip with percentile=5.0
#
# Back to defaults
# ARGS={}
# Called autoclip with percentile=10.0
#
# BUILDING BLOCKS TO MAKE THIS WORK
# ---------------------------------
# Need a context manager and a function decorator.
#
import inspect
from contextlib import contextmanager
import argparse
PARSE_FUNCS = []
ARGS = {}
@contextmanager
def scope(parsed_args, pattern=''):
"""
Context manager to put parsed arguments into
a state.
"""
parsed_args = parsed_args.copy()
remove_keys = []
matched = {}
global ARGS
old_args = ARGS
for key in parsed_args:
if '/' in key:
if key.split('/')[0] == pattern:
matched[key.split('/')[-1]] = parsed_args[key]
remove_keys.append(key)
parsed_args.update(matched)
for key in remove_keys:
parsed_args.pop(key)
ARGS = parsed_args
yield
ARGS = old_args
def bind_to_parser(*patterns):
"""
Wrap the function so it looks in ARGS (managed
by the scope context manager) for keyword
arguments.
"""
def decorator(func):
PARSE_FUNCS.append((func, patterns))
def cmd_func(*args, **kwargs):
prefix = func.__name__
sig = inspect.signature(func)
cmd_kwargs = {}
for key, val in sig.parameters.items():
arg_type = val.annotation
arg_val = val.default
if arg_val is not inspect.Parameter.empty:
arg_name = f'{prefix}.{key}'
if arg_name in ARGS:
cmd_kwargs[key] = ARGS[arg_name]
kwargs.update(cmd_kwargs)
return func(*args, **kwargs)
return cmd_func
return decorator
def build_parser():
p = argparse.ArgumentParser()
f = p.add_argument_group(
title="Generated arguments from functions",
description='Specify arguments to functions.'
)
# Add kwargs from function to parser
for func, patterns in PARSE_FUNCS:
prefix = func.__name__
sig = inspect.signature(func)
for key, val in sig.parameters.items():
arg_type = val.annotation
arg_val = val.default
if arg_val is not inspect.Parameter.empty:
arg_names = []
arg_names.append(f'--{prefix}.{key}')
for pattern in patterns:
arg_names.append(f'--{pattern}/{prefix}.{key}')
for arg_name in arg_names:
f.add_argument(arg_name, type=arg_type,
default=arg_val)
return p
# FUNCTIONS YOU WANT TO ACCESS FROM COMMAND LINE
# ----------------------------------------------
# Decorate the function with bind_to_parser,
# adding it to PARSE_FUNCS. The argument
# parser inspects each function in PARSE_FUNCS
# and adds it to the argument flags. This
# functions arguments are available at:
#
# python example.py --autoclip.percentile=N
#
# The function arguments must be annotated with
# their type. Only keyword arguments are included
# in the ArgumentParser.
#
# You can optionally define additional patterns to match
# for different scopes. This will use the arguments
# given on that pattern when the scope is set to that
# pattern. The argument is available on command line at
# --pattern/func.kwarg. For example:
#
# python example.py \
# --autoclip.percentile=100
# --train/autoclip.percentile=1
# --val/autoclip.percentile=5
#
# With the corresponding code:
#
# >>> with scope(args):
# >>> autoclip() # prints 100
# >>> with scope(args, 'train'):
# >>> autoclip() # prints 1
# >>> with scope(args, 'val'):
# >>> autoclip() # prints 5
#
@bind_to_parser('train', 'val')
def autoclip(percentile : float = 10.0):
print(f'Called autoclip with percentile={percentile}')
pass
if __name__ == '__main__':
parser = build_parser()
args = vars(parser.parse_args())
# Uses default value
print("Before scoping")
print(f'ARGS={ARGS}')
autoclip()
# Uses value from parsed arguments
for _scope in ['', 'train', 'val']:
print(f"\nScoping {_scope}")
with scope(args, _scope):
print(f'ARGS={ARGS}')
autoclip()
print("\nBack to defaults")
print(f'ARGS={ARGS}')
autoclip()
@pseeth
Copy link
Author

pseeth commented Jul 27, 2020

Easy-to-use file to include in projects as needed, with
a few more bells and whistles (mainly reading in
docstrings as help text and grouping arguments based
on functions).

"""
Utilities for binding function arguments to 
the command line under different scopes as needed.
Read on to see how to bind a function.

HOW TO BIND A FUNCTION
----------------------
Decorate the function with bind_to_parser,
adding it to PARSE_FUNCS. The argument
parser inspects each function in PARSE_FUNCS
and adds it to the argument flags. For example:

>>> @bind_to_parser('train', 'val')
>>> def autoclip(percentile : float = 10.0):
        print(f'Called autoclip with percentile={percentile}')

This functions arguments are available at:

  python example.py --autoclip.percentile=N

The function arguments must be annotated with
their type. Only keyword arguments are included
in the ArgumentParser.

You can optionally define additional patterns to match
for different scopes. This will use the arguments
given on that pattern when the scope is set to that
pattern. The argument is available on command line at
--pattern/func.kwarg. The patterns used were 'train'
and 'val' so the additional arguments are also
available for binding:

  python example.py \ 
      --autoclip.percentile=100 
      --train/autoclip.percentile=1
      --val/autoclip.percentile=5

Use with the corresponding code:
  >>> # above this, parse the args
  >>> with scope(args):
  >>>     autoclip() # prints 100
  >>> with scope(args, 'train'):
  >>>     autoclip() # prints 1
  >>> with scope(args, 'val'):
  >>>     autoclip() # prints 5
"""

import inspect
from contextlib import contextmanager
import argparse

PARSE_FUNCS = []
ARGS = {}

@contextmanager
def scope(parsed_args, pattern=''):
    """
    Context manager to put parsed arguments into 
    a state.
    """
    parsed_args = parsed_args.copy()
    remove_keys = []
    matched = {}

    global ARGS
    old_args = ARGS
    for key in parsed_args:
        if '/' in key:
            if key.split('/')[0] == pattern:
                matched[key.split('/')[-1]] = parsed_args[key]
            remove_keys.append(key)
    
    parsed_args.update(matched)
    for key in remove_keys:
        parsed_args.pop(key)
    ARGS = parsed_args
    yield
    ARGS = old_args

def bind_to_parser(*patterns):
    """
    Wrap the function so it looks in ARGS (managed 
    by the scope context manager) for keyword 
    arguments.
    """

    def decorator(func):
        PARSE_FUNCS.append((func, patterns))
        def cmd_func(*args, **kwargs):
            prefix = func.__name__
            sig = inspect.signature(func)
            cmd_kwargs = {}

            for key, val in sig.parameters.items():
                arg_type = val.annotation
                arg_val = val.default
                if arg_val is not inspect.Parameter.empty:
                    arg_name = f'{prefix}.{key}'
                    if arg_name in ARGS:
                        cmd_kwargs[key] = ARGS[arg_name]

            kwargs.update(cmd_kwargs)
            return func(*args, **kwargs)
        return cmd_func
    
    return decorator

def parse_args():
    """
    Goes through all detected functions that are
    bound and adds them to the argument parser,
    along with their scopes. Then parses the
    command line and returns a dictionary.
    """
    p = argparse.ArgumentParser(
        formatter_class=argparse.RawDescriptionHelpFormatter
    )

    # Add kwargs from function to parser
    for func, patterns in PARSE_FUNCS:
        prefix = func.__name__

        desc = func.__doc__ 

        if patterns:
            desc += (
                f"""
    Additional scopes
    -----------------
    {', '.join(list(patterns))}
            """
            )

        f = p.add_argument_group(
            title=f"Generated arguments for function {prefix}",
            description=desc
        )

        sig = inspect.signature(func)
        for key, val in sig.parameters.items():
            arg_type = val.annotation
            arg_val = val.default

            if arg_val is not inspect.Parameter.empty:
                arg_names = []
                arg_names.append(f'--{prefix}.{key}')
                for pattern in patterns:
                    arg_names.append(f'--{pattern}/{prefix}.{key}')
                for arg_name in arg_names:
                    f.add_argument(arg_name, type=arg_type, 
                        default=arg_val)
    return vars(p.parse_args())

@pseeth
Copy link
Author

pseeth commented Jul 27, 2020

And a quick example, with the above functions imported from utils.py.

from .utils import bind_to_parser, parse_args, scope
import nussl

@bind_to_parser()
def signal(
    window_length : int = 256,
    hop_length : int = 64,
    window_type : str = 'sqrt_hann',
    sample_rate: int = 8000
):
    """
    Defines global AudioSignal parameters and
    builds STFTParams object.

    Parameters
    ----------
    window_length : int, optional
        Window length of STFT, by default 256
    hop_length : int, optional
        Hop length of STFT, by default 64
    window_type : str, optional
        Window type of STFT., by default 'sqrt_hann'
    sample_rate : int, optional
        Sampling rate, by default 8000

    Returns
    -------
    tuple
        Tuple of nussl.STFTParams and sample_rate.
    """
    return (
        nussl.STFTParams(window_length, hop_length, window_type), 
        sample_rate
    )

@bind_to_parser()
def transform(
    stft_params : nussl.STFTParams, 
    sample_rate : int,
    excerpt_length : float = 4.0
):
    """
    Builds transforms that get applied to
    training and validation datasets.

    Parameters
    ----------
    stft_params : nussl.STFTParams
        Parameters of STFT (see: signal).
    sample_rate : int
        Sample rate of audio signal (see: signal).
    excerpt_length : float, optional
        Length of excerpt in seconds, by default 4.0
    """

    pass

if __name__ == "__main__":
    args = parse_args()
    with scope(args):
        print(signal())

Generated CLI interface:

❯ python -m src.exp -h
usage: exp.py [-h] [--signal.window_length SIGNAL.WINDOW_LENGTH] [--signal.hop_length SIGNAL.HOP_LENGTH] [--signal.window_type SIGNAL.WINDOW_TYPE]
              [--signal.sample_rate SIGNAL.SAMPLE_RATE] [--transform.excerpt_length TRANSFORM.EXCERPT_LENGTH]

optional arguments:
  -h, --help            show this help message and exit

Generated arguments for function signal:

      Defines global AudioSignal parameters and
      builds STFTParams object.

      Parameters
      ----------
      window_length : int, optional
          Window length of STFT, by default 256
      hop_length : int, optional
          Hop length of STFT, by default 64
      window_type : str, optional
          Window type of STFT., by default 'sqrt_hann'
      sample_rate : int, optional
          Sampling rate, by default 8000

      Returns
      -------
      tuple
          Tuple of nussl.STFTParams and sample_rate.


  --signal.window_length SIGNAL.WINDOW_LENGTH
  --signal.hop_length SIGNAL.HOP_LENGTH
  --signal.window_type SIGNAL.WINDOW_TYPE
  --signal.sample_rate SIGNAL.SAMPLE_RATE

Generated arguments for function transform:

      Builds transforms that get applied to
      training and validation datasets.

      Parameters
      ----------
      stft_params : nussl.STFTParams
          Parameters of STFT (see: signal).
      sample_rate : int
          Sample rate of audio signal (see: signal).
      excerpt_length : float, optional
          Length of excerpt in seconds, by default 4.0


  --transform.excerpt_length TRANSFORM.EXCERPT_LENGTH

Usage:

❯ python -m src.exp --stft_params.window_length=512
STFTParams(window_length=512, hop_length=64, window_type='sqrt_hann')
❯ python -m src.exp # uses defaults
STFTParams(window_length=256, hop_length=64, window_type='sqrt_hann')

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment