-
-
Save jayhale/c5f08dcd1656db1b82e3177425911091 to your computer and use it in GitHub Desktop.
Example SQLAlchemy resource and IO manager for Dagster
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from base64 import b64decode | |
from contextlib import contextmanager | |
from cryptography.hazmat.backends import default_backend | |
from cryptography.hazmat.primitives import serialization | |
from dagster import Field, StringSource, resource | |
from snowflake.sqlalchemy import URL as SnowflakeURL | |
from sqlalchemy import create_engine | |
from sqlalchemy.engine import URL | |
generic_config_schema = { | |
"username": Field(StringSource, is_required=False), | |
"password": Field(StringSource, is_required=False), | |
"host": Field(StringSource, is_required=False), | |
"port": Field(StringSource, is_required=False), | |
"database": Field(StringSource, is_required=False), | |
"url": Field(StringSource, is_required=False), | |
"schema": Field(StringSource, is_required=False), | |
} | |
class SqlAlchemyConnection: | |
valid_url_args = ("username", "password", "host", "port", "database") | |
def __init__(self, config): | |
# Build a connection URL | |
if "url" in config: | |
self.url = config.get("url") | |
else: | |
self.url = self._build_url(self._get_url_args(config)) | |
# Default schema | |
self.schema = config.get("schema", None) | |
# Connect args | |
self.connect_args = self._build_connect_args(config) | |
# Engine | |
self.engine = create_engine(self.url, connect_args=self.connect_args) | |
def _get_url_args(self, config): | |
return {k: config.get(k) for k in self.valid_url_args if k in config} | |
def _build_url(self, url_args): | |
return URL.create(**url_args) | |
def _build_connect_args(self, config): | |
return {} | |
@contextmanager | |
def get_connection(self): | |
connection = self.engine.connect() | |
if self.schema: | |
connection.execution_options(schema_translate_map={None: self.schema}) | |
yield connection | |
connection.close() | |
def has_table(self, table_name): | |
return self.engine.has_table(table_name) | |
class RedshiftConnection(SqlAlchemyConnection): | |
def _get_url_args(self, config): | |
url_args = super(RedshiftConnection, self)._get_url_args(config) | |
url_args["drivername"] = "redshift+redshift_connector" | |
return url_args | |
@resource( | |
config_schema=generic_config_schema, | |
description="A resource for connecting to a Redshift instance", | |
) | |
def redshift_resource(context): | |
return RedshiftConnection(context.resource_config) | |
snowflake_config_schema = generic_config_schema | { | |
"account": Field(StringSource, is_required=True), | |
"warehouse": Field(StringSource, is_required=False), | |
"role": Field(StringSource, is_required=False), | |
"private_key": Field(StringSource, is_required=False), | |
"private_key_password": Field(StringSource, is_required=False), | |
} | |
class SnowflakeConnection(SqlAlchemyConnection): | |
valid_url_args = [ | |
"account", | |
"username", | |
"password", | |
"database", | |
"schema", | |
"warehouse", | |
"role", | |
] | |
def _build_url(self, url_args): | |
assert ( | |
url_args.get("account", "") != "" | |
), "Missing required 'account' Snowflake URL parameter" | |
# Snowflake renames the username parameter to user | |
modified_url_args = dict(url_args) | |
if "username" in modified_url_args: | |
modified_url_args["user"] = modified_url_args["username"] | |
del modified_url_args["username"] | |
return SnowflakeURL(**modified_url_args) | |
def _build_connect_args(self, config): | |
connect_args = {} | |
if "private_key_filename" in config or "private_key" in config: | |
# Prefer private key files if both are available | |
if "private_key_filename" in config: | |
pk_file = config.get("private_key_filename") | |
pk_pass = config.get("private_key_password", None) | |
if pk_pass is not None: | |
pk_pass = pk_pass.encode() | |
with open(pk_file, "rb") as f: | |
key = serialization.load_pem_private_key( | |
f.read(), password=pk_pass, backend=default_backend() | |
) | |
else: | |
pk_bytes = b64decode(config.get("private_key")) | |
pk_pass = config.get("private_key_password", None) | |
if pk_pass is not None: | |
pk_pass = pk_pass.encode() | |
key = serialization.load_der_private_key( | |
pk_bytes, password=pk_pass, backend=default_backend() | |
) | |
connect_args["private_key"] = key.private_bytes( | |
encoding=serialization.Encoding.DER, | |
format=serialization.PrivateFormat.PKCS8, | |
encryption_algorithm=serialization.NoEncryption(), | |
) | |
return connect_args | |
@resource( | |
config_schema=snowflake_config_schema, | |
description="A resource for connecting to a Snowflake instance", | |
) | |
def snowflake_resource(context): | |
return SnowflakeConnection(context.resource_config) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
from abc import abstractmethod | |
from dagster import Field, IOManager, MetadataValue, StringSource | |
from resources.db import SnowflakeConnection, snowflake_config_schema | |
class SnowflakeIOManagerBase(IOManager): | |
write_mode = "w" | |
stage_prefix = "stage_raw_" | |
table_prefix = "raw_" | |
key_delimiter = "__" | |
file_extension = None | |
file_format = None | |
def __init__(self, config, base_dir=None): | |
self.db = SnowflakeConnection(config) | |
self.base_dir = base_dir | |
def get_path(self, context): | |
if context.has_asset_key: | |
path = context.get_asset_identifier() | |
else: | |
path = context.get_identifier() | |
full_path = os.path.join(self.base_dir, *path) | |
if self.file_extension: | |
full_path = full_path + self.file_extension | |
return full_path | |
def get_asset_name(self, context): | |
if context.has_asset_key: | |
# Always drop path elements that refer to our architectural schemas | |
path = [ | |
p | |
for p in context.asset_key.path | |
if p.lower() not in ["bronze", "silver", "gold"] | |
] | |
return self.key_delimiter.join(path) | |
else: | |
return context.name | |
@abstractmethod | |
def store_serialized_object(self, path, obj): | |
"""Method to store the object to disk prior to staging""" | |
def get_stage_name(self, context): | |
return self.stage_prefix + self.get_asset_name(context) | |
def create_stage(self, stage_name, connection): | |
assert ( | |
self.file_format | |
), "A Snowflake file format expression must be set on SnowflakeIOManager.file_format" | |
query = f"create stage if not exists {stage_name} file_format = ({self.file_format})" | |
connection.execute(query) | |
def stage_serialized_object(self, stage_name, serialized_file_path, connection): | |
"""Method to stage the serialized object to Snowflake""" | |
file_abspath = os.path.abspath(serialized_file_path) | |
query = f"put file://{file_abspath} @{stage_name}" | |
connection.execute(query) | |
def get_table_name(self, context): | |
return self.table_prefix + self.get_asset_name(context) | |
def create_table(self, table_name, connection): | |
query = f""" | |
create table if not exists {table_name} ( | |
RAW VARIANT, | |
LOADED_TS TIMESTAMP_NTZ(9) | |
); | |
""" | |
connection.execute(query) | |
return table_name | |
@abstractmethod | |
def load_staged_object(self, table_name, stage_name, file_name, connection): | |
"""Method to load the staged object into a table on Snowflake""" | |
def clean_up(self, stage_name, file_name, connection): | |
"""Clean up after loading the staged file by removing the file on Snowflake""" | |
query = f"""remove @{stage_name}/{file_name};""" | |
connection.execute(query) | |
def handle_output(self, context, obj): | |
serialized_file_path = self.get_path(context) | |
file_name = os.path.basename(serialized_file_path) | |
file_size_bytes = self.store_serialized_object(serialized_file_path, obj) | |
with self.db.get_connection() as c: | |
stage_name = self.get_stage_name(context) | |
self.create_stage(stage_name, c) | |
self.stage_serialized_object(stage_name, serialized_file_path, c) | |
table_name = self.get_table_name(context) | |
self.create_table(table_name, c) | |
self.load_staged_object(table_name, stage_name, file_name, c) | |
self.clean_up(stage_name, file_name, c) | |
account, region = context.resource_config["account"].split(".") | |
context.add_output_metadata( | |
{ | |
"snowflake data page": MetadataValue.url( | |
f"https://app.snowflake.com/{region}/{account}/#/data/databases/" | |
f"{context.resource_config['database'].upper()}/schemas/" | |
f"{context.resource_config['schema'].upper()}/table/{table_name.upper()}" | |
), | |
"file size (bytes)": file_size_bytes, | |
"local file path": serialized_file_path, | |
"stage name": stage_name, | |
"table name": table_name, | |
} | |
) | |
snowflake_io_manager_config_schema = snowflake_config_schema | { | |
"base_dir": Field(StringSource, is_required=False) | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import json | |
import os | |
from dagster import io_manager | |
from dagster._utils import mkdir_p | |
from .snowflake_io_manager_base import ( | |
SnowflakeIOManagerBase, | |
snowflake_io_manager_config_schema, | |
) | |
from .utils import JSONDateTimeEncoder | |
class SnowflakeMergeJsonIOManager(SnowflakeIOManagerBase): | |
file_extension = ".json" | |
file_format = "type = JSON compression = AUTO" | |
def store_serialized_object(self, path, obj): | |
mkdir_p(os.path.dirname(path)) | |
with open(path, self.write_mode) as f: | |
json.dump(obj, f, cls=JSONDateTimeEncoder) | |
return os.path.getsize(path) | |
def load_staged_object(self, table_name, stage_name, file_name, connection): | |
query = f""" | |
merge into {table_name} as target | |
using ( | |
select | |
parsed.value as RAW, | |
sysdate() as LOADED_TS | |
from @{stage_name}/{file_name}.gz as f, | |
lateral flatten(input => parse_json(f.$1)) as parsed | |
group by parsed.value | |
) as source | |
on target.RAW = source.RAW | |
when not matched then | |
insert (target.RAW, target.LOADED_TS) values (source.RAW, source.LOADED_TS); | |
""" | |
connection.execute(query) | |
def clean_up(self, stage_name, file_name, connection): | |
"""Clean up after loading the staged file by removing the file on Snowflake""" | |
query = f"""remove @{stage_name}/{file_name}.gz;""" | |
connection.execute(query) | |
def load_input(self, context): | |
raise NotImplementedError | |
@io_manager(config_schema=snowflake_io_manager_config_schema) | |
def snowflake_merge_json_io_manager(context): | |
base_dir = context.resource_config.get( | |
"base_dir", context.instance.storage_directory() | |
) | |
return SnowflakeMergeJsonIOManager( | |
config=context.resource_config, base_dir=base_dir | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment