Skip to content

Instantly share code, notes, and snippets.

@robinedwards
Created January 27, 2020 09:51
Show Gist options
  • Save robinedwards/3f2ec4336e1ced084547d24d7e7ead3a to your computer and use it in GitHub Desktop.
Save robinedwards/3f2ec4336e1ced084547d24d7e7ead3a to your computer and use it in GitHub Desktop.
Selective retry operator
class SelectiveRetryPythonOperator(PythonOperator):
"""
Allows only retrying certain types of exception by altering max_tries
"""
def __init__(self,
retry_for=None,
func_args=None,
func_kwargs=None,
provide_execution_date=False,
provide_xcom=None,
provide_context=False,
**kwargs):
# Note we require context for the retry mechanism but we dont want to pollute all our functions
# by default so hence the separate attribute. There is probably a cleaner solution to this.
self.provide_context_to_callable = provide_context
kwargs['provide_context'] = True
self.retry_for = retry_for or ()
self.func_args = func_args or ()
assert isinstance(self.func_args, (tuple, list))
self.func_kwargs = func_kwargs or {}
self.provide_execution_date = provide_execution_date
self.provide_xcom = provide_xcom
super(SelectiveRetryPythonOperator, self).__init__(**kwargs)
def pre_execute_python_callable(self):
"""
Hook to be executed before the python callable is called.
"""
def retry_handler(self, exc):
"""
Hook to process an exception that has been raised during task execution
"""
@provide_session
def execute_callable(self, session=None):
task_instance = self.op_kwargs['task_instance']
logger.debug("Calling %s() with %s",
self.python_callable.__name__, pformat([self.func_args, self.func_kwargs]))
# See note on constructor
if self.provide_context_to_callable:
self.func_kwargs.update(self.op_kwargs)
try:
self.pre_execute_python_callable()
return self.python_callable(*self.func_args, **self.func_kwargs)
except Exception as e:
logger.info("Exception raised during python callable execution", exc_info=1)
if not self.retry_for or not isinstance(e, self.retry_for):
logger.info("Is not an expected exception won't retry the task")
task_instance.max_tries = 0
session.merge(task_instance)
# allow the handler to not count certain exception types
elif DO_NOT_COUNT_RETRY == self.retry_handler(e):
task_instance._try_number = max(task_instance._try_number - 1, 0)
logger.info("Not counting try(%s) for exception type %s", task_instance._try_number, e)
session.merge(task_instance)
raise
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment