Skip to content

Instantly share code, notes, and snippets.

@medecau
Created October 22, 2020 02:54
Show Gist options
  • Save medecau/06a2d61908449796fea5f284cafb28ea to your computer and use it in GitHub Desktop.
Save medecau/06a2d61908449796fea5f284cafb28ea to your computer and use it in GitHub Desktop.
enforce type hinting at runtime with a decorator
from functools import wraps
from typing import get_type_hints
class InputError(Exception):
pass
def make_error_message(arg_name, arg_type, expected_type):
return f"{arg_name} value has type {arg_type} - must be {expected_type}"
def insist(func):
hints = get_type_hints(func)
pos_types = tuple(hints.values())
pos_args = tuple(hints.keys())
@wraps(func)
def wrapped(*args, **kwargs):
num_posargs = len(args)
posargs_hints = pos_types[:num_posargs]
for idx, arg in enumerate(args):
arg_name = pos_args[idx]
arg_type = type(arg)
expected_type = posargs_hints[idx]
type_mismatch = not isinstance(arg, expected_type)
if type_mismatch:
cast_as_expected = expected_type(arg)
if arg == cast_as_expected:
continue
error_message = make_error_message(arg_name, arg_type, expected_type)
raise InputError(error_message)
kwargs_keys = pos_args[num_posargs:]
for arg_name, arg_val in kwargs.items():
arg_type = type(arg_val)
expected_type = hints[arg_name]
type_mismatch = not isinstance(arg_val, expected_type)
if type_mismatch:
cast_as_expected = expected_type(arg_val)
if arg_val == cast_as_expected:
continue
error_message = make_error_message(arg_name, arg_type, expected_type)
raise InputError(error_message)
result = func(*args, **kwargs)
if "return" in hints:
arg_name = "return"
arg_type = type(result)
expected_type = hints["return"]
type_mismatch = not isinstance(result, expected_type)
if type_mismatch:
error_message = make_error_message(arg_name, arg_type, expected_type)
raise Exception(error_message)
return result
return wrapped
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment