Created
October 7, 2020 23:57
-
-
Save zyd14/47f67d9181167f908e4d57ded4401ea5 to your computer and use it in GitHub Desktop.
Module which provides functionality for executing a list of tasks stored in a TaskMap object as remote lambda invocations
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from time import sleep | |
from zappa.asynchronous import task | |
from zappa import asynchronous | |
# For some reason zappa isn't picking up the table name from the zappa settings automatically so I'm just setting | |
# it manually for now | |
asynchronous.ASYNC_RESPONSE_TABLE = 'phils_done_tasks' | |
# This wrapper provided by the zappa package enables this function to be run on a remote lambda invocation. | |
# The response will be stored in DynamoDB, with a 400kb limit on response item size. If larger responses are expected | |
# you can manually store the response in S3 and return the S3 reference instead of the actual data, then retrieve the | |
# data from S3 when you unpack the responses returned by the remote_manager.manage() function. | |
# Arguments to the function wrapped by @task must also be JSON-serializable, and must be under the 256KB limit imposed | |
# by AWS on payloads for asynchronous lambda invocations. Again, if you want to feed more than 256KB data | |
# to the remote lambda invocation, send it an S3 reference for it to retrieve its data from. | |
# Last limitation is that the lambda function you are executing the code on needs to have already been deployed and have | |
# all the libraries required by this task already installed | |
@task(remote_aws_lambda_function_name='remote-phil-dev', remote_aws_region='us-west-2', capture_response=True) | |
def longrunner(sleep_length: int = 5): | |
sleep(sleep_length) | |
return {'Message': f'it took {sleep_length} seconds to generate this'} | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
from time import sleep | |
from typing import Any, Callable, Dict, List, Tuple, Union | |
from zappa.asynchronous import get_async_response | |
class TaskMap: | |
""" Object to hold tasks to be executed on a remote lambda invocation""" | |
def __init__(self): | |
self._task_map = {} # type: Dict[Callable, List[Tuple[List[Any], Dict[str, Any]]]] | |
def __iter__(self): | |
for t in self._task_map: | |
for arg_set in self._task_map[t]: | |
yield t, arg_set | |
def __len__(self): | |
return len(self._task_map.keys()) | |
def add_task(self, task: Callable, args: List[Any] = None, kwargs: Dict[str, Any] = None): | |
""" | |
Args: | |
task: function to be executed remotely | |
args: list of positional arguments to be fed to task function | |
kwargs: dictionary of keyword arguments to be fed to task function | |
""" | |
if args is None: | |
args = [] | |
if kwargs is None: | |
kwargs = {} | |
if task not in self._task_map: | |
self._task_map.update({task: [(args, kwargs)]}) | |
else: | |
self._task_map[task].append((args, kwargs)) | |
def manage(task_map: TaskMap, callback: Callable = None) -> Union[List[dict], None]: | |
""" Main method for executing tasks in a TaskMap as remote lambda functions, collecting their responses and | |
returning them to the caller. | |
Args: | |
task_map: a TaskMap object holding tasks to be executed asynchronously | |
callback: an optional function to execute on the collected responses returned by the remote tasks | |
""" | |
response_ids = [] | |
if len(task_map) < 1: | |
print('No tasks added') | |
return | |
# Execute tasks in task map | |
for t, args in task_map: | |
response_ids.append(t(*args[0], **args[1]).response_id) | |
total_wait = 0 | |
# Initial wait period gives the lambdas a period of time to get going before looking for their results | |
initial_wait = int(os.getenv('INITIAL_WAIT_SECONDS', 5)) | |
sleep(initial_wait) | |
total_wait += initial_wait | |
max_total_wait = int(os.getenv('MAX_TOTAL_WAIT', 900)) | |
if max_total_wait > 900: | |
print('MAX WAIT set to more than 15 minutes (900 seconds) - remote lambda workers can only execute for a ' | |
'maximum of 15 minutes, so it is likely that they will start timing out after that time period. If ' | |
'expecting remote workers to execute for longer than 15 minutes, consider using a Fargate or Batch ' | |
'solution instead') | |
num_tasks = len(response_ids) | |
if response_ids: | |
response_datas = profit(response_ids, total_wait, max_total_wait, num_tasks) | |
if len(response_datas) != num_tasks: | |
print('Not all responses collected from tasks, return data will likely be incomplete') | |
# Execute user-provided callback function on collected response data. Might be useful to perform some kind of | |
# transformation or cataloging of response data | |
if callback: | |
return callback(response_datas) | |
else: | |
errors = check_for_errors(response_datas) | |
if errors: | |
print('Some tasks responded with "N/A", which often means the task threw an unhandled exception') | |
return response_datas | |
else: | |
print('No tasks created') | |
def check_for_errors(response_datas: List[dict]): | |
errors = [] | |
for r in response_datas: | |
if r.get('response') == 'N/A': | |
errors.append(r) | |
return errors | |
def profit(response_ids: List[str], total_wait: int, max_total_wait: int, num_tasks: int): | |
""" Attempt to collect responses from tasks using response IDs returned when launching async lambda functions. | |
Args: | |
response_ids: List of response IDs corresponding to an item in DynamoDB where the response from a remote task | |
will be stored | |
total_wait: The amount of time that the manager has already waited for tasks to be finished | |
max_total_wait: The maximum time the manager will wait for all responses to be collected before timing out and | |
returning what it has | |
num_tasks: Number of tasks being executed | |
Returns: | |
""" | |
loop_wait = int(os.getenv('LOOP_WAIT_SECONDS', 15)) | |
response_datas = [] | |
response_ids_collected = [] | |
num_responses_collected = 0 | |
# While there are still response_ids to collect and time hasn't maxed out, keep trying to get response data from | |
# DynamoDB | |
while len(response_ids) > 0 and total_wait < max_total_wait: | |
response_datas, response_ids, num_responses_collected = try_getting_responses(response_ids, response_datas, response_ids_collected, num_responses_collected) | |
# Not all responses collected yet, sleep for a user-specified amount of time before trying again | |
if num_responses_collected != num_tasks: | |
print("Didn't get all responses, going to sleep for a bit") | |
sleep(loop_wait) | |
total_wait += loop_wait | |
else: | |
# All responses gathered, return them to caller | |
return response_datas | |
if num_responses_collected == num_tasks: | |
print('got all responses!') | |
return response_datas | |
elif total_wait >= max_total_wait: | |
print('Timed out, returning what responses were collected but data is likely to be incomplete') | |
return response_datas | |
def try_getting_responses(response_ids: List[str], | |
response_datas: List[dict], | |
num_responses_collected: int): | |
""" Iterate over task response_ids, using each ID to look up an item in DynamoDB which will store the result of an | |
individual lambda task when it has completed. If a response is found for a task, remove its ID from the list so | |
it is not checked for again. | |
""" | |
response_ids_collected = [] | |
for r in response_ids: | |
response = get_async_response(r) | |
if response is not None: | |
# if lambda is still running then just print a message and don't do anything with that ID | |
if response.get('status') == 'in progress' and response.get('response') == 'N/A': | |
print(f'lambda {r} still going, check back later') | |
else: | |
response_datas.append(response) | |
response_ids_collected.append(r) | |
num_responses_collected += 1 | |
# Remove response_ids that have been collected from responses we're still trying to get | |
for r in response_ids_collected: | |
try: | |
response_ids.remove(r) | |
except ValueError: | |
pass | |
return response_datas, response_ids, num_responses_collected |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from func_to_execute_remote import longrunner | |
from remote_manager import TaskMap, manage | |
def run_trial(): | |
task_map = TaskMap() | |
# Add 75 test tasks with varying sleep lengths | |
for i in range(75): | |
task_map.add_task(longrunner, [randint(1, 25)]) | |
results = manage(task_map) | |
print(results) | |
if __name__ == '__main__': | |
run_trial() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment