Created
August 26, 2019 17:20
-
-
Save relwell/3eb9d68526ba8eadb41b3b1399e747a9 to your computer and use it in GitHub Desktop.
zappa async fix
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
THIS IS BASICALLY A MONKEY-PATCH OF BROKEN BEHAVIOR IN ZAPPA | |
Zappa sends bytes to boto instead of a string. | |
Zappa Async Tasks | |
Example: | |
``` | |
from zappa.async import task | |
@task(service='sns') | |
def my_async_func(*args, **kwargs): | |
dosomething() | |
``` | |
For SNS, you can also pass an `arn` argument to task() which will specify which SNS path to send it to. | |
Without `service='sns'`, the default service is 'lambda' which will call the method in an asynchronous | |
lambda call. | |
The following restrictions apply: | |
* function must have a clean import path -- i.e. no closures, lambdas, or methods. | |
* args and kwargs must be JSON-serializable. | |
* The JSON-serialized form must be within the size limits for Lambda (128K) or SNS (256K) events. | |
Discussion of this comes from: | |
https://github.com/Miserlou/Zappa/issues/61 | |
https://github.com/Miserlou/Zappa/issues/603 | |
https://github.com/Miserlou/Zappa/pull/694 | |
https://github.com/Miserlou/Zappa/pull/732 | |
https://github.com/Miserlou/Zappa/issues/840 | |
## Full lifetime of an asynchronous dispatch: | |
1. In a file called `foo.py`, there is the following code: | |
``` | |
from zappa.async import task | |
@task | |
def my_async_func(*args, **kwargs): | |
return sum(args) | |
``` | |
2. The decorator desugars to: | |
`my_async_func = task(my_async_func)` | |
3. Somewhere else, the code runs: | |
`res = my_async_func(1,2)` | |
really calls task's `_run_async(1,2)` | |
with `func` equal to the original `my_async_func` | |
If we are running in Lambda, this runs: | |
LambdaAsyncResponse().send('foo.my_async_func', (1,2), {}) | |
and returns the LambdaAsyncResponse instance to the local | |
context. That local context, can, e.g. test for `res.sent` | |
to confirm it was dispatched correctly. | |
4. LambdaAsyncResponse.send invoked the currently running | |
AWS Lambda instance with the json message: | |
``` | |
{ "command": "zappa.async.route_lambda_task", | |
"task_path": "foo.my_async_func", | |
"args": [1,2], | |
"kwargs": {} | |
} | |
``` | |
5. The new lambda instance is invoked with the message above, | |
and Zappa runs its usual bootstrapping context, and inside | |
zappa.handler, the existence of the 'command' key in the message | |
dispatches the full message to zappa.async.route_lambda_task, which | |
in turn calls `run_message(message)` | |
6. `run_message` loads the task_path value to load the `func` from `foo.py`. | |
We should note that my_async_func is wrapped by @task in this new | |
context, as well. However, @task also decorated `my_async_func.sync()` | |
to run the original function synchronously. | |
`run_message` duck-types the method and finds the `.sync` attribute | |
and runs that instead -- thus we do not infinitely dispatch. | |
If `my_async_func` had code to dispatch other functions inside its | |
synchronous portions (or even call itself recursively), those *would* | |
be dispatched asynchronously, unless, of course, they were called | |
by: `my_async_func.sync(1,2)` in which case it would run synchronously | |
and in the current lambda function. | |
""" | |
import boto3 | |
import botocore | |
from functools import update_wrapper, wraps | |
import importlib | |
import inspect | |
import json | |
import os | |
import uuid | |
import time | |
from zappa.utilities import get_topic_name | |
try: | |
from zappa_settings import ASYNC_RESPONSE_TABLE | |
except ImportError: | |
ASYNC_RESPONSE_TABLE = None | |
# Declare these here so they're kept warm. | |
try: | |
aws_session = boto3.Session() | |
LAMBDA_CLIENT = aws_session.client("lambda") | |
SNS_CLIENT = aws_session.client("sns") | |
STS_CLIENT = aws_session.client("sts") | |
DYNAMODB_CLIENT = aws_session.client("dynamodb") | |
except botocore.exceptions.NoRegionError as e: # pragma: no cover | |
# This can happen while testing on Travis, but it's taken care of | |
# during class initialization. | |
pass | |
## | |
# Response and Exception classes | |
## | |
class AsyncException(Exception): # pragma: no cover | |
""" Simple exception class for async tasks. """ | |
pass | |
class LambdaAsyncResponse(object): | |
""" | |
Base Response Dispatcher class | |
Can be used directly or subclassed if the method to send the message is changed. | |
""" | |
def __init__( | |
self, | |
lambda_function_name=None, | |
aws_region=None, | |
capture_response=False, | |
**kwargs | |
): | |
""" """ | |
if kwargs.get("boto_session"): | |
self.client = kwargs.get("boto_session").client("lambda") | |
else: # pragma: no cover | |
self.client = LAMBDA_CLIENT | |
self.lambda_function_name = lambda_function_name | |
self.aws_region = aws_region | |
if capture_response: | |
if ASYNC_RESPONSE_TABLE is None: | |
print( | |
"Warning! Attempted to capture a response without " | |
"async_response_table configured in settings (you won't " | |
"capture async responses)." | |
) | |
capture_response = False | |
self.response_id = "MISCONFIGURED" | |
else: | |
self.response_id = str(uuid.uuid4()) | |
else: | |
self.response_id = None | |
self.capture_response = capture_response | |
def send(self, task_path, args, kwargs): | |
""" | |
Create the message object and pass it to the actual sender. | |
""" | |
message = { | |
"task_path": task_path, | |
"capture_response": self.capture_response, | |
"response_id": self.response_id, | |
"args": args, | |
"kwargs": kwargs, | |
} | |
self._send(message) | |
return self | |
def _send(self, message): | |
""" | |
Given a message, directly invoke the lambda function for this task. | |
""" | |
message["command"] = "zappa.async.route_lambda_task" | |
payload = json.dumps(message).encode("utf-8") | |
if len(payload) > 128000: # pragma: no cover | |
raise AsyncException("Payload too large for async Lambda call") | |
self.response = self.client.invoke( | |
FunctionName=self.lambda_function_name, | |
InvocationType="Event", # makes the call async | |
Payload=str(payload), | |
) | |
self.sent = self.response.get("StatusCode", 0) == 202 | |
class SnsAsyncResponse(LambdaAsyncResponse): | |
""" | |
Send a SNS message to a specified SNS topic | |
Serialise the func path and arguments | |
""" | |
def __init__( | |
self, | |
lambda_function_name=None, | |
aws_region=None, | |
capture_response=False, | |
**kwargs | |
): | |
self.lambda_function_name = lambda_function_name | |
self.aws_region = aws_region | |
if kwargs.get("boto_session"): | |
self.client = kwargs.get("boto_session").client("sns") | |
else: # pragma: no cover | |
self.client = SNS_CLIENT | |
if kwargs.get("arn"): | |
self.arn = kwargs.get("arn") | |
else: | |
if kwargs.get("boto_session"): | |
sts_client = kwargs.get("boto_session").client("sts") | |
else: | |
sts_client = STS_CLIENT | |
AWS_ACCOUNT_ID = sts_client.get_caller_identity()["Account"] | |
self.arn = "arn:aws:sns:{region}:{account}:{topic_name}".format( | |
region=self.aws_region, | |
account=AWS_ACCOUNT_ID, | |
topic_name=get_topic_name(self.lambda_function_name), | |
) | |
# Issue: https://github.com/Miserlou/Zappa/issues/1209 | |
# TODO: Refactor | |
self.capture_response = capture_response | |
if capture_response: | |
if ASYNC_RESPONSE_TABLE is None: | |
print( | |
"Warning! Attempted to capture a response without " | |
"async_response_table configured in settings (you won't " | |
"capture async responses)." | |
) | |
capture_response = False | |
self.response_id = "MISCONFIGURED" | |
else: | |
self.response_id = str(uuid.uuid4()) | |
else: | |
self.response_id = None | |
self.capture_response = capture_response | |
def _send(self, message): | |
""" | |
Given a message, publish to this topic. | |
""" | |
message["command"] = "zappa.async.route_sns_task" | |
self.response = self.client.publish( | |
TargetArn=self.arn, Message=json.dumps(message) | |
) | |
self.sent = self.response.get("MessageId") | |
## | |
# Aync Routers | |
## | |
ASYNC_CLASSES = {"lambda": LambdaAsyncResponse, "sns": SnsAsyncResponse} | |
def route_lambda_task(event, context): | |
""" | |
Deserialises the message from event passed to zappa.handler.run_function | |
imports the function, calls the function with args | |
""" | |
message = event | |
return run_message(message) | |
def route_sns_task(event, context): | |
""" | |
Gets SNS Message, deserialises the message, | |
imports the function, calls the function with args | |
""" | |
record = event["Records"][0] | |
message = json.loads(record["Sns"]["Message"]) | |
return run_message(message) | |
def run_message(message): | |
""" | |
Runs a function defined by a message object with keys: | |
'task_path', 'args', and 'kwargs' used by lambda routing | |
and a 'command' in handler.py | |
""" | |
if message.get("capture_response", False): | |
DYNAMODB_CLIENT.put_item( | |
TableName=ASYNC_RESPONSE_TABLE, | |
Item={ | |
"id": {"S": str(message["response_id"])}, | |
"ttl": {"N": str(int(time.time() + 600))}, | |
"async_status": {"S": "in progress"}, | |
"async_response": {"S": str(json.dumps("N/A"))}, | |
}, | |
) | |
func = import_and_get_task(message["task_path"]) | |
if hasattr(func, "sync"): | |
response = func.sync(*message["args"], **message["kwargs"]) | |
else: | |
response = func(*message["args"], **message["kwargs"]) | |
if message.get("capture_response", False): | |
DYNAMODB_CLIENT.update_item( | |
TableName=ASYNC_RESPONSE_TABLE, | |
Key={"id": {"S": str(message["response_id"])}}, | |
UpdateExpression="SET async_response = :r, async_status = :s", | |
ExpressionAttributeValues={ | |
":r": {"S": str(json.dumps(response))}, | |
":s": {"S": "complete"}, | |
}, | |
) | |
return response | |
## | |
# Execution interfaces and classes | |
## | |
def run( | |
func, | |
args=[], | |
kwargs={}, | |
service="lambda", | |
capture_response=False, | |
remote_aws_lambda_function_name=None, | |
remote_aws_region=None, | |
**task_kwargs | |
): | |
""" | |
Instead of decorating a function with @task, you can just run it directly. | |
If you were going to do func(*args, **kwargs), then you will call this: | |
import zappa.async.run | |
zappa.async.run(func, args, kwargs) | |
If you want to use SNS, then do: | |
zappa.async.run(func, args, kwargs, service='sns') | |
and other arguments are similar to @task | |
""" | |
lambda_function_name = remote_aws_lambda_function_name or os.environ.get( | |
"AWS_LAMBDA_FUNCTION_NAME" | |
) | |
aws_region = remote_aws_region or os.environ.get("AWS_REGION") | |
task_path = get_func_task_path(func) | |
return ASYNC_CLASSES[service]( | |
lambda_function_name=lambda_function_name, | |
aws_region=aws_region, | |
capture_response=capture_response, | |
**task_kwargs | |
).send(task_path, args, kwargs) | |
# Handy: | |
# http://stackoverflow.com/questions/10294014/python-decorator-best-practice-using-a-class-vs-a-function | |
# However, this needs to pass inspect.getargspec() in handler.py which does not take classes | |
# Wrapper written to take optional arguments | |
# http://chase-seibert.github.io/blog/2013/12/17/python-decorator-optional-parameter.html | |
def task(*args, **kwargs): | |
"""Async task decorator so that running | |
Args: | |
func (function): the function to be wrapped | |
Further requirements: | |
func must be an independent top-level function. | |
i.e. not a class method or an anonymous function | |
service (str): either 'lambda' or 'sns' | |
remote_aws_lambda_function_name (str): the name of a remote lambda function to call with this task | |
remote_aws_region (str): the name of a remote region to make lambda/sns calls against | |
Returns: | |
A replacement function that dispatches func() to | |
run asynchronously through the service in question | |
""" | |
func = None | |
if len(args) == 1 and callable(args[0]): | |
func = args[0] | |
if not kwargs: # Default Values | |
service = "lambda" | |
lambda_function_name_arg = None | |
aws_region_arg = None | |
else: # Arguments were passed | |
service = kwargs.get("service", "lambda") | |
lambda_function_name_arg = kwargs.get("remote_aws_lambda_function_name") | |
aws_region_arg = kwargs.get("remote_aws_region") | |
capture_response = kwargs.get("capture_response", False) | |
def func_wrapper(func): | |
task_path = get_func_task_path(func) | |
@wraps(func) | |
def _run_async(*args, **kwargs): | |
""" | |
This is the wrapping async function that replaces the function | |
that is decorated with @task. | |
Args: | |
These are just passed through to @task's func | |
Assuming a valid service is passed to task() and it is run | |
inside a Lambda process (i.e. AWS_LAMBDA_FUNCTION_NAME exists), | |
it dispatches the function to be run through the service variable. | |
Otherwise, it runs the task synchronously. | |
Returns: | |
In async mode, the object returned includes state of the dispatch. | |
For instance | |
When outside of Lambda, the func passed to @task is run and we | |
return the actual value. | |
""" | |
lambda_function_name = lambda_function_name_arg or os.environ.get( | |
"AWS_LAMBDA_FUNCTION_NAME" | |
) | |
aws_region = aws_region_arg or os.environ.get("AWS_REGION") | |
if (service in ASYNC_CLASSES) and (lambda_function_name): | |
send_result = ASYNC_CLASSES[service]( | |
lambda_function_name=lambda_function_name, | |
aws_region=aws_region, | |
capture_response=capture_response, | |
).send(task_path, args, kwargs) | |
return send_result | |
else: | |
return func(*args, **kwargs) | |
update_wrapper(_run_async, func) | |
_run_async.service = service | |
_run_async.sync = func | |
return _run_async | |
return func_wrapper(func) if func else func_wrapper | |
def task_sns(func): | |
""" | |
SNS-based task dispatcher. Functions the same way as task() | |
""" | |
return task(func, service="sns") | |
## | |
# Utility Functions | |
## | |
def import_and_get_task(task_path): | |
""" | |
Given a modular path to a function, import that module | |
and return the function. | |
""" | |
module, function = task_path.rsplit(".", 1) | |
app_module = importlib.import_module(module) | |
app_function = getattr(app_module, function) | |
return app_function | |
def get_func_task_path(func): | |
""" | |
Format the modular task path for a function via inspection. | |
""" | |
module_path = inspect.getmodule(func).__name__ | |
task_path = "{module_path}.{func_name}".format( | |
module_path=module_path, func_name=func.__name__ | |
) | |
return task_path | |
def get_async_response(response_id): | |
""" | |
Get the response from the async table | |
""" | |
response = DYNAMODB_CLIENT.get_item( | |
TableName=ASYNC_RESPONSE_TABLE, Key={"id": {"S": str(response_id)}} | |
) | |
if "Item" not in response: | |
return None | |
return { | |
"status": response["Item"]["async_status"]["S"], | |
"response": json.loads(response["Item"]["async_response"]["S"]), | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment