Skip to content

Instantly share code, notes, and snippets.

@dpmccabe
Last active April 24, 2024 23:45
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save dpmccabe/1e935626e7efd759d55bd2f44ba07439 to your computer and use it in GitHub Desktop.
Save dpmccabe/1e935626e7efd759d55bd2f44ba07439 to your computer and use it in GitHub Desktop.
post-zappa callback to update lambda function provisioned concurrency
import re
from time import sleep
from typing import Any, Dict
import boto3
from zappa.cli import ZappaCLI
PROVISIONED_CONCURRENCY = 5
client = boto3.client("lambda")
def post_callbacks(zappa_cli: ZappaCLI) -> None:
"""
Create a function alias and set its provisioned concurrency to the desired number of
instances, then configure the API to send traffic to that alias
Parameters
----------
zappa_cli : ZappaCLI
Returns
-------
None
"""
lambda_name = zappa_cli.lambda_name
alias_arn = check_alias(lambda_name)
check_concurrency(lambda_name, alias_arn)
set_api_gateway_alias(lambda_name)
def check_alias(lambda_name: str) -> str:
"""
Ensure that there is an alias "current" pointing to the latest function version
Parameters
----------
lambda_name : str
Returns
-------
str
the alias name
"""
# get the most recent version number
resp = client.list_versions_by_function(FunctionName=lambda_name)
versions = resp["Versions"]
versions = [x for x in versions if x["Version"] != "$LATEST"]
latest_version = max(versions, key=lambda x: int(x["Version"]))
# get the list of aliases
resp = client.list_aliases(FunctionName=lambda_name)
if len(resp["Aliases"]) > 1:
# this shouldn't happen, but would require fixing on AWS console
raise Exception("Multiple aliases defined")
elif len(resp["Aliases"]) == 0:
alias_arn = create_alias(lambda_name, latest_version)
elif resp["Aliases"][0]["Name"] != "current":
# this shouldn't happen, but would require fixing on AWS console
raise Exception('Existing alias is not named "current"')
else:
# alias already exists, so make sure it points to latest version
alias_arn = update_alias(lambda_name, latest_version)
return alias_arn
def create_alias(lambda_name: str, latest_version: Dict[str, Any]) -> str:
"""
Create an alias "current" pointing to the latest function version
Parameters
----------
lambda_name : str
latest_version : Dict[str, Any]
Returns
-------
str
the alias name
"""
resp = client.create_alias(
FunctionName=lambda_name,
Name="current",
FunctionVersion=latest_version["Version"],
)
print('Created function alias "current" for version', resp["FunctionVersion"])
return resp["AliasArn"]
def update_alias(lambda_name: str, latest_version: Dict[str, Any]) -> str:
"""
Update the "current" alias so that it's pointing to the latest function version
Parameters
----------
lambda_name : str
latest_version : Dict[str, Any]
Returns
-------
str
the alias name
"""
resp = client.update_alias(
FunctionName=lambda_name,
Name="current",
FunctionVersion=latest_version["Version"],
RoutingConfig={},
)
print('Updated function alias "current" for version', resp["FunctionVersion"])
return resp["AliasArn"]
def check_concurrency(lambda_name: str, alias_arn: str) -> None:
"""
Ensure that the provisioned concurrency settings have the correct number of
instances
Parameters
----------
lambda_name : str
alias_arn : str
Returns
-------
None
"""
while True:
# get existing concurrency configs
resp = client.list_provisioned_concurrency_configs(FunctionName=lambda_name)
if len(resp["ProvisionedConcurrencyConfigs"]) > 1:
# this shouldn't happen, but would require fixing on AWS console
raise Exception("Multiple existing provisioned concurrency configs")
elif len(resp["ProvisionedConcurrencyConfigs"]) == 0:
# need to create a config
create_concurrency(lambda_name)
break
# get singular existing concurrency config
prov = resp["ProvisionedConcurrencyConfigs"][0]
if prov["RequestedProvisionedConcurrentExecutions"] == PROVISIONED_CONCURRENCY:
# already have the right number of instances, although they might not all be
# ready yet
break
# need to change number of instances, so wait until it's ready
if prov["Status"] == "IN_PROGRESS":
print("Waiting for concurrency config to be ready...")
sleep(10)
elif prov["Status"] == "READY":
update_concurrency(alias_arn, lambda_name, prov)
break
else:
raise Exception(f"Invalid concurrency config status: {prov['Status']}")
def create_concurrency(lambda_name: str) -> None:
"""
Create a new provisioned concurrency config
Parameters
----------
lambda_name : str
Returns
-------
None
"""
resp = client.put_provisioned_concurrency_config(
FunctionName=lambda_name,
Qualifier="current",
ProvisionedConcurrentExecutions=PROVISIONED_CONCURRENCY,
)
print(
"Set provisioned concurrency to",
resp["RequestedProvisionedConcurrentExecutions"],
)
def update_concurrency(lambda_name: str, alias_arn: str, prov: Dict[str, Any]) -> None:
"""
Update an existing provisioned concurrency config so that it has the right number of
instances
Parameters
----------
lambda_name : str
alias_arn : str
prov : Dict[str, Any]
Returns
-------
None
"""
# check that conccurrency config is accurate now that it's ready
if prov["FunctionArn"] != alias_arn:
# this shouldn't happen, but would require fixing on AWS console
raise Exception("Existing concurrency config does not match alias ARN")
resp = client.put_provisioned_concurrency_config(
FunctionName=lambda_name,
Qualifier="current",
ProvisionedConcurrentExecutions=PROVISIONED_CONCURRENCY,
)
print(
"Set provisioned concurrency to",
resp["RequestedProvisionedConcurrentExecutions"],
)
def set_api_gateway_alias(lambda_name: str) -> None:
"""
Configure the API Gateway to use the "current" alias of the Lambda function
Parameters
----------
lambda_name : str
Returns
-------
None
"""
# get the API Gateway's ID
cf_client = boto3.client("cloudformation")
resp = cf_client.describe_stack_resource(
StackName=lambda_name, LogicalResourceId="Api"
)
api_id = resp["StackResourceDetail"]["PhysicalResourceId"]
# get the API resources (these are just stubs, since the function routes endpoints)
apig_client = boto3.client("apigateway")
resp = apig_client.get_resources(restApiId=api_id, embed=["methods"])
for item in resp["items"]:
resource_id = item["id"]
for method, behavior in item["resourceMethods"].items():
method_int = behavior["methodIntegration"] # e.g. ANY, OPTIONS, etc.
if "uri" in method_int:
# this is not a mock, i.e. we're sending traffic somewhere, so fix the
# URI if necessary
s = method_int["uri"]
s = re.sub("/invocations$", "", s)
uri_parts = re.split(r":", s)
if uri_parts[-1] != "current":
# it's just sending traffic to the $LATEST alias by default
uri_parts.append("current")
new_uri = ":".join(uri_parts) + "/invocations"
patch_op = {"op": "replace", "path": "/uri", "value": new_uri}
apig_client.update_integration(
restApiId=api_id,
resourceId=resource_id,
httpMethod=method,
patchOperations=[patch_op],
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment