Skip to content

Instantly share code, notes, and snippets.

@nave91
Last active April 23, 2022 07:48
Show Gist options
  • Save nave91/459f1551086a9edf9e5693a256559467 to your computer and use it in GitHub Desktop.
Save nave91/459f1551086a9edf9e5693a256559467 to your computer and use it in GitHub Desktop.
Airflow plugin to create a dbt operator
# MIT License
# Copyright (c) 2019 Bellhops Inc.
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import os
import networkx
from airflow.operators.bash_operator import BashOperator
from airflow.operators.subdag_operator import SubDagOperator
from airflow import utils as airflow_utils
class DbtModelOperator(SubDagOperator):
@airflow_utils.apply_defaults
def __init__(self, dag, task_id, start_date, schedule_interval, default_args, dbt_config,
dbt_directory, dbt_gpickle_file_path, **kwargs):
self.start_date = start_date
self.dag_schedule_interval = schedule_interval
self.default_args = default_args
self.dbt_tasks = {}
self.dbt_config = dbt_config
self.public_input_schema_name = self.dbt_config['public_input_schema_name']
self.public_output_schema_name = self.dbt_config['public_output_schema_name']
self.public_conn_id = self.dbt_config['public_conn_id']
self.directory = dbt_directory
self.gpickle_file_path = dbt_gpickle_file_path
self.profile_name = self.dbt_config['profile']
self.target_name = self.dbt_config['target']
from airflow import DAG
self.sub_dag_name = dag.dag_id + '.' + task_id
self.subdag = DAG(
self.sub_dag_name,
start_date=self.start_date,
schedule_interval=self.dag_schedule_interval,
default_args=self.default_args
)
self.init_tasks()
super(DbtModelOperator, self).__init__(
dag=dag,
subdag=self.subdag,
task_id=task_id,
trigger_rule='all_done'
)
@property
def task_type(self):
return 'SubDagOperator'
def task_dbt_run_and_test(self, model_name):
dbt_file_name = model_name.split('.')[-1]
param_profiles_dir = '--profiles-dir=.'
param_profile = '--profile={profile}'.format(profile=self.profile_name)
param_target = '--target={target}'.format(target=self.target_name)
param_models = '--models {model}'.format(model=dbt_file_name)
task = BashOperator(
task_id=dbt_file_name,
bash_command='''
cd {{ params.directory }};
. ../.venv/bin/activate
dbt run {{ params.param_profiles_dir }} {{ params.param_profile }} {{ params.param_target }} {{ params.param_models }} && \
dbt test {{ params.param_profiles_dir }} {{ params.param_profile }} {{ params.param_target }} {{ params.param_models }}
''',
dag=self.subdag,
params={
"directory": self.directory,
"param_profiles_dir": param_profiles_dir,
"param_profile": param_profile,
"param_target": param_target,
"param_models": param_models
}
)
file_name = model_name.split('.')[-1]
target_file_path = self.get_dbt_file_path(file_name=file_name, prefix='target')
if target_file_path:
sql = open(target_file_path).read()
task.doc = "-- Target File: " + target_file_path + "\n"
task.doc += sql
return task
def task_dbt_models_from_graph(self):
dbt_graph = networkx.read_gpickle(self.gpickle_file_path)
for node_name in set(dbt_graph.nodes()):
if node_name.split('.')[0] == 'model':
dbt_task = self.task_dbt_run_and_test(node_name)
self.dbt_tasks[node_name] = dbt_task
for edge in dbt_graph.edges():
source_prefix = edge[0].split('.')[0]
sink_prefix = edge[1].split('.')[0]
if source_prefix == 'model' and sink_prefix == 'model':
self.dbt_tasks[edge[0]].set_downstream(self.dbt_tasks[edge[1]])
def get_dbt_file_path(self, file_name, prefix=''):
for root, dirs, files in os.walk(self.directory + "/" + prefix):
for f in files:
if f == (file_name + ".sql"):
target_file = os.path.join(root, f)
return target_file
return None
def init_tasks(self):
if os.path.isfile(self.gpickle_file_path):
self.task_dbt_models_from_graph()
class DbtOperator(SubDagOperator):
@airflow_utils.apply_defaults
def __init__(self, dag, task_id, start_date, schedule_interval, default_args, dbt_project_name, dbt_config, **kwargs):
self.start_date = start_date
self.dag_schedule_interval = schedule_interval
self.default_args = default_args
self.dbt_config = dbt_config
self.dbt_project_name = dbt_project_name
self.status_config = self.dbt_config['status_config']
self.dbt_profiles = self.dbt_config['dbt_profiles']
self.directory = self.dbt_config['local_directory'] + '/' + self.dbt_project_name
self.gpickle_file_path = self.directory + '/graph.gpickle'
self.profile_name = self.dbt_config['profile']
self.target_name = self.dbt_config['target']
from airflow import DAG
self.sub_dag_name = dag.dag_id + '.' + task_id
self.subdag = DAG(
self.sub_dag_name,
start_date=self.start_date,
schedule_interval=self.dag_schedule_interval,
default_args=self.default_args
)
self.init_tasks()
super(DbtOperator, self).__init__(
dag=dag,
subdag=self.subdag,
task_id=task_id,
trigger_rule='all_done'
)
@property
def task_type(self):
return 'SubDagOperator'
def task_configure_profile(self):
task = BashOperator(
task_id='configure_profile',
bash_command='''
echo "{{ params.profiles }}" > {{ params.directory }}/profiles.yml
''',
dag=self.subdag,
params={
"directory": self.directory,
"profiles": self.dbt_profiles
}
)
return task
def task_dbt_command(self, dbt_command):
param_dbt_command = dbt_command
param_profiles_dir = '--profiles-dir=.'
param_profile = '--profile={profile}'.format(profile=self.profile_name)
param_target = '--target={target}'.format(target=self.target_name)
task = BashOperator(
task_id=dbt_command,
bash_command='''
cd {{ params.directory }};
. ../.venv/bin/activate
dbt {{ params.param_dbt_command }} {{ params.param_profiles_dir }} {{ params.param_profile }} {{ params.param_target }}
''',
dag=self.subdag,
params={
"directory": self.directory,
"param_profiles_dir": param_profiles_dir,
"param_profile": param_profile,
"param_target": param_target,
"param_dbt_command": param_dbt_command
}
)
return task
def task_init_dbt_commands(self, upstream_dependencies):
dbt_commands = ['clean', 'deps', 'seed', 'compile']
first_dbt_command = dbt_commands[0]
last_dbt_command = dbt_commands[-1]
dbt_tasks = {}
last_task = None
for dbt_command in dbt_commands:
task = self.task_dbt_command(dbt_command)
dbt_tasks[dbt_command] = task
if last_task:
task.set_upstream(last_task)
last_task = task
dbt_tasks[first_dbt_command].set_upstream(upstream_dependencies)
return dbt_tasks[last_dbt_command]
def task_copy_graph(self):
# Updates to dbt project leads to changes in graph
# Initially to display a graph before first run we need to have
# a graph file. project_name/graph.gpickle serves that purpose of initialization
# This task helps update that file when dbt models are changed in future runs
task = BashOperator(
task_id='copy_graph_out_of_target',
bash_command='''
cd {{ params.directory }};
cp target/graph.gpickle .
''',
dag=self.subdag,
params={
"directory": self.directory,
}
)
return task
def init_tasks(self):
configure_profile = self.task_configure_profile()
init_dbt_commands = self.task_init_dbt_commands(configure_profile)
copy_graph = self.task_copy_graph()
copy_graph.set_upstream(init_dbt_commands)
postgres_dbt_models_task = DbtModelOperator(
dag=self.subdag,
task_id='dbt_postgres_models',
start_date=self.start_date,
schedule_interval=self.dag_schedule_interval,
default_args=self.default_args,
dbt_config=self.dbt_config,
dbt_directory=self.directory,
dbt_gpickle_file_path=self.gpickle_file_path
)
postgres_dbt_models_task.set_upstream(copy_graph)
def get_export_table_names(self):
exports_directory = self.directory + "/models/export"
table_names = []
if os.path.exists(exports_directory):
for root, dirs, files in os.walk(exports_directory):
for f in files:
if f.endswith(".sql"):
table_names.append(f.split('.')[0])
return table_names
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment