Skip to content

Instantly share code, notes, and snippets.

@jamesgregson
Last active October 26, 2019 23:26
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jamesgregson/a227c10d8668b9b454c325acb3909526 to your computer and use it in GitHub Desktop.
Save jamesgregson/a227c10d8668b9b454c325acb3909526 to your computer and use it in GitHub Desktop.
Simple contract implementation with decorator
import inspect
def contract( expects=[], ensures=[] ):
def func_wrapper( func ):
return Contract(function=func,preconditions=expects,postconditions=ensures )
return func_wrapper
class Contract:
class Args:
def __init__( self, vars ):
for k,v in vars:
setattr(self,k,v)
class Error(Exception):
pass
enabled = True
def __init__( self, function, preconditions, postconditions ):
self._function = function
self._preconditions = preconditions
self._postconditions = postconditions
def __call__( self, *args, **kwargs ):
if not Contract.enabled:
return self._function(*args,**kwargs)
# extract parameters from args/kwargs apply
# defaults and define the input arguments
# for condition checking
sig = inspect.signature(self._function)
bnd = sig.bind(*args,**kwargs)
bnd.apply_defaults()
_args = Contract.Args(bnd.arguments.items())
# get source file and line info
caller = inspect.getframeinfo(inspect.stack()[1][0])
# precondition checks.
precond_errors = []
for precond in self._preconditions:
if not precond(_args):
precond_errors.append( 'Precondition error, {} line {}: {}'.format(caller.filename,caller.lineno,inspect.getsource(precond).split(':',1)[1].strip()) )
if precond_errors:
for msg in precond_errors:
print( msg )
raise Contract.Error('Preconditions not met.')
# call the function
result = self._function( *args, **kwargs )
# postcondition checks
postcond_errors = []
for postcond in self._postconditions:
if not postcond(_args,result):
postcond_errors.append( 'Postcondition error, {} line {}: {}'.format(caller.filename,caller.lineno,inspect.getsource(postcond).split(':',1)[1].strip()) )
if postcond_errors:
for msg in postcond_errors:
print( msg )
raise Contract.Error('Postconditions not met')
return result
from contract import contract, Contract
def name_check( args ):
return args.name == 'james'
@contract(
expects = [ name_check ],
ensures = [ lambda args, result: result == 'hello '+args.name ])
def say_hello( name, **kwargs ):
for k,v in kwargs.items():
print('{} = {}'.format(k,v))
return 'hello {}'.format(name)
if __name__ == '__main__':
print('Contracts disabled...')
Contract.enabled = False
print( say_hello('Barry', values=5) )
print('Contracts enabled...')
Contract.enabled = True
try:
print( say_hello('Barry') )
except Contract.Error as e:
print( e )
try:
print( say_hello('james',value=6) )
except Contract.Error as e:
print( e )
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment