Skip to content

Instantly share code, notes, and snippets.

@jayhale

jayhale/db.py Secret

Created October 28, 2022 18:54
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save jayhale/c5f08dcd1656db1b82e3177425911091 to your computer and use it in GitHub Desktop.
Save jayhale/c5f08dcd1656db1b82e3177425911091 to your computer and use it in GitHub Desktop.
Example SQLAlchemy resource and IO manager for Dagster
from base64 import b64decode
from contextlib import contextmanager
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization
from dagster import Field, StringSource, resource
from snowflake.sqlalchemy import URL as SnowflakeURL
from sqlalchemy import create_engine
from sqlalchemy.engine import URL
generic_config_schema = {
"username": Field(StringSource, is_required=False),
"password": Field(StringSource, is_required=False),
"host": Field(StringSource, is_required=False),
"port": Field(StringSource, is_required=False),
"database": Field(StringSource, is_required=False),
"url": Field(StringSource, is_required=False),
"schema": Field(StringSource, is_required=False),
}
class SqlAlchemyConnection:
valid_url_args = ("username", "password", "host", "port", "database")
def __init__(self, config):
# Build a connection URL
if "url" in config:
self.url = config.get("url")
else:
self.url = self._build_url(self._get_url_args(config))
# Default schema
self.schema = config.get("schema", None)
# Connect args
self.connect_args = self._build_connect_args(config)
# Engine
self.engine = create_engine(self.url, connect_args=self.connect_args)
def _get_url_args(self, config):
return {k: config.get(k) for k in self.valid_url_args if k in config}
def _build_url(self, url_args):
return URL.create(**url_args)
def _build_connect_args(self, config):
return {}
@contextmanager
def get_connection(self):
connection = self.engine.connect()
if self.schema:
connection.execution_options(schema_translate_map={None: self.schema})
yield connection
connection.close()
def has_table(self, table_name):
return self.engine.has_table(table_name)
class RedshiftConnection(SqlAlchemyConnection):
def _get_url_args(self, config):
url_args = super(RedshiftConnection, self)._get_url_args(config)
url_args["drivername"] = "redshift+redshift_connector"
return url_args
@resource(
config_schema=generic_config_schema,
description="A resource for connecting to a Redshift instance",
)
def redshift_resource(context):
return RedshiftConnection(context.resource_config)
snowflake_config_schema = generic_config_schema | {
"account": Field(StringSource, is_required=True),
"warehouse": Field(StringSource, is_required=False),
"role": Field(StringSource, is_required=False),
"private_key": Field(StringSource, is_required=False),
"private_key_password": Field(StringSource, is_required=False),
}
class SnowflakeConnection(SqlAlchemyConnection):
valid_url_args = [
"account",
"username",
"password",
"database",
"schema",
"warehouse",
"role",
]
def _build_url(self, url_args):
assert (
url_args.get("account", "") != ""
), "Missing required 'account' Snowflake URL parameter"
# Snowflake renames the username parameter to user
modified_url_args = dict(url_args)
if "username" in modified_url_args:
modified_url_args["user"] = modified_url_args["username"]
del modified_url_args["username"]
return SnowflakeURL(**modified_url_args)
def _build_connect_args(self, config):
connect_args = {}
if "private_key_filename" in config or "private_key" in config:
# Prefer private key files if both are available
if "private_key_filename" in config:
pk_file = config.get("private_key_filename")
pk_pass = config.get("private_key_password", None)
if pk_pass is not None:
pk_pass = pk_pass.encode()
with open(pk_file, "rb") as f:
key = serialization.load_pem_private_key(
f.read(), password=pk_pass, backend=default_backend()
)
else:
pk_bytes = b64decode(config.get("private_key"))
pk_pass = config.get("private_key_password", None)
if pk_pass is not None:
pk_pass = pk_pass.encode()
key = serialization.load_der_private_key(
pk_bytes, password=pk_pass, backend=default_backend()
)
connect_args["private_key"] = key.private_bytes(
encoding=serialization.Encoding.DER,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption(),
)
return connect_args
@resource(
config_schema=snowflake_config_schema,
description="A resource for connecting to a Snowflake instance",
)
def snowflake_resource(context):
return SnowflakeConnection(context.resource_config)
import os
from abc import abstractmethod
from dagster import Field, IOManager, MetadataValue, StringSource
from resources.db import SnowflakeConnection, snowflake_config_schema
class SnowflakeIOManagerBase(IOManager):
write_mode = "w"
stage_prefix = "stage_raw_"
table_prefix = "raw_"
key_delimiter = "__"
file_extension = None
file_format = None
def __init__(self, config, base_dir=None):
self.db = SnowflakeConnection(config)
self.base_dir = base_dir
def get_path(self, context):
if context.has_asset_key:
path = context.get_asset_identifier()
else:
path = context.get_identifier()
full_path = os.path.join(self.base_dir, *path)
if self.file_extension:
full_path = full_path + self.file_extension
return full_path
def get_asset_name(self, context):
if context.has_asset_key:
# Always drop path elements that refer to our architectural schemas
path = [
p
for p in context.asset_key.path
if p.lower() not in ["bronze", "silver", "gold"]
]
return self.key_delimiter.join(path)
else:
return context.name
@abstractmethod
def store_serialized_object(self, path, obj):
"""Method to store the object to disk prior to staging"""
def get_stage_name(self, context):
return self.stage_prefix + self.get_asset_name(context)
def create_stage(self, stage_name, connection):
assert (
self.file_format
), "A Snowflake file format expression must be set on SnowflakeIOManager.file_format"
query = f"create stage if not exists {stage_name} file_format = ({self.file_format})"
connection.execute(query)
def stage_serialized_object(self, stage_name, serialized_file_path, connection):
"""Method to stage the serialized object to Snowflake"""
file_abspath = os.path.abspath(serialized_file_path)
query = f"put file://{file_abspath} @{stage_name}"
connection.execute(query)
def get_table_name(self, context):
return self.table_prefix + self.get_asset_name(context)
def create_table(self, table_name, connection):
query = f"""
create table if not exists {table_name} (
RAW VARIANT,
LOADED_TS TIMESTAMP_NTZ(9)
);
"""
connection.execute(query)
return table_name
@abstractmethod
def load_staged_object(self, table_name, stage_name, file_name, connection):
"""Method to load the staged object into a table on Snowflake"""
def clean_up(self, stage_name, file_name, connection):
"""Clean up after loading the staged file by removing the file on Snowflake"""
query = f"""remove @{stage_name}/{file_name};"""
connection.execute(query)
def handle_output(self, context, obj):
serialized_file_path = self.get_path(context)
file_name = os.path.basename(serialized_file_path)
file_size_bytes = self.store_serialized_object(serialized_file_path, obj)
with self.db.get_connection() as c:
stage_name = self.get_stage_name(context)
self.create_stage(stage_name, c)
self.stage_serialized_object(stage_name, serialized_file_path, c)
table_name = self.get_table_name(context)
self.create_table(table_name, c)
self.load_staged_object(table_name, stage_name, file_name, c)
self.clean_up(stage_name, file_name, c)
account, region = context.resource_config["account"].split(".")
context.add_output_metadata(
{
"snowflake data page": MetadataValue.url(
f"https://app.snowflake.com/{region}/{account}/#/data/databases/"
f"{context.resource_config['database'].upper()}/schemas/"
f"{context.resource_config['schema'].upper()}/table/{table_name.upper()}"
),
"file size (bytes)": file_size_bytes,
"local file path": serialized_file_path,
"stage name": stage_name,
"table name": table_name,
}
)
snowflake_io_manager_config_schema = snowflake_config_schema | {
"base_dir": Field(StringSource, is_required=False)
}
import json
import os
from dagster import io_manager
from dagster._utils import mkdir_p
from .snowflake_io_manager_base import (
SnowflakeIOManagerBase,
snowflake_io_manager_config_schema,
)
from .utils import JSONDateTimeEncoder
class SnowflakeMergeJsonIOManager(SnowflakeIOManagerBase):
file_extension = ".json"
file_format = "type = JSON compression = AUTO"
def store_serialized_object(self, path, obj):
mkdir_p(os.path.dirname(path))
with open(path, self.write_mode) as f:
json.dump(obj, f, cls=JSONDateTimeEncoder)
return os.path.getsize(path)
def load_staged_object(self, table_name, stage_name, file_name, connection):
query = f"""
merge into {table_name} as target
using (
select
parsed.value as RAW,
sysdate() as LOADED_TS
from @{stage_name}/{file_name}.gz as f,
lateral flatten(input => parse_json(f.$1)) as parsed
group by parsed.value
) as source
on target.RAW = source.RAW
when not matched then
insert (target.RAW, target.LOADED_TS) values (source.RAW, source.LOADED_TS);
"""
connection.execute(query)
def clean_up(self, stage_name, file_name, connection):
"""Clean up after loading the staged file by removing the file on Snowflake"""
query = f"""remove @{stage_name}/{file_name}.gz;"""
connection.execute(query)
def load_input(self, context):
raise NotImplementedError
@io_manager(config_schema=snowflake_io_manager_config_schema)
def snowflake_merge_json_io_manager(context):
base_dir = context.resource_config.get(
"base_dir", context.instance.storage_directory()
)
return SnowflakeMergeJsonIOManager(
config=context.resource_config, base_dir=base_dir
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment