Skip to content

Instantly share code, notes, and snippets.

@peter
Last active December 21, 2017 15:15
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save peter/503d47f239d592099d520cfb3331e6b0 to your computer and use it in GitHub Desktop.
Save peter/503d47f239d592099d520cfb3331e6b0 to your computer and use it in GitHub Desktop.
Python runtime type checks with decorator

Python runtime type checks with decorator

# This modules provides pre/post conditions for python functions via the @typeSpec decorator

from functools import reduce

def compact(d):
    return {k: v for k, v in d.items() if v != None}

def TypeOf(example):
    example_type = type(example)
    def check(v):
        if not isinstance(v, example_type):
            return 'must be of type {}'.format(example_type)
    return check

def Maybe(required_type):
    def check(v):
        if v == None:
            return None
        else:
            return required_type(v)
    return check

def Enum(*allowed_values):
    def check(v):
        if not v in set(allowed_values):
            return 'must be one of {}'.format(', '.join(allowed_values))
    return check

def NoneCheck(v):
    if v == None:
        return None
    else:
        return 'must be None'

def List(list_spec):
    def check(v):
        if not isinstance(v, list):
            return 'must be list'
        if len(list_spec) == 1:
            item_type = list_spec[0]
            for i, iv in enumerate(v):
                error_message = type_check(iv, item_type)
                if error_message:
                    return 'list index {} is invalid - {}'.format(i, error_message)
        elif len(list_spec) > 1:
            if len(list_spec) != len(v):
                return 'list has length {} but should be {}'.format(len(v), len(list_spec))
            for i, iv in enumerate(v):
                error_message = type_check(iv, list_spec[i])
                if error_message:
                    return 'list index {} is invalid - {}'.format(i, error_message)
    return check

def Dict(dict_spec):
    def check(v):
        if not isinstance(v, dict):
            return 'must be dict'
        errors = compact({k: type_check(v.get(k, None), t) for k, t in dict_spec.items()})
        return str(errors) if len(errors) > 0 else None
    return check

def AnyOf(*allowed_types):
    def check(v):
        for allowed_type in allowed_types:
            result = type_check(v, allowed_type)
            if result == None:
                return result
        return 'must be one of these types: {}'.format(', '.join(map(str, allowed_types)))
    return check

def AllOf(*required_types):
    def check(v):
        for required_type in required_types:
            result = type_check(v, required_type)
            if result != None:
                return result
    return check

def type_check(v, required_type):
    def check_callable(v, callable_type):
        result = callable_type(v)
        error_message = 'invalid (returned False)' if result == False else result
        if result != True and error_message:
            return error_message
    if type(required_type) == type:
        callable_type = getattr(required_type, 'type_check', None)
        if not isinstance(v, required_type):
            return 'needs to be of type {}'.format(required_type)
        elif callable_type:
            return check_callable(v, callable_type)
    elif required_type == None:
        return check_callable(v, NoneCheck)
    elif type(required_type) == list:
        return check_callable(v, List(required_type))
    elif type(required_type) == dict:
        return check_callable(v, Dict(required_type))
    elif callable(required_type):
        return check_callable(v, required_type)
    else:
        return check_callable(v, TypeOf(required_type))

def assert_type(v, required_type, message = ''):
    error_message = type_check(v, required_type)
    if error_message:
        raise ValueError(message + error_message)

class typeSpec(object):
    def __init__(self, *signature):
        if len(signature) < 1:
            raise ValueError('Signature is empty')
        self.arg_types = signature[:-1]
        self.return_type = signature[-1]

    def __call__(self, f):
        def with_type_check(*args, **kwargs):
            for i, arg in enumerate(args):
                assert_type(arg, self.arg_types[i], message='typeSpec: arg {} with value {} of type {} is invalid - '.format(i, arg, type(arg)))
            result = f(*args, **kwargs)
            assert_type(result, self.return_type, message='typeSpec: return type with value {} of type {} is invalid - '.format(result, type(result)))
            return result
        return with_type_check

Example Usage

from test.type_spec import typeSpec, Maybe, AllOf, AnyOf, assert_type

def Positive(v):
  if v <= 0:
    return 'must be positive'

def Number(v):
  if isinstance(v, int) or isinstance(v, float):
    return None
  else:
    return 'Must be number (int or float)'

class MyNumber:
  def __init__(self, n):
    self.n = n
  def __add__(self, other):
    return MyNumber(self.n + other.n)
  def type_check(my_number):
    return Number(my_number.n)

class SafeNumber:
  def __init__(self, n):
    self.n = n
    assert_type(n, Number)
  def __add__(self, other):
    return SafeNumber(self.n + other.n)

@typeSpec(int, int, int)
def add_ints(a, b):
  return a + b

@typeSpec(Number, lambda n: n > 0, Number)
def div_numbers(a, b):
  return a / b

@typeSpec(Number, AllOf(Number, Positive), Number)
def div_numbers2(a, b):
  return a / b

@typeSpec(Number, Number, Number)
def add(a, b):
  return a + b

@typeSpec(AnyOf(int, float), AnyOf(int, float), AnyOf(int, float))
def add2(a, b):
  return a + b

@typeSpec(int, int, int)
def faulty_add_ints(a, b):
  return 'foobar'

@typeSpec(MyNumber, MyNumber, int)
def add_my_numbers(a, b):
  return (a + b).n
add_my_numbers(MyNumber(5), MyNumber(5)) # => 10

number = MyNumber('foobar')
MyNumber.type_check(number) # => 'Must be number (int or float)'

@typeSpec(SafeNumber, SafeNumber, int)
def add_safe_numbers(a, b):
  return (a + b).n
add_safe_numbers(SafeNumber(5), SafeNumber(5)) # => 10
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment