Last active
April 24, 2024 23:45
-
-
Save dpmccabe/1e935626e7efd759d55bd2f44ba07439 to your computer and use it in GitHub Desktop.
post-zappa callback to update lambda function provisioned concurrency
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import re | |
from time import sleep | |
from typing import Any, Dict | |
import boto3 | |
from zappa.cli import ZappaCLI | |
PROVISIONED_CONCURRENCY = 5 | |
client = boto3.client("lambda") | |
def post_callbacks(zappa_cli: ZappaCLI) -> None: | |
""" | |
Create a function alias and set its provisioned concurrency to the desired number of | |
instances, then configure the API to send traffic to that alias | |
Parameters | |
---------- | |
zappa_cli : ZappaCLI | |
Returns | |
------- | |
None | |
""" | |
lambda_name = zappa_cli.lambda_name | |
alias_arn = check_alias(lambda_name) | |
check_concurrency(lambda_name, alias_arn) | |
set_api_gateway_alias(lambda_name) | |
def check_alias(lambda_name: str) -> str: | |
""" | |
Ensure that there is an alias "current" pointing to the latest function version | |
Parameters | |
---------- | |
lambda_name : str | |
Returns | |
------- | |
str | |
the alias name | |
""" | |
# get the most recent version number | |
resp = client.list_versions_by_function(FunctionName=lambda_name) | |
versions = resp["Versions"] | |
versions = [x for x in versions if x["Version"] != "$LATEST"] | |
latest_version = max(versions, key=lambda x: int(x["Version"])) | |
# get the list of aliases | |
resp = client.list_aliases(FunctionName=lambda_name) | |
if len(resp["Aliases"]) > 1: | |
# this shouldn't happen, but would require fixing on AWS console | |
raise Exception("Multiple aliases defined") | |
elif len(resp["Aliases"]) == 0: | |
alias_arn = create_alias(lambda_name, latest_version) | |
elif resp["Aliases"][0]["Name"] != "current": | |
# this shouldn't happen, but would require fixing on AWS console | |
raise Exception('Existing alias is not named "current"') | |
else: | |
# alias already exists, so make sure it points to latest version | |
alias_arn = update_alias(lambda_name, latest_version) | |
return alias_arn | |
def create_alias(lambda_name: str, latest_version: Dict[str, Any]) -> str: | |
""" | |
Create an alias "current" pointing to the latest function version | |
Parameters | |
---------- | |
lambda_name : str | |
latest_version : Dict[str, Any] | |
Returns | |
------- | |
str | |
the alias name | |
""" | |
resp = client.create_alias( | |
FunctionName=lambda_name, | |
Name="current", | |
FunctionVersion=latest_version["Version"], | |
) | |
print('Created function alias "current" for version', resp["FunctionVersion"]) | |
return resp["AliasArn"] | |
def update_alias(lambda_name: str, latest_version: Dict[str, Any]) -> str: | |
""" | |
Update the "current" alias so that it's pointing to the latest function version | |
Parameters | |
---------- | |
lambda_name : str | |
latest_version : Dict[str, Any] | |
Returns | |
------- | |
str | |
the alias name | |
""" | |
resp = client.update_alias( | |
FunctionName=lambda_name, | |
Name="current", | |
FunctionVersion=latest_version["Version"], | |
RoutingConfig={}, | |
) | |
print('Updated function alias "current" for version', resp["FunctionVersion"]) | |
return resp["AliasArn"] | |
def check_concurrency(lambda_name: str, alias_arn: str) -> None: | |
""" | |
Ensure that the provisioned concurrency settings have the correct number of | |
instances | |
Parameters | |
---------- | |
lambda_name : str | |
alias_arn : str | |
Returns | |
------- | |
None | |
""" | |
while True: | |
# get existing concurrency configs | |
resp = client.list_provisioned_concurrency_configs(FunctionName=lambda_name) | |
if len(resp["ProvisionedConcurrencyConfigs"]) > 1: | |
# this shouldn't happen, but would require fixing on AWS console | |
raise Exception("Multiple existing provisioned concurrency configs") | |
elif len(resp["ProvisionedConcurrencyConfigs"]) == 0: | |
# need to create a config | |
create_concurrency(lambda_name) | |
break | |
# get singular existing concurrency config | |
prov = resp["ProvisionedConcurrencyConfigs"][0] | |
if prov["RequestedProvisionedConcurrentExecutions"] == PROVISIONED_CONCURRENCY: | |
# already have the right number of instances, although they might not all be | |
# ready yet | |
break | |
# need to change number of instances, so wait until it's ready | |
if prov["Status"] == "IN_PROGRESS": | |
print("Waiting for concurrency config to be ready...") | |
sleep(10) | |
elif prov["Status"] == "READY": | |
update_concurrency(alias_arn, lambda_name, prov) | |
break | |
else: | |
raise Exception(f"Invalid concurrency config status: {prov['Status']}") | |
def create_concurrency(lambda_name: str) -> None: | |
""" | |
Create a new provisioned concurrency config | |
Parameters | |
---------- | |
lambda_name : str | |
Returns | |
------- | |
None | |
""" | |
resp = client.put_provisioned_concurrency_config( | |
FunctionName=lambda_name, | |
Qualifier="current", | |
ProvisionedConcurrentExecutions=PROVISIONED_CONCURRENCY, | |
) | |
print( | |
"Set provisioned concurrency to", | |
resp["RequestedProvisionedConcurrentExecutions"], | |
) | |
def update_concurrency(lambda_name: str, alias_arn: str, prov: Dict[str, Any]) -> None: | |
""" | |
Update an existing provisioned concurrency config so that it has the right number of | |
instances | |
Parameters | |
---------- | |
lambda_name : str | |
alias_arn : str | |
prov : Dict[str, Any] | |
Returns | |
------- | |
None | |
""" | |
# check that conccurrency config is accurate now that it's ready | |
if prov["FunctionArn"] != alias_arn: | |
# this shouldn't happen, but would require fixing on AWS console | |
raise Exception("Existing concurrency config does not match alias ARN") | |
resp = client.put_provisioned_concurrency_config( | |
FunctionName=lambda_name, | |
Qualifier="current", | |
ProvisionedConcurrentExecutions=PROVISIONED_CONCURRENCY, | |
) | |
print( | |
"Set provisioned concurrency to", | |
resp["RequestedProvisionedConcurrentExecutions"], | |
) | |
def set_api_gateway_alias(lambda_name: str) -> None: | |
""" | |
Configure the API Gateway to use the "current" alias of the Lambda function | |
Parameters | |
---------- | |
lambda_name : str | |
Returns | |
------- | |
None | |
""" | |
# get the API Gateway's ID | |
cf_client = boto3.client("cloudformation") | |
resp = cf_client.describe_stack_resource( | |
StackName=lambda_name, LogicalResourceId="Api" | |
) | |
api_id = resp["StackResourceDetail"]["PhysicalResourceId"] | |
# get the API resources (these are just stubs, since the function routes endpoints) | |
apig_client = boto3.client("apigateway") | |
resp = apig_client.get_resources(restApiId=api_id, embed=["methods"]) | |
for item in resp["items"]: | |
resource_id = item["id"] | |
for method, behavior in item["resourceMethods"].items(): | |
method_int = behavior["methodIntegration"] # e.g. ANY, OPTIONS, etc. | |
if "uri" in method_int: | |
# this is not a mock, i.e. we're sending traffic somewhere, so fix the | |
# URI if necessary | |
s = method_int["uri"] | |
s = re.sub("/invocations$", "", s) | |
uri_parts = re.split(r":", s) | |
if uri_parts[-1] != "current": | |
# it's just sending traffic to the $LATEST alias by default | |
uri_parts.append("current") | |
new_uri = ":".join(uri_parts) + "/invocations" | |
patch_op = {"op": "replace", "path": "/uri", "value": new_uri} | |
apig_client.update_integration( | |
restApiId=api_id, | |
resourceId=resource_id, | |
httpMethod=method, | |
patchOperations=[patch_op], | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment