Skip to content

Instantly share code, notes, and snippets.

@guyarad
Created September 25, 2016 06:06
Show Gist options
  • Save guyarad/c3f8709345b73d2f32bb2ffba322e30d to your computer and use it in GitHub Desktop.
Save guyarad/c3f8709345b73d2f32bb2ffba322e30d to your computer and use it in GitHub Desktop.
import wrapt
import functools
def for_all_methods(decorator):
"""
CLASS DECORATOR.
Based on http://stackoverflow.com/a/6307868/916568
When applied to a class, will automatically decorate all public
methods (ones that don't start with an underscore) using the
given ``decorator``.
IMPORTANT NOTE: this hasn't been tested to be production grade.
Please use with caution and for debugging only.
Args:
decorator: a decorator to use for decorating all class methods.
Returns:
A class (not an instance) where all its public methods are
decorated with given ``decorator``.
"""
def decorate(cls):
for attr in cls.__dict__: # there's probably a better way to do this
if not attr.startswith('_') and callable(getattr(cls, attr)):
setattr(cls, attr, decorator(getattr(cls, attr)))
return cls
return decorate
def verify_thread_safety(wrapped):
"""
METHOD DECORATOR.
When applied to a method, will record the current thread calling
the method and verifies that previous calls also originated from
the same thread.
When used on more than one method in a class, it will used the
same property to verify that all the decorated methods are being
accessed from a single thread.
Recommended to use in conjunction with ``for_all_methods``
decorator, which will automatically apply for all public methods
in a class.
Args:
wrapped: the method to wrap. It's expected to decorate a
method (of a class), rather than a free function
Returns:
a decorated method.
"""
@functools.wraps(wrapped)
def decorate(self, *args, **kwargs):
curr_thread = threading.current_thread()
if not hasattr(self, 'called_from_tid'):
self.called_from_tid = curr_thread.ident
self.called_from_name = curr_thread.name
assert curr_thread.ident == self.called_from_tid, "Method name is '{}'. First called from {}[{}]. Curr thread is {}[{}]".format(
wrapped.__name__, self.called_from_name, self.called_from_tid, curr_thread.name, curr_thread.ident)
return wrapped(self, *args, **kwargs)
return wrapt.synchronized(decorate)
def check_public_method_calls(wrapped):
"""
METHOD DECORATOR.
Meant to be used in conjunction with ``for_all_methods``
decorator, so it will be applied to all the public methods of the class.
When applied, will verify that public methods are not being called
down-the-stack from another public method.
The decorator will impose synchronized access to all public
methods of the class.
Args:
wrapped: the method to wrap. It's expected to decorate a
method (of a class), rather than a free function
Returns:
a decorated method.
"""
@functools.wraps(wrapped)
def decorate(self, *args, **kwargs):
with self._lock:
if not hasattr(self, 'method_called'):
self.method_called = False
assert not self.method_called, "Method was called recursively: " + wrapped.__name__
self.method_called = True
try:
return_value = wrapped(self, *args, **kwargs)
finally:
self.method_called = False
return return_value
return wrapt.synchronized(decorate)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment