Skip to content

Instantly share code, notes, and snippets.

@zyd14
Created October 7, 2020 23:57
Show Gist options
  • Save zyd14/47f67d9181167f908e4d57ded4401ea5 to your computer and use it in GitHub Desktop.
Save zyd14/47f67d9181167f908e4d57ded4401ea5 to your computer and use it in GitHub Desktop.
Module which provides functionality for executing a list of tasks stored in a TaskMap object as remote lambda invocations
from time import sleep
from zappa.asynchronous import task
from zappa import asynchronous
# For some reason zappa isn't picking up the table name from the zappa settings automatically so I'm just setting
# it manually for now
asynchronous.ASYNC_RESPONSE_TABLE = 'phils_done_tasks'
# This wrapper provided by the zappa package enables this function to be run on a remote lambda invocation.
# The response will be stored in DynamoDB, with a 400kb limit on response item size. If larger responses are expected
# you can manually store the response in S3 and return the S3 reference instead of the actual data, then retrieve the
# data from S3 when you unpack the responses returned by the remote_manager.manage() function.
# Arguments to the function wrapped by @task must also be JSON-serializable, and must be under the 256KB limit imposed
# by AWS on payloads for asynchronous lambda invocations. Again, if you want to feed more than 256KB data
# to the remote lambda invocation, send it an S3 reference for it to retrieve its data from.
# Last limitation is that the lambda function you are executing the code on needs to have already been deployed and have
# all the libraries required by this task already installed
@task(remote_aws_lambda_function_name='remote-phil-dev', remote_aws_region='us-west-2', capture_response=True)
def longrunner(sleep_length: int = 5):
sleep(sleep_length)
return {'Message': f'it took {sleep_length} seconds to generate this'}
import os
from time import sleep
from typing import Any, Callable, Dict, List, Tuple, Union
from zappa.asynchronous import get_async_response
class TaskMap:
""" Object to hold tasks to be executed on a remote lambda invocation"""
def __init__(self):
self._task_map = {} # type: Dict[Callable, List[Tuple[List[Any], Dict[str, Any]]]]
def __iter__(self):
for t in self._task_map:
for arg_set in self._task_map[t]:
yield t, arg_set
def __len__(self):
return len(self._task_map.keys())
def add_task(self, task: Callable, args: List[Any] = None, kwargs: Dict[str, Any] = None):
"""
Args:
task: function to be executed remotely
args: list of positional arguments to be fed to task function
kwargs: dictionary of keyword arguments to be fed to task function
"""
if args is None:
args = []
if kwargs is None:
kwargs = {}
if task not in self._task_map:
self._task_map.update({task: [(args, kwargs)]})
else:
self._task_map[task].append((args, kwargs))
def manage(task_map: TaskMap, callback: Callable = None) -> Union[List[dict], None]:
""" Main method for executing tasks in a TaskMap as remote lambda functions, collecting their responses and
returning them to the caller.
Args:
task_map: a TaskMap object holding tasks to be executed asynchronously
callback: an optional function to execute on the collected responses returned by the remote tasks
"""
response_ids = []
if len(task_map) < 1:
print('No tasks added')
return
# Execute tasks in task map
for t, args in task_map:
response_ids.append(t(*args[0], **args[1]).response_id)
total_wait = 0
# Initial wait period gives the lambdas a period of time to get going before looking for their results
initial_wait = int(os.getenv('INITIAL_WAIT_SECONDS', 5))
sleep(initial_wait)
total_wait += initial_wait
max_total_wait = int(os.getenv('MAX_TOTAL_WAIT', 900))
if max_total_wait > 900:
print('MAX WAIT set to more than 15 minutes (900 seconds) - remote lambda workers can only execute for a '
'maximum of 15 minutes, so it is likely that they will start timing out after that time period. If '
'expecting remote workers to execute for longer than 15 minutes, consider using a Fargate or Batch '
'solution instead')
num_tasks = len(response_ids)
if response_ids:
response_datas = profit(response_ids, total_wait, max_total_wait, num_tasks)
if len(response_datas) != num_tasks:
print('Not all responses collected from tasks, return data will likely be incomplete')
# Execute user-provided callback function on collected response data. Might be useful to perform some kind of
# transformation or cataloging of response data
if callback:
return callback(response_datas)
else:
errors = check_for_errors(response_datas)
if errors:
print('Some tasks responded with "N/A", which often means the task threw an unhandled exception')
return response_datas
else:
print('No tasks created')
def check_for_errors(response_datas: List[dict]):
errors = []
for r in response_datas:
if r.get('response') == 'N/A':
errors.append(r)
return errors
def profit(response_ids: List[str], total_wait: int, max_total_wait: int, num_tasks: int):
""" Attempt to collect responses from tasks using response IDs returned when launching async lambda functions.
Args:
response_ids: List of response IDs corresponding to an item in DynamoDB where the response from a remote task
will be stored
total_wait: The amount of time that the manager has already waited for tasks to be finished
max_total_wait: The maximum time the manager will wait for all responses to be collected before timing out and
returning what it has
num_tasks: Number of tasks being executed
Returns:
"""
loop_wait = int(os.getenv('LOOP_WAIT_SECONDS', 15))
response_datas = []
response_ids_collected = []
num_responses_collected = 0
# While there are still response_ids to collect and time hasn't maxed out, keep trying to get response data from
# DynamoDB
while len(response_ids) > 0 and total_wait < max_total_wait:
response_datas, response_ids, num_responses_collected = try_getting_responses(response_ids, response_datas, response_ids_collected, num_responses_collected)
# Not all responses collected yet, sleep for a user-specified amount of time before trying again
if num_responses_collected != num_tasks:
print("Didn't get all responses, going to sleep for a bit")
sleep(loop_wait)
total_wait += loop_wait
else:
# All responses gathered, return them to caller
return response_datas
if num_responses_collected == num_tasks:
print('got all responses!')
return response_datas
elif total_wait >= max_total_wait:
print('Timed out, returning what responses were collected but data is likely to be incomplete')
return response_datas
def try_getting_responses(response_ids: List[str],
response_datas: List[dict],
num_responses_collected: int):
""" Iterate over task response_ids, using each ID to look up an item in DynamoDB which will store the result of an
individual lambda task when it has completed. If a response is found for a task, remove its ID from the list so
it is not checked for again.
"""
response_ids_collected = []
for r in response_ids:
response = get_async_response(r)
if response is not None:
# if lambda is still running then just print a message and don't do anything with that ID
if response.get('status') == 'in progress' and response.get('response') == 'N/A':
print(f'lambda {r} still going, check back later')
else:
response_datas.append(response)
response_ids_collected.append(r)
num_responses_collected += 1
# Remove response_ids that have been collected from responses we're still trying to get
for r in response_ids_collected:
try:
response_ids.remove(r)
except ValueError:
pass
return response_datas, response_ids, num_responses_collected
from func_to_execute_remote import longrunner
from remote_manager import TaskMap, manage
def run_trial():
task_map = TaskMap()
# Add 75 test tasks with varying sleep lengths
for i in range(75):
task_map.add_task(longrunner, [randint(1, 25)])
results = manage(task_map)
print(results)
if __name__ == '__main__':
run_trial()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment