Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
DBT BQ Dataset Authorizer
"""
This script is intended for use to automatically permission the datasets in BigQuery.
Usage:
This is intended to be run as part of the deployment of datawarehouse
Args:
-t/--target : the target to use (ex. prod, staging, test) - tied to the gcp project you are permissioning
-d/--dbt_base_dir : the directory to be searched for schema.yml files defining the mapping between datasets
and the groups which need access to them.
-m/--manifest_file : the manifest file which specifies what DBT has generated, and therefore which views need
to be authorized
-p/--profile_file : The profile.yaml used by DBT to specify connection details to the target. This script will
use the same connection. This script expects authentication by service account keyfile.
"""
import argparse
import json
import logging
from collections import defaultdict
from contextlib import ContextDecorator
from functools import reduce
from os import path
from pathlib import Path
import jinja2
import yaml
from dbt.context.base import BaseContext
from google.cloud import bigquery
from google.oauth2 import service_account
DBT_DIR = Path(__file__).resolve().parent.parent
MANIFEST_FILE = path.join(DBT_DIR, Path("target/manifest.json"))
PROFILE_FILE = path.join(DBT_DIR, Path("profiles.yml"))
class BQDatasetAuthorizer(ContextDecorator):
def __init__(self, target, profile_file):
self.target = target
self.profile_file = profile_file
# this dict maps a dataset to a list of BigQuery Access entries defined by code.
# We expect groupByEmail and View.
# https://googleapis.dev/python/bigquery/latest/generated/google.cloud.bigquery.dataset.AccessEntry.html
self.auth_dict = defaultdict(list)
def __enter__(self):
self._create_bq_client()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
# bq leaves the sockets open so you have to close them
self._close_bq_client()
def get_auth_dict(self):
return self.auth_dict
# generic helper functions
def _deep_dict_lookup(self, d, field):
return reduce(
lambda x, key: x.get(key, {}) if isinstance(x, dict) else {}, field.split("."), d
)
def _generate_access_entry_for_group(self, group_email):
return bigquery.AccessEntry(
role="READER", entity_type="groupByEmail", entity_id=group_email
)
def _generate_access_entry_for_view(self, view_dataset, view_tablename):
dataset_ref = bigquery.dataset.DatasetReference.from_string(
f"{self.project_id}.{view_dataset}"
)
view = bigquery.Table(dataset_ref.table(view_tablename))
# bigquery requires access entries have role=None
return bigquery.AccessEntry(
role=None, entity_type="view", entity_id=view.reference.to_api_repr()
)
# parsing helper functions
def _get_all_group_auth(self, dbt_base_dir):
"""Searches dbt directory's schema.yml for any dataset access entries and returns them in a list
Returns:
List of dataset dicts with the accesses to be granted
"""
group_auth = []
for file in dbt_base_dir.glob("models/**/schema.yml"):
with open(file) as f:
d = yaml.safe_load(f)
if d.get("datasets"):
group_auth.extend(d.get("datasets"))
return group_auth
def _get_project(self, profiles, target):
"""Extract the project for the appropriate target in DBT profiles.
Warning: there is some hardcoding on how dbt profiles are set up
Args:
profiles (dict): The DBT profile dictionary
Returns:
the project being targeted by that endpoint
"""
# hardcoded path based on how dbt_profiles get set up
project_key = f"dw_v2.outputs.{target}.project"
return self._deep_dict_lookup(profiles, project_key)
def _get_keyfile_path(self, profiles, target):
"""Extract the keyfile path for the appropriate target in DBT profiles.
Warning: there is some hardcoding on how dbt profiles are set up and some funny jinja stuff
Args:
profiles (dict): The DBT profile dictionary
Returns:
The location of the keyfile for the specified target in profiles
"""
# hardcoded path based on how dbt_profiles get set up
keyfile = f"dw_v2.outputs.{target}.keyfile"
target_keyfile_str = self._deep_dict_lookup(profiles, keyfile)
keyfile_template = jinja2.Environment(loader=jinja2.BaseLoader()).from_string(
target_keyfile_str
)
# we have to specify that the env_var in the jinja template should use the dbt_macro env_var :(
return keyfile_template.render(env_var=BaseContext.env_var)
def _parse_manifest(self, manifest_file):
if manifest_file is not None:
with open(manifest_file, "rb") as f:
manifest = json.load(f)
node_dict = manifest["nodes"]
source_dict = manifest["sources"]
child_map_dict = manifest["child_map"]
else:
node_dict = {}
source_dict = {}
child_map_dict = {}
return node_dict, source_dict, child_map_dict
def _parse_profiles(self):
if self.profile_file is not None:
with open(self.profile_file) as f:
profiles = yaml.safe_load(f)
else:
profiles = {}
return profiles
# helpers for bq
def _create_bq_client(self):
profiles = self._parse_profiles()
keyfile_path = self._get_keyfile_path(profiles, self.target)
credentials = service_account.Credentials.from_service_account_file(
keyfile_path, scopes=["https://www.googleapis.com/auth/cloud-platform"],
)
self.project_id = self._get_project(profiles, self.target)
self.client = bigquery.Client(credentials=credentials, project=self.project_id)
logging.info(f"Created BigQuery Client for {self.project_id}")
def _close_bq_client(self):
self.client.close()
def add_views_to_auth_dict(self, manifest_file):
"""Add BQ Access entries for views as that need authorized based on the dbt-generated manifest file to the
authorization dictionary.
"""
node_dict, source_dict, child_map_dict = self._parse_manifest(manifest_file)
def _get_field_from_node(node_name, field):
node = node_dict.get(node_name)
if node is not None:
return self._deep_dict_lookup(node, field)
source = source_dict.get(node_name)
if source is not None:
return self._deep_dict_lookup(source, field)
return None
def _node_is_view(node):
return (
_get_field_from_node(node, "resource_type") == "model"
and _get_field_from_node(node, "config.materialized") == "view"
)
for parent_node in child_map_dict.keys():
# Tests are present in the child map and don't want to include them
if _get_field_from_node(parent_node, "resource_type") == "test":
continue
access_entries = []
src_dataset = _get_field_from_node(parent_node, "schema")
children_nodes = child_map_dict[parent_node]
for child_node in children_nodes:
dest_dataset = _get_field_from_node(child_node, "schema")
# only have to authorize views in different datasets
# todo handle ephemeral models
if _node_is_view(child_node) and src_dataset != dest_dataset:
auth_view = self._generate_access_entry_for_view(
view_dataset=dest_dataset,
view_tablename=_get_field_from_node(child_node, "name"),
)
access_entries.append(auth_view)
if len(access_entries) > 0:
self.auth_dict[src_dataset].extend(access_entries)
def add_google_groups_to_auth_dict(self, dbt_dir):
"""Add BQ Access entries for google groups to the authorization dictionary.
"""
group_auth = self._get_all_group_auth(dbt_dir)
for dataset in group_auth:
access_dict = dataset["access"]
access_entries = list(
map(lambda x: self._generate_access_entry_for_group(x), access_dict,)
)
self.auth_dict[dataset["name"]].extend(access_entries)
def authorize_in_bq(self):
"""Updates each BigQuery dataset that has new access entries in the authorization dict with those entries.
(Aka permissions the newly added groups and authorizes the new views)
1. Gets the current access entries for the dataset,
2. Identifies new access entries in the authorization dictionary
3. Updates Bigquery to enable these accesses
"""
for dataset, access_list in self.auth_dict.items():
try:
cur_dataset = self.client.get_dataset(dataset)
except Exception as e:
logging.error(f"dataset {dataset} cannot be gotten, exception: {e}")
continue
access_entries = cur_dataset.access_entries
# bigquery will error if duplicates in access entries
new_access_entries = [
access
for x, access in enumerate(access_list)
if access not in access_entries and access not in access_list[:x]
]
if len(new_access_entries) == 0:
continue
logging.info(
f"adding {len(new_access_entries)} new access_entries for dataset: {dataset}"
)
access_entries.extend(new_access_entries)
cur_dataset.access_entries = access_entries
try:
self.client.update_dataset(cur_dataset, ["access_entries"])
except Exception as e:
logging.error(f"issues updating dataset {dataset}, exception: {e}")
continue
def parse_arguments():
"""Parse command line arguments"""
parser = argparse.ArgumentParser(
description="Appropriately permission all DBT outputs in BigQuery"
)
parser.add_argument("-t", "--target", type=str, required=True)
parser.add_argument("-d", "--dbt_base_dir", type=Path, default=DBT_DIR, required=False)
parser.add_argument("-m", "--manifest_file", type=Path, default=MANIFEST_FILE, required=False)
parser.add_argument("-p", "--profile_file", type=Path, default=PROFILE_FILE, required=False)
return parser.parse_args()
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
args = parse_arguments()
with BQDatasetAuthorizer(target=args.target, profile_file=args.profile_file) as bq_auth:
bq_auth.add_google_groups_to_auth_dict(args.dbt_base_dir)
bq_auth.add_views_to_auth_dict(args.manifest_file)
bq_auth.authorize_in_bq()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.