Last active
April 2, 2020 12:57
-
-
Save omegaml/8979e42667803c5a938e7bdbe31bfb85 to your computer and use it in GitHub Desktop.
omega|ml snowflake datasets plugin
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from omegaml.backends.basedata import BaseDataBackend | |
from base64 import b64encode, b64decode | |
from sqlalchemy import create_engine | |
import pandas as pd | |
# version of this plugin | |
version = '0.1.3' | |
class SnowflakeDataBackend(BaseDataBackend): | |
""" | |
Snowflake plugin for omegaml | |
Installation: | |
copy/paste above into a cell, execute, then run this to register | |
Alternatively install getgist | |
!pip install getgist | |
!getgist | |
Pre-Requisites: | |
make sure you have the following packages installed | |
!pip install -U snowflake-sqlalchemy==1.2.1 | |
Usage: | |
# define your snowflake connection | |
snowflake_constr = f'snowflake://{user}:{password}@{account}/' | |
# store in any of three ways | |
# -- just the connection | |
om.datasets.put(snowflake_constr, 'mysnowflake') | |
om.datasets.get('mysnowflake') | |
=> the sql connection object | |
# -- store connection with a predefined sql | |
om.datasets.put(snowflake_constr, 'mysnowflake', sql='select ....') | |
om.datasets.get('mysnowflake') | |
=> will return a pandas dataframe. specify chunksize to return an interable of dataframes | |
# -- copy the result of the snowflake query to omegaml | |
om.datasets.put(snowflake_constr, 'mysnowflake', sql='select ...', copy=True) | |
om.datasets.get('mysnowflake') | |
=> will return a pandas dataframe (without executing any additional queries) | |
=> can also use with om.datasets.getl('mysnowflake') to return a MDataFrame | |
Advanced: | |
om.datasets.put() supports the following additional keyword arguments | |
chunksize=int specify the number of rows to read from snowflake in one chunk. | |
defaults to 10000 | |
parse_dates=['col', ...] list of column names to parse for date, time or datetime. | |
see pd.read_sql for details | |
transform=callable a callable, is passed the DataFrame of each chunk before it | |
is inserted into the database. use to provide custom transformations. | |
only works on copy=True | |
as well as other kwargs supported by pd.read_sql | |
""" | |
KIND = 'snowflake.conx' | |
@classmethod | |
def supports(self, obj, name, *args, **kwargs): | |
return isinstance(obj, str) and obj.startswith('snowflake') | |
def get(self, name, sql=None, chunksize=None, *args, **kwargs): | |
meta = self.data_store.metadata(name) | |
connection_str = meta.kind_meta.get('snowflake_connection') | |
sql = sql or meta.kind_meta.get('sql') | |
chunksize = chunksize or meta.kind_meta.get('chunksize') | |
if connection_str: | |
connection = self.get_connection(connection_str) | |
else: | |
raise ValueError('no connection string') | |
if sql: | |
return pd.read_sql(sql, connection, chunksize=chunksize) | |
return connection | |
def put(self, obj, name, sql=None, copy=False, append=True, chunksize=None, | |
transform=None, *args, **kwargs): | |
attributes = kwargs.pop('attributes', None) or {} | |
kind_meta = { | |
'snowflake_connection': str(obj), | |
'sql': sql, | |
'chunksize': chunksize, | |
} | |
if copy: | |
if not sql: | |
raise ValueError('a valid SQL statement is requirement with copy=True') | |
metadata = self.copy_from_sql(sql, obj, name, chunksize=chunksize, | |
append=append, transform=transform, **kwargs) | |
else: | |
metadata = self.data_store.metadata(name) | |
if metadata is not None: | |
metadata.kind_meta.update(kind_meta) | |
else: | |
metadata = self.data_store.make_metadata(name, self.KIND, | |
kind_meta=kind_meta, | |
attributes=attributes) | |
metadata.attributes.update(attributes) | |
return metadata.save() | |
def get_connection(self, connection_str): | |
connection = None | |
try: | |
engine = create_engine(connection_str) | |
connection = engine.connect() | |
results = connection.execute('select current_version()').fetchone() | |
except Exception as e: | |
if connection is not None: | |
connection.close() | |
raise | |
return connection | |
def copy_from_sql(self, sql, connstr, name, chunksize=10000, | |
append=False, transform=None, **kwargs): | |
connection = self.get_connection(connstr) | |
chunksize = chunksize or 10000 # avoid None | |
pditer = pd.read_sql(sql, connection, chunksize=chunksize, **kwargs) | |
try: | |
import tqdm | |
except: | |
meta = self._chunked_insert(pditer, name, append=append, | |
transform=transform) | |
else: | |
with tqdm.tqdm(unit='rows') as pbar: | |
meta = self._chunked_insert(pditer, name, append=append, | |
transform=transform, pbar=pbar) | |
return meta | |
def _chunked_insert(self, pditer, name, append=True, transform=None, pbar=None): | |
for i, df in enumerate(pditer): | |
if pbar is not None: | |
pbar.update(len(df)) | |
should_append = (i > 0) or append | |
if transform: | |
df = transform(df) | |
try: | |
meta = self.data_store.put(df, name, append=should_append) | |
except Exception as e: | |
rows = df.iloc[0:10].to_dict() | |
raise ValueError("{e}: {rows}".format(**locals())) | |
return meta | |
""" | |
omegaml patch to fast_insert | |
version: 0.10.*, 0.11.* | |
""" | |
import os | |
import math | |
from multiprocessing import Pool | |
from omegaml.store.fastinsert import dfchunker, repeat, insert_chunk | |
# single instance multiprocessing pool | |
pool = None | |
def fast_insert(df, omstore, name, chunk_size=int(1e4)): | |
""" | |
fast insert of dataframe to mongodb | |
Depending on size use single-process or multiprocessing. Typically | |
multiprocessing is faster on datasets with > 10'000 data elements | |
(rows x columns). Note this may max out your CPU and may use | |
processor count * chunksize of additional memory. The chunksize is | |
set to 10'000. The processor count is the default used by multiprocessing, | |
typically the number of CPUs reported by the operating system. | |
:param df: dataframe | |
:param omstore: the OmegaStore to use. will be used to get the mongo_url | |
:param name: the dataset name in OmegaStore to use. will be used to get the | |
collection name from the omstore | |
""" | |
global pool | |
if len(df) * len(df.columns) > chunk_size: | |
mongo_url = omstore.mongo_url | |
collection_name = omstore.collection(name).name | |
# we crossed upper limits of single threaded processing, use a Pool | |
# use the cached pool | |
cores = max(1, math.ceil(os.cpu_count() / 2)) | |
pool = pool or Pool(processes=cores) | |
jobs = zip(dfchunker(df, size=chunk_size), | |
repeat(mongo_url), repeat(collection_name)) | |
pool.map(insert_chunk, (job for job in jobs)) | |
else: | |
# still within bounds for single threaded inserts | |
omstore.collection(name).insert_many(df.to_dict(orient='records')) | |
# apply the fix | |
from omegaml import version as omversion | |
if any(omversion.startswith(v) for v in ('0.10', '0.11')): | |
print(f"*** applying fast_insert patch to omegaml-{omversion}") | |
from omegaml.store import base | |
base.fast_insert = fast_insert | |
# load sqlalchemy | |
# source: https://stackoverflow.com/a/60726909/890242 | |
from sqlalchemy.dialects import registry | |
registry.register('snowflake', 'snowflake.sqlalchemy', 'dialect') | |
print(f"snowflake plugin {version}: to install execute the following line of code") | |
print("> om.datasets.register_backend(SnowflakeDataBackend.KIND, SnowflakeDataBackend)") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment