Created
December 29, 2019 16:20
-
-
Save stucchio/8a0c6c57cea7452eed8e7001877ae2fd to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import config | |
from azure.keyvault.secrets import SecretClient | |
from azure.identity import ClientSecretCredential | |
from azure.core.exceptions import ResourceNotFoundError | |
import time | |
class AzureKeyVaultConfiguration(config.Configuration): | |
""" | |
Configuration class. | |
The Configuration class takes azure keyvault credentials and behaves like a drop-in replacement | |
for the regular configuration. | |
""" | |
def __init__(self, az_client_id, az_client_secret, az_tenant_id, az_vault_name, | |
cache_expiration=5*60, | |
lowercase_keys: bool = False): | |
""" | |
Constructor. | |
:param config_: a mapping of configuration values. Keys need to be strings. | |
:param lowercase_keys: whether to convert every key to lower case. | |
""" | |
credentials = ClientSecretCredential( | |
client_id=az_client_id, | |
client_secret=az_client_secret, | |
tenant_id=az_tenant_id | |
) | |
vault_url = "https://{az_vault_name}.vault.azure.net/".format(az_vault_name=az_vault_name) | |
self.kvclient = SecretClient(vault_url=vault_url, credential=credentials) | |
self.cache_expiration = cache_expiration | |
self.__cache = {} | |
def __eq__(self, other): # type: ignore | |
"""Equality operator.""" | |
return self.as_dict() == other.as_dict() | |
def __get_secret(self, key): | |
key = key.replace('_', '-') # Normalize for Azure Keyvault | |
now = time.time() | |
from_cache = self.__cache.get(key) | |
if (not (from_cache is None)) and (from_cache[1] + self.cache_expiration > now): | |
return from_cache[0] | |
try: | |
secret = self.kvclient.get_secret(key) | |
self.__cache[key] = (secret.value, now) | |
return secret.value | |
except ResourceNotFoundError: | |
if (key in self.__cache): | |
del self.__cache[key] | |
return None | |
def __getitem__(self, item: str): | |
secret = self.__get_secret(item) | |
if secret is None: | |
raise KeyError(item) | |
else: | |
return secret | |
def __getattr__(self, item: str): | |
secret = self.__get_secret(item) | |
if secret is None: | |
raise KeyError(item) | |
else: | |
return secret | |
def get(self, key: str, default = None): | |
""" | |
Get the configuration values corresponding to :attr:`key`. | |
:param key: key to retrieve | |
:param default: default value in case the key is missing | |
:return: the value found or a default | |
""" | |
secret = self.__get_secret(key) | |
if secret is None: | |
return default | |
else: | |
return secret |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment