Created
December 13, 2020 04:18
-
-
Save mengwangk/6f156794c8b60c3e129f393bb919b233 to your computer and use it in GitHub Desktop.
loader.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# AUTOGENERATED! DO NOT EDIT! File to edit: 01_data.loader.ipynb (unless otherwise specified). | |
__all__ = ['auto_str', 'GetAttr', 'ObjectFactory', 'DbTargetProvider', 'FileSourceProvider', 'DatabaseTarget', | |
'FileSource', 'PgSqlDbBuilder', 'PgSqlDb', 'MySqlDbBuilder', 'MySqlDb', 'create_excel_file_source', | |
'create_csv_file_source', 'ExcelSource', 'CSVSource', 'db_targets', 'file_sources', 'ingest'] | |
# Cell | |
import os | |
import logging | |
import pandas as pd | |
import pymysql | |
import pymysql.cursors | |
from enum import Enum, auto | |
from sqlalchemy import create_engine, inspect | |
logging.basicConfig( | |
format="%(asctime)s %(levelname)s(): %(message)s", level=logging.DEBUG | |
) | |
# Cell | |
def auto_str(cls): | |
"Auto generate __str__" | |
def __str__(self): | |
return "%s(%s)" % ( | |
type(self).__name__, | |
", ".join("%s=%s" % item for item in vars(self).items()), | |
) | |
cls.__str__ = __str__ | |
return cls | |
# Cell | |
class GetAttr: | |
"Inherit from this to have all attr accesses in `self._xtra` passed down to `self.default`" | |
_default='default' | |
def _component_attr_filter(self,k): | |
if k.startswith('__') or k in ('_xtra',self._default): return False | |
xtra = getattr(self,'_xtra',None) | |
return xtra is None or k in xtra | |
def _dir(self): return [k for k in dir(getattr(self,self._default)) if self._component_attr_filter(k)] | |
def __getattr__(self,k): | |
if self._component_attr_filter(k): | |
attr = getattr(self,self._default,None) | |
if attr is not None: return getattr(attr,k) | |
raise AttributeError(k) | |
def __dir__(self): return custom_dir(self,self._dir()) | |
def __setstate__(self,data): self.__dict__.update(data) | |
# Cell | |
class ObjectFactory(): | |
"Generic object factory" | |
def __init__(self): | |
self._builders = {} | |
def register_builder(self, key, builder): | |
self._builders[key] = builder | |
def create(self, key, **kwargs): | |
builder = self._builders.get(key) | |
if not builder: | |
raise ValueError(key) | |
return builder(**kwargs) | |
# Cell | |
class DbTargetProvider(ObjectFactory): | |
"Database provider" | |
def get(self, id, **kwargs): | |
"""Create the database interface""" | |
return self.create(id, **kwargs) | |
# Cell | |
class FileSourceProvider(ObjectFactory): | |
"Supported file sources" | |
def get(self, id, **kwargs): | |
"""Create the file interface""" | |
return self.create(id, **kwargs) | |
# Cell | |
class DatabaseTarget(Enum): | |
PostgreSQL = auto() | |
MySQL = auto() | |
# Cell | |
class FileSource(Enum): | |
CSV = auto() | |
Excel = auto() | |
# Cell | |
class PgSqlDbBuilder: | |
"""PostgreSQL database builder.""" | |
def __init__(self): | |
self._instance = None | |
def __call__(self, host, port, db, user, password, **_ignored): | |
if not self._instance: | |
self._instance = PgSqlDb( | |
host, | |
port, | |
db, | |
user, | |
password | |
) | |
return self._instance | |
@auto_str | |
class PgSqlDb: | |
"""PostgreSQL database destination.""" | |
def __init__(self, host, port, db, user, password): | |
self._host = host | |
self._port = port | |
self._db = db | |
self._user = user | |
self._password = password | |
def get_engine(self): | |
"""Create and return sqlalchemy engine.""" | |
return create_engine(self.get_conn_str()) | |
def get_conn_str(self): | |
"""Return the connection string.""" | |
return f"postgresql+psycopg2://{self._user}:{self._password}@{self._host}:{self._port}/{self._db}" | |
# Cell | |
class MySqlDbBuilder: | |
"""MySQL database builder.""" | |
def __init__(self): | |
self._instance = None | |
def __call__(self, host, port, db, user, password, **_ignored): | |
if not self._instance: | |
self._instance = MySqlDb( | |
host, | |
port, | |
db, | |
user, | |
password | |
) | |
return self._instance | |
@auto_str | |
class MySqlDb: | |
"""MySQL database destination.""" | |
def __init__(self, host, port, db, user, password): | |
self._host = host | |
self._port = port | |
self._db = db | |
self._user = user | |
self._password = password | |
def get_engine(self): | |
"""Create and return sqlalchemy engine.""" | |
return create_engine(self.get_conn_str()) | |
def get_conn_str(self): | |
"""Return the connection string.""" | |
return f"mysql+pymysql://{self._user}:{self._password}@{self._host}:{self._port}/{self._db}?charset=utf8mb4" | |
# Cell | |
def create_excel_file_source(file_path, **args): | |
"""Create Excel file source.""" | |
return ExcelSource(file_path, **args) | |
def create_csv_file_source(file_path, **args): | |
"""Create CSV file source.""" | |
return CSVSource(file_path, **args) | |
class ExcelSource: | |
"""Excel file source.""" | |
def __init__(self, file_path, **args): | |
self._file_path = file_path | |
self._args = args | |
def filepath(self): | |
return self._file_path | |
def get_data(self): | |
"""Read the file and return a `DataFrame`""" | |
return pd.read_excel(self._file_path, engine='openpyxl', **self._args) | |
class CSVSource: | |
"""CSV file source.""" | |
def __init__(self, file_path, **args): | |
self._file_path = file_path | |
self._args = args | |
def filepath(self): | |
return self._file_path | |
def get_data(self): | |
"""Read the file and return a `DataFrame`""" | |
return pd.read_csv(self._file_path, engine=None, **self._args) | |
# Cell | |
# Register supported database providers | |
db_targets = DbTargetProvider() | |
db_targets.register_builder(DatabaseTarget.PostgreSQL, PgSqlDbBuilder()) | |
db_targets.register_builder(DatabaseTarget.MySQL, MySqlDbBuilder()) | |
# Cell | |
# Register supported file types | |
file_sources = FileSourceProvider() | |
file_sources.register_builder(FileSource.Excel, create_excel_file_source) | |
file_sources.register_builder(FileSource.CSV, create_csv_file_source) | |
# Cell | |
def ingest(file_source, target_db, table_name, *, if_exists='append', method='multi', schema=None): | |
"""Ingest the file into the database table.""" | |
# Create db engine | |
engine = target_db.get_engine() | |
# Inspect the target table schema | |
inspector = inspect(engine) | |
dtypes = {} | |
for column in inspector.get_columns(table_name, schema=schema): | |
dtypes[column["name"]] = column["type"] | |
logging.info(dtypes) | |
# Load the excel into database | |
df = file_source.get_data() | |
df.to_sql( | |
table_name, engine, if_exists=if_exists, method=method, chunksize=500, index=False, dtype=dtypes | |
) | |
# TODO - Validation | |
print(f"\nTotal records in {file_source.filepath()} - {len(df)}") | |
for c in df.columns: | |
print(f"{c} - {df[c].nunique()}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment