Last active
October 26, 2019 23:26
-
-
Save jamesgregson/a227c10d8668b9b454c325acb3909526 to your computer and use it in GitHub Desktop.
Simple contract implementation with decorator
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import inspect | |
def contract( expects=[], ensures=[] ): | |
def func_wrapper( func ): | |
return Contract(function=func,preconditions=expects,postconditions=ensures ) | |
return func_wrapper | |
class Contract: | |
class Args: | |
def __init__( self, vars ): | |
for k,v in vars: | |
setattr(self,k,v) | |
class Error(Exception): | |
pass | |
enabled = True | |
def __init__( self, function, preconditions, postconditions ): | |
self._function = function | |
self._preconditions = preconditions | |
self._postconditions = postconditions | |
def __call__( self, *args, **kwargs ): | |
if not Contract.enabled: | |
return self._function(*args,**kwargs) | |
# extract parameters from args/kwargs apply | |
# defaults and define the input arguments | |
# for condition checking | |
sig = inspect.signature(self._function) | |
bnd = sig.bind(*args,**kwargs) | |
bnd.apply_defaults() | |
_args = Contract.Args(bnd.arguments.items()) | |
# get source file and line info | |
caller = inspect.getframeinfo(inspect.stack()[1][0]) | |
# precondition checks. | |
precond_errors = [] | |
for precond in self._preconditions: | |
if not precond(_args): | |
precond_errors.append( 'Precondition error, {} line {}: {}'.format(caller.filename,caller.lineno,inspect.getsource(precond).split(':',1)[1].strip()) ) | |
if precond_errors: | |
for msg in precond_errors: | |
print( msg ) | |
raise Contract.Error('Preconditions not met.') | |
# call the function | |
result = self._function( *args, **kwargs ) | |
# postcondition checks | |
postcond_errors = [] | |
for postcond in self._postconditions: | |
if not postcond(_args,result): | |
postcond_errors.append( 'Postcondition error, {} line {}: {}'.format(caller.filename,caller.lineno,inspect.getsource(postcond).split(':',1)[1].strip()) ) | |
if postcond_errors: | |
for msg in postcond_errors: | |
print( msg ) | |
raise Contract.Error('Postconditions not met') | |
return result |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from contract import contract, Contract | |
def name_check( args ): | |
return args.name == 'james' | |
@contract( | |
expects = [ name_check ], | |
ensures = [ lambda args, result: result == 'hello '+args.name ]) | |
def say_hello( name, **kwargs ): | |
for k,v in kwargs.items(): | |
print('{} = {}'.format(k,v)) | |
return 'hello {}'.format(name) | |
if __name__ == '__main__': | |
print('Contracts disabled...') | |
Contract.enabled = False | |
print( say_hello('Barry', values=5) ) | |
print('Contracts enabled...') | |
Contract.enabled = True | |
try: | |
print( say_hello('Barry') ) | |
except Contract.Error as e: | |
print( e ) | |
try: | |
print( say_hello('james',value=6) ) | |
except Contract.Error as e: | |
print( e ) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment