Skip to content

Instantly share code, notes, and snippets.

Created August 13, 2018 22:23
Show Gist options
  • Save rvaidya/f02c0e72e296b2906c76f0a94399d01e to your computer and use it in GitHub Desktop.
Save rvaidya/f02c0e72e296b2906c76f0a94399d01e to your computer and use it in GitHub Desktop.
Dask DataFrame read_sql_table using sqlalchemy reflection to detect column types
from dask.dataframe import read_sql_table
import pandas as pd
import numpy as np
from sqlalchemy import create_engine, schema
from config import database_config
# Copied from pandas with modifications
def _get_dtype(column, sqltype):
from sqlalchemy.types import (Integer, Float, Boolean, DateTime,
if isinstance(sqltype, Float):
return float
elif isinstance(sqltype, Integer):
if column.nullable:
return float
# TODO: Refine integer size.
return np.dtype('int64')
elif isinstance(sqltype, TIMESTAMP):
# we have a timezone capable type
if not sqltype.timezone:
return np.dtype('datetime64[ns]')
return DatetimeTZDtype
elif isinstance(sqltype, DateTime):
# Caution: np.datetime64 is also a subclass of np.number.
return np.dtype('datetime64[ns]')
elif isinstance(sqltype, Date):
elif isinstance(sqltype, Boolean):
return bool
return object
def database_table_request(db_type: str, db_server: str, database: str, table: str, index_col: str = None, npartitions: int = 1):
db_engine = database_config.engine(db_type)
db_username = database_config.username(db_type)
db_password = database_config.password(db_type)
# Get database schema using sqlalchemy reflection
db_uri = f'{db_engine}://{db_username}:{db_password}@{db_server}/{database}'
db_engine = create_engine(db_uri)
db_metadata = schema.MetaData(bind=db_engine, reflect=True)
db_tables = {k.lower(): v for k, v in db_metadata.tables.items()}
db_table = db_tables[table.lower()]
# Identify the PK if it hasn't been passed
for column in db_table.columns:
if column.primary_key and index_col == None:
index_col =
# Now that we have a PK name, create an empty pandas DataFrame
# for Dask meta argument
pd_df = pd.DataFrame(index=None)
for column in db_table.columns:
if not == index_col:
pd_df[] = pd.Series(
dtype=_get_dtype(column, column.type))
# Execute query here
df = read_sql_table(, db_uri, index_col,
meta=pd_df, npartitions=npartitions)
# Return dataframe
return df
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment