Skip to content

Instantly share code, notes, and snippets.

@mjsqu
Last active September 28, 2023 02:42
Show Gist options
  • Save mjsqu/6c2ac95cb62fa6bac52701a830e9f069 to your computer and use it in GitHub Desktop.
Save mjsqu/6c2ac95cb62fa6bac52701a830e9f069 to your computer and use it in GitHub Desktop.
A backend for AWS to obtain Secrets from multiple parameters (untested)
import json
import re
from airflow.providers.amazon.aws.secrets.systems_manager import (
SystemsManagerParameterStoreBackend,
)
class SSMMultiParameterBackend(SystemsManagerParameterStoreBackend):
def _get_secret(
self, path_prefix: str, secret_id: str, lookup_pattern: str | None
) -> str | None:
"""
Get secret value from Parameter Store.
:param path_prefix: Prefix for the Path to get Secret (from connections_prefix)
:param secret_id: Secret Key (from conn_id)
:param lookup_pattern: If provided, `secret_id` must match this pattern to look up the secret in
Systems Manager
"""
if lookup_pattern and not re.match(lookup_pattern, secret_id, re.IGNORECASE):
return None
ssm_path = self.build_path(path_prefix, secret_id)
ssm_path = self._ensure_leading_slash(ssm_path)
parameter_args = [
"conn_type",
"description",
"host",
"login",
"password",
"schema",
"port",
"uri",
]
connection_object = {}
# Using ssm_path, look for the following parameters
for arg in parameter_args:
lookup = f"{ssm_path}/{arg}"
try:
response = self.client.get_parameter(Name=lookup, WithDecryption=True)
connection_object[arg] = response["Parameter"]["Value"]
except self.client.exceptions.ParameterNotFound:
self.log.debug("Parameter %s not found.", ssm_path)
# Now look at parameters in the path and add to extra
connection_object["extra"] = {}
try:
paginator = self.client.get_paginator('get_parameters_by_path')
for page in paginator.paginate(Path=ssm_path, WithDecryption=True, PaginationConfig={'PageSize': 1}):
for parameter in page.get('Parameters'):
connection_object["extra"][parameter['Name'].split("/")[-1]] = parameter['Value']
except self.client.exceptions.ParameterNotFound:
self.log.debug("Parameter %s not found.", ssm_path)
self.log.info(f"{connection_object=}")
return json.dumps(connection_object)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment