Last active
March 7, 2020 00:06
-
-
Save thomasjpfan/8288750cf4a8d2d8dde495a0e0e4b542 to your computer and use it in GitHub Desktop.
Do not use this anywhere
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from abc import ABC, abstractmethod | |
from inspect import signature | |
from functools import wraps | |
from numbers import Integral | |
from typing import Optional | |
def _isinstance(item, type_to_check): | |
if hasattr(type_to_check, '__origin__'): | |
if type_to_check.__origin__ == typing.Union: | |
return isinstance(item, type_to_check.__args__) | |
return isinstance(item, type_to_check.__origin__) | |
return isinstance(item, type_to_check) | |
def _ensure_types(func, abstract_signature): | |
_func_sig = signature(func) | |
abstract_parameters = abstract_signature.parameters | |
@wraps(func) | |
def wrapper(*arg, **kwargs): | |
# check input | |
for i, p in enumerate(_func_sig.parameters.keys()): | |
if not _isinstance(arg[i], abstract_parameters[p].annotation): | |
raise TypeError("INPUT IS WRONG") | |
output = func(*arg, **kwargs) | |
# check output | |
if not _isinstance(output, abstract_signature.return_annotation): | |
raise TypeError("OUTPUT IS WRONG") | |
return output | |
return wrapper | |
class _AbstractCheck: | |
def __new__(cls, *args, **kwargs): | |
instance = super().__new__(cls, *args, **kwargs) | |
# get abstract methods | |
abs_methods = ((_cls, _cls.__abstractmethods__) for _cls in cls.mro() if hasattr(_cls, '__abstractmethods__')) | |
# fiters out claseses without abstract methods | |
abs_methods = [item for item in abs_methods if item[1]] | |
abs_methods = list(chain(abs_methods)) | |
# method_set are frozensets of strings (method names) | |
for abs_cls, method_set in abs_methods: | |
for abs_method_str in method_set: | |
child_method = getattr(instance, abs_method_str) | |
abstract_method = getattr(abs_cls, abs_method_str) | |
abstract_sign = signature(abstract_method) | |
check_method = _ensure_types(child_method, abstract_sign) | |
setattr(instance, abs_method_str, check_method) | |
return instance | |
class A(_AbstractCheck, ABC): | |
def do_something(self, x): | |
self._hello(x) | |
@abstractmethod | |
def _hello(self, x: dict) -> typing.Union[bool, None]: | |
... | |
class B(A): | |
def _hello(self, x): | |
return |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment