Skip to content

Instantly share code, notes, and snippets.

@jorisvandenbossche
Last active April 9, 2022 00:55
Show Gist options
  • Star 7 You must be signed in to star a gist
  • Fork 8 You must be signed in to fork a gist
  • Save jorisvandenbossche/10841234 to your computer and use it in GitHub Desktop.
Save jorisvandenbossche/10841234 to your computer and use it in GitHub Desktop.
Patched version of pandas.io.sql to support PostgreSQL
"""
Patched version to support PostgreSQL
(original version: https://github.com/pydata/pandas/blob/v0.13.1/pandas/io/sql.py)
Adapted functions are:
- added _write_postgresql
- updated table_exist
- updated get_sqltype
- updated get_schema
Collection of query wrappers / abstractions to both facilitate data
retrieval and to reduce dependency on DB-specific API.
"""
from __future__ import print_function
from datetime import datetime, date
from pandas.compat import range, lzip, map, zip
import pandas.compat as compat
import numpy as np
import traceback
from pandas.core.datetools import format as date_format
from pandas.core.api import DataFrame, isnull
#------------------------------------------------------------------------------
# Helper execution function
def execute(sql, con, retry=True, cur=None, params=None):
"""
Execute the given SQL query using the provided connection object.
Parameters
----------
sql: string
Query to be executed
con: database connection instance
Database connection. Must implement PEP249 (Database API v2.0).
retry: bool
Not currently implemented
cur: database cursor, optional
Must implement PEP249 (Datbase API v2.0). If cursor is not provided,
one will be obtained from the database connection.
params: list or tuple, optional
List of parameters to pass to execute method.
Returns
-------
Cursor object
"""
try:
if cur is None:
cur = con.cursor()
if params is None:
cur.execute(sql)
else:
cur.execute(sql, params)
return cur
except Exception:
try:
con.rollback()
except Exception: # pragma: no cover
pass
print('Error on sql %s' % sql)
raise
def _safe_fetch(cur):
try:
result = cur.fetchall()
if not isinstance(result, list):
result = list(result)
return result
except Exception as e: # pragma: no cover
excName = e.__class__.__name__
if excName == 'OperationalError':
return []
def tquery(sql, con=None, cur=None, retry=True):
"""
Returns list of tuples corresponding to each row in given sql
query.
If only one column selected, then plain list is returned.
Parameters
----------
sql: string
SQL query to be executed
con: SQLConnection or DB API 2.0-compliant connection
cur: DB API 2.0 cursor
Provide a specific connection or a specific cursor if you are executing a
lot of sequential statements and want to commit outside.
"""
cur = execute(sql, con, cur=cur)
result = _safe_fetch(cur)
if con is not None:
try:
cur.close()
con.commit()
except Exception as e:
excName = e.__class__.__name__
if excName == 'OperationalError': # pragma: no cover
print('Failed to commit, may need to restart interpreter')
else:
raise
traceback.print_exc()
if retry:
return tquery(sql, con=con, retry=False)
if result and len(result[0]) == 1:
# python 3 compat
result = list(lzip(*result)[0])
elif result is None: # pragma: no cover
result = []
return result
def uquery(sql, con=None, cur=None, retry=True, params=None):
"""
Does the same thing as tquery, but instead of returning results, it
returns the number of rows affected. Good for update queries.
"""
cur = execute(sql, con, cur=cur, retry=retry, params=params)
result = cur.rowcount
try:
con.commit()
except Exception as e:
excName = e.__class__.__name__
if excName != 'OperationalError':
raise
traceback.print_exc()
if retry:
print('Looks like your connection failed, reconnecting...')
return uquery(sql, con, retry=False)
return result
def read_frame(sql, con, index_col=None, coerce_float=True, params=None):
"""
Returns a DataFrame corresponding to the result set of the query
string.
Optionally provide an index_col parameter to use one of the
columns as the index. Otherwise will be 0 to len(results) - 1.
Parameters
----------
sql: string
SQL query to be executed
con: DB connection object, optional
index_col: string, optional
column name to use for the returned DataFrame object.
coerce_float : boolean, default True
Attempt to convert values to non-string, non-numeric objects (like
decimal.Decimal) to floating point, useful for SQL result sets
params: list or tuple, optional
List of parameters to pass to execute method.
"""
cur = execute(sql, con, params=params)
rows = _safe_fetch(cur)
columns = [col_desc[0] for col_desc in cur.description]
cur.close()
con.commit()
result = DataFrame.from_records(rows, columns=columns,
coerce_float=coerce_float)
if index_col is not None:
result = result.set_index(index_col)
return result
frame_query = read_frame
read_sql = read_frame
def write_frame(frame, name, con, flavor='sqlite', if_exists='fail', **kwargs):
"""
Write records stored in a DataFrame to a SQL database.
Parameters
----------
frame: DataFrame
name: name of SQL table
con: an open SQL database connection object
flavor: {'sqlite', 'mysql', 'oracle'}, default 'sqlite'
if_exists: {'fail', 'replace', 'append'}, default 'fail'
fail: If table exists, do nothing.
replace: If table exists, drop it, recreate it, and insert data.
append: If table exists, insert data. Create if does not exist.
"""
if 'append' in kwargs:
import warnings
warnings.warn("append is deprecated, use if_exists instead",
FutureWarning)
if kwargs['append']:
if_exists = 'append'
else:
if_exists = 'fail'
if if_exists not in ('fail', 'replace', 'append'):
raise ValueError("'%s' is not valid for if_exists" % if_exists)
exists = table_exists(name, con, flavor)
if if_exists == 'fail' and exists:
raise ValueError("Table '%s' already exists." % name)
# creation/replacement dependent on the table existing and if_exist criteria
create = None
if exists:
if if_exists == 'fail':
raise ValueError("Table '%s' already exists." % name)
elif if_exists == 'replace':
cur = con.cursor()
cur.execute("DROP TABLE %s;" % name)
cur.close()
create = get_schema(frame, name, flavor)
else:
create = get_schema(frame, name, flavor)
if create is not None:
cur = con.cursor()
cur.execute(create)
cur.close()
cur = con.cursor()
# Replace spaces in DataFrame column names with _.
safe_names = [s.replace(' ', '_').strip() for s in frame.columns]
flavor_picker = {'sqlite' : _write_sqlite,
'mysql' : _write_mysql,
'postgresql' : _write_postgresql}
func = flavor_picker.get(flavor, None)
if func is None:
raise NotImplementedError
func(frame, name, safe_names, cur)
cur.close()
con.commit()
def _write_sqlite(frame, table, names, cur):
bracketed_names = ['[' + column + ']' for column in names]
col_names = ','.join(bracketed_names)
wildcards = ','.join(['?'] * len(names))
insert_query = 'INSERT INTO %s (%s) VALUES (%s)' % (
table, col_names, wildcards)
# pandas types are badly handled if there is only 1 column ( Issue #3628 )
if not len(frame.columns) == 1:
data = [tuple(x) for x in frame.values]
else:
data = [tuple(x) for x in frame.values.tolist()]
cur.executemany(insert_query, data)
def _write_mysql(frame, table, names, cur):
bracketed_names = ['`' + column + '`' for column in names]
col_names = ','.join(bracketed_names)
wildcards = ','.join([r'%s'] * len(names))
insert_query = "INSERT INTO %s (%s) VALUES (%s)" % (
table, col_names, wildcards)
data = [tuple(x) for x in frame.values]
cur.executemany(insert_query, data)
def _write_postgresql(frame, table, names, cur):
bracketed_names = ['"' + column + '"' for column in names]
col_names = ','.join(bracketed_names)
wildcards = ','.join([r'%s'] * len(names))
insert_query = 'INSERT INTO public.%s (%s) VALUES (%s)' % (
table, col_names, wildcards)
data = [tuple(x) for x in frame.values]
print insert_query
print data
cur.executemany(insert_query, data)
def table_exists(name, con, flavor):
flavor_map = {
'sqlite': ("SELECT name FROM sqlite_master "
"WHERE type='table' AND name='%s';") % name,
'mysql' : "SHOW TABLES LIKE '%s'" % name,
'postgresql' : "SELECT * FROM pg_catalog.pg_tables where tablename = '%s'" % name}
query = flavor_map.get(flavor, None)
# if query is None:
# raise NotImplementedError
return len(tquery(query, con)) > 0
def get_sqltype(pytype, flavor):
sqltype = {'mysql': 'VARCHAR (63)',
'sqlite': 'TEXT',
'postgresql': 'VARCHAR (63)'}
if issubclass(pytype, np.floating):
sqltype['mysql'] = 'FLOAT'
sqltype['sqlite'] = 'REAL'
sqltype['postgresql'] = 'double precision'
if issubclass(pytype, np.integer):
#TODO: Refine integer size.
sqltype['mysql'] = 'BIGINT'
sqltype['sqlite'] = 'INTEGER'
sqltype['postgresql'] = 'integer'
if issubclass(pytype, np.datetime64) or pytype is datetime:
# Caution: np.datetime64 is also a subclass of np.number.
sqltype['mysql'] = 'DATETIME'
sqltype['sqlite'] = 'TIMESTAMP'
sqltype['postgresql'] = 'timestamp'
if pytype is datetime.date:
sqltype['mysql'] = 'DATE'
sqltype['sqlite'] = 'TIMESTAMP'
sqltype['postgresql'] = 'date'
if issubclass(pytype, np.bool_):
sqltype['sqlite'] = 'INTEGER'
sqltype['postgresql'] = 'boolean'
return sqltype[flavor]
def get_schema(frame, name, flavor, keys=None):
"Return a CREATE TABLE statement to suit the contents of a DataFrame."
lookup_type = lambda dtype: get_sqltype(dtype.type, flavor)
# Replace spaces in DataFrame column names with _.
safe_columns = [s.replace(' ', '_').strip() for s in frame.dtypes.index]
column_types = lzip(safe_columns, map(lookup_type, frame.dtypes))
if flavor == 'sqlite':
columns = ',\n '.join('[%s] %s' % x for x in column_types)
elif flavor == 'postgresql':
columns = ',\n '.join('"%s" %s' % x for x in column_types)
else:
columns = ',\n '.join('`%s` %s' % x for x in column_types)
keystr = ''
if keys is not None:
if isinstance(keys, compat.string_types):
keys = (keys,)
keystr = ', PRIMARY KEY (%s)' % ','.join(keys)
template = """CREATE TABLE %(name)s (
%(columns)s
%(keystr)s
);"""
create_statement = template % {'name': name, 'columns': columns,
'keystr': keystr}
return create_statement
def sequence2dict(seq):
"""Helper function for cx_Oracle.
For each element in the sequence, creates a dictionary item equal
to the element and keyed by the position of the item in the list.
>>> sequence2dict(("Matt", 1))
{'1': 'Matt', '2': 1}
Source:
http://www.gingerandjohn.com/archives/2004/02/26/cx_oracle-executemany-example/
"""
d = {}
for k, v in zip(range(1, 1 + len(seq)), seq):
d[str(k)] = v
return d
@BAM-BAM-BAM
Copy link

This code won't work if you try writing to a schema other than "public" in postgressql.

I think best way to deal with that issue is to add a "schema" argument to read_frame() and write_frame()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment