-
-
Save catawbasam/3164289 to your computer and use it in GitHub Desktop.
# -*- coding: utf-8 -*- | |
""" | |
LICENSE: BSD (same as pandas) | |
example use of pandas with oracle mysql postgresql sqlite | |
- updated 9/18/2012 with better column name handling; couple of bug fixes. | |
- used ~20 times for various ETL jobs. Mostly MySQL, but some Oracle. | |
to do: | |
save/restore index (how to check table existence? just do select count(*)?), | |
finish odbc, | |
add booleans?, | |
sql_server? | |
""" | |
import numpy as np | |
import cStringIO | |
import pandas.io.sql as psql | |
from dateutil import parser | |
from pandas import isnull | |
dbtypes={ | |
'mysql' : {'DATE':'DATE', 'DATETIME':'DATETIME', 'INT':'BIGINT', 'FLOAT':'FLOAT', 'VARCHAR':'VARCHAR'}, | |
'oracle': {'DATE':'DATE', 'DATETIME':'DATE', 'INT':'NUMBER', 'FLOAT':'NUMBER', 'VARCHAR':'VARCHAR2'}, | |
'sqlite': {'DATE':'TIMESTAMP', 'DATETIME':'TIMESTAMP', 'INT':'NUMBER', 'FLOAT':'NUMBER', 'VARCHAR':'VARCHAR2'}, | |
'postgresql': {'DATE':'TIMESTAMP', 'DATETIME':'TIMESTAMP', 'INT':'BIGINT', 'FLOAT':'REAL', 'VARCHAR':'TEXT'}, | |
} | |
# from read_frame. ?datetime objects returned? convert to datetime64? | |
def read_db(sql, con): | |
return psql.frame_query(sql, con) | |
def table_exists(name=None, con=None, flavor='sqlite'): | |
if flavor == 'sqlite': | |
sql="SELECT name FROM sqlite_master WHERE type='table' AND name='MYTABLE';".replace('MYTABLE', name) | |
elif flavor == 'mysql': | |
sql="show tables like 'MYTABLE';".replace('MYTABLE', name) | |
elif flavor == 'postgresql': | |
sql= "SELECT * FROM pg_tables WHERE tablename='MYTABLE';".replace('MYTABLE', name) | |
elif flavor == 'oracle': | |
sql="select table_name from user_tables where table_name='MYTABLE'".replace('MYTABLE', name.upper()) | |
elif flavor == 'odbc': | |
raise NotImplementedError | |
else: | |
raise NotImplementedError | |
df = read_db(sql, con) | |
print sql, df | |
print 'table_exists?', len(df) | |
exists = True if len(df)>0 else False | |
return exists | |
def write_frame(frame, name=None, con=None, flavor='sqlite', if_exists='fail'): | |
""" | |
Write records stored in a DataFrame to specified dbms. | |
if_exists: | |
'fail' - create table will be attempted and fail | |
'replace' - if table with 'name' exists, it will be deleted | |
'append' - assume table with correct schema exists and add data. if no table or bad data, then fail. | |
??? if table doesn't exist, make it. | |
if table already exists. Add: if_exists=('replace','append','fail') | |
""" | |
if if_exists=='replace' and table_exists(name, con, flavor): | |
cur = con.cursor() | |
cur.execute("drop table "+name) | |
cur.close() | |
if if_exists in ('fail','replace') or ( if_exists=='append' and table_exists(name, con, flavor)==False ): | |
#create table | |
schema = get_schema(frame, name, flavor) | |
if flavor=='oracle': | |
schema = schema.replace(';','') | |
cur = con.cursor() | |
if flavor=='mysql': | |
cur.execute("SET sql_mode='ANSI_QUOTES';") | |
print 'schema\n', schema | |
cur.execute(schema) | |
cur.close() | |
print 'created table' | |
cur = con.cursor() | |
#bulk insert | |
if flavor=='sqlite' or flavor=='odbc': | |
wildcards = ','.join(['?'] * len(frame.columns)) | |
insert_sql = 'INSERT INTO %s VALUES (%s)' % (name, wildcards) | |
#print 'insert_sql', insert_sql | |
data = [tuple(x) for x in frame.values] | |
#print 'data', data | |
cur.executemany(insert_sql, data) | |
elif flavor=='oracle': | |
cols=[db_colname(k) for k in frame.dtypes.index] | |
colnames = ','.join(cols) | |
colpos = ', '.join([':'+str(i+1) for i,f in enumerate(cols)]) | |
insert_sql = 'INSERT INTO %s (%s) VALUES (%s)' % (name, colnames, colpos) | |
#print 'insert_sql', insert_sql | |
data = [ convertSequenceToDict(rec) for rec in frame.values] | |
#print data | |
cur.executemany(insert_sql, data) | |
elif flavor=='mysql': | |
wildcards = ','.join(['%s'] * len(frame.columns)) | |
cols=[db_colname(k) for k in frame.dtypes.index] | |
colnames = ','.join(cols) | |
insert_sql = 'INSERT INTO %s (%s) VALUES (%s)' % (name, colnames, wildcards) | |
print insert_sql | |
#data = [tuple(x) for x in frame.values] | |
data= [ tuple([ None if isnull(v) else v for v in rw]) for rw in frame.values ] | |
print data[0] | |
cur.executemany(insert_sql, data) | |
elif flavor=='postgresql': | |
postgresql_copy_from(frame, name, con) | |
else: | |
raise NotImplementedError | |
con.commit() | |
cur.close() | |
return | |
def nan2none(df): | |
dnp = df.values | |
for rw in dnp: | |
rw2 = tuple([ None if v==np.Nan else v for v in rw]) | |
tpl_list= [ tuple([ None if v==np.Nan else v for v in rw]) for rw in dnp ] | |
return tpl_list | |
def db_colname(pandas_colname): | |
'''convert pandas column name to a DBMS column name | |
TODO: deal with name length restrictions, esp for Oracle | |
''' | |
colname = pandas_colname.replace(' ','_').strip() | |
return colname | |
def postgresql_copy_from(df, name, con ): | |
# append data into existing postgresql table using COPY | |
# 1. convert df to csv no header | |
output = cStringIO.StringIO() | |
# deal with datetime64 to_csv() bug | |
have_datetime64 = False | |
dtypes = df.dtypes | |
for i, k in enumerate(dtypes.index): | |
dt = dtypes[k] | |
print 'dtype', dt, dt.itemsize | |
if str(dt.type)=="<type 'numpy.datetime64'>": | |
have_datetime64 = True | |
if have_datetime64: | |
d2=df.copy() | |
for i, k in enumerate(dtypes.index): | |
dt = dtypes[k] | |
if str(dt.type)=="<type 'numpy.datetime64'>": | |
d2[k] = [ v.to_pydatetime() for v in d2[k] ] | |
#convert datetime64 to datetime | |
#ddt= [v.to_pydatetime() for v in dd] #convert datetime64 to datetime | |
d2.to_csv(output, sep='\t', header=False, index=False) | |
else: | |
df.to_csv(output, sep='\t', header=False, index=False) | |
output.seek(0) | |
contents = output.getvalue() | |
print 'contents\n', contents | |
# 2. copy from | |
cur = con.cursor() | |
cur.copy_from(output, name) | |
con.commit() | |
cur.close() | |
return | |
#source: http://www.gingerandjohn.com/archives/2004/02/26/cx_oracle-executemany-example/ | |
def convertSequenceToDict(list): | |
"""for cx_Oracle: | |
For each element in the sequence, creates a dictionary item equal | |
to the element and keyed by the position of the item in the list. | |
>>> convertListToDict(("Matt", 1)) | |
{'1': 'Matt', '2': 1} | |
""" | |
dict = {} | |
argList = range(1,len(list)+1) | |
for k,v in zip(argList, list): | |
dict[str(k)] = v | |
return dict | |
def get_schema(frame, name, flavor): | |
types = dbtypes[flavor] #deal with datatype differences | |
column_types = [] | |
dtypes = frame.dtypes | |
for i,k in enumerate(dtypes.index): | |
dt = dtypes[k] | |
#print 'dtype', dt, dt.itemsize | |
if str(dt.type)=="<type 'numpy.datetime64'>": | |
sqltype = types['DATETIME'] | |
elif issubclass(dt.type, np.datetime64): | |
sqltype = types['DATETIME'] | |
elif issubclass(dt.type, (np.integer, np.bool_)): | |
sqltype = types['INT'] | |
elif issubclass(dt.type, np.floating): | |
sqltype = types['FLOAT'] | |
else: | |
sampl = frame[ frame.columns[i] ][0] | |
#print 'other', type(sampl) | |
if str(type(sampl))=="<type 'datetime.datetime'>": | |
sqltype = types['DATETIME'] | |
elif str(type(sampl))=="<type 'datetime.date'>": | |
sqltype = types['DATE'] | |
else: | |
if flavor in ('mysql','oracle'): | |
size = 2 + max( (len(str(a)) for a in frame[k]) ) | |
print k,'varchar sz', size | |
sqltype = types['VARCHAR'] + '(?)'.replace('?', str(size) ) | |
else: | |
sqltype = types['VARCHAR'] | |
colname = db_colname(k) #k.upper().replace(' ','_') | |
column_types.append((colname, sqltype)) | |
columns = ',\n '.join('%s %s' % x for x in column_types) | |
template_create = """CREATE TABLE %(name)s ( | |
%(columns)s | |
);""" | |
#print 'COLUMNS:\n', columns | |
create = template_create % {'name' : name, 'columns' : columns} | |
return create | |
############################################################################### | |
def test_sqlite(name, testdf): | |
print '\nsqlite, using detect_types=sqlite3.PARSE_DECLTYPES for datetimes' | |
import sqlite3 | |
with sqlite3.connect('test.db', detect_types=sqlite3.PARSE_DECLTYPES) as conn: | |
#conn.row_factory = sqlite3.Row | |
write_frame(testdf, name, con=conn, flavor='sqlite', if_exists='replace') | |
df_sqlite = read_db('select * from '+name, con=conn) | |
print 'loaded dataframe from sqlite', len(df_sqlite) | |
print 'done with sqlite' | |
def test_oracle(name, testdf): | |
print '\nOracle' | |
import cx_Oracle | |
with cx_Oracle.connect('YOURCONNECTION') as ora_conn: | |
testdf['d64'] = np.datetime64( testdf['hire_date'] ) | |
write_frame(testdf, name, con=ora_conn, flavor='oracle', if_exists='replace') | |
df_ora2 = read_db('select * from '+name, con=ora_conn) | |
print 'done with oracle' | |
return df_ora2 | |
def test_postgresql(name, testdf): | |
#from pg8000 import DBAPI as pg | |
import psycopg2 as pg | |
print '\nPostgresQL, Greenplum' | |
pgcn = pg.connect(YOURCONNECTION) | |
print 'df frame_query' | |
try: | |
write_frame(testdf, name, con=pgcn, flavor='postgresql', if_exists='replace') | |
print 'pg copy_from' | |
postgresql_copy_from(testdf, name, con=pgcn) | |
df_gp = read_db('select * from '+name, con=pgcn) | |
print 'loaded dataframe from greenplum', len(df_gp) | |
finally: | |
pgcn.commit() | |
pgcn.close() | |
print 'done with greenplum' | |
def test_mysql(name, testdf): | |
import MySQLdb | |
print '\nmysql' | |
cn= MySQLdb.connect(YOURCONNECTION) | |
try: | |
write_frame(testdf, name='test_df', con=cn, flavor='mysql', if_exists='replace') | |
df_mysql = read_db('select * from '+name, con=cn) | |
print 'loaded dataframe from mysql', len(df_mysql) | |
finally: | |
cn.close() | |
print 'mysql done' | |
############################################################################## | |
if __name__=='__main__': | |
from pandas import DataFrame | |
from datetime import datetime | |
print """Aside from sqlite, you'll need to install the driver and set a valid | |
connection string for each test routine.""" | |
test_data = { | |
"name": [ 'Joe', 'Bob', 'Jim', 'Suzy', 'Cathy', 'Sarah' ], | |
"hire_date": [ datetime(2012,1,1), datetime(2012,2,1), datetime(2012,3,1), datetime(2012,4,1), datetime(2012,5,1), datetime(2012,6,1) ], | |
"erank": [ 1, 2, 3, 4, 5, 6 ], | |
"score": [ 1.1, 2.2, 3.1, 2.5, 3.6, 1.8] | |
} | |
df = DataFrame(test_data) | |
name='test_df' | |
test_sqlite(name, df) | |
#test_oracle(name, df) | |
#test_postgresql(name, df) | |
#test_mysql(name, df) | |
print 'done' |
I needed to add the imports:
import numpy as np
import cStringIO
and an index=False to the to_csv arguments, as well as an output.seek(0) after the to_csv.
This line is strange
sql="SELECT name FROM sqlite_master WHERE type='table' AND name='MYTABLE';".replace('MYTABLE', name)
Why not using
sql="SELECT name FROM sqlite_master WHERE type='table' AND name='{MYTABLE}';".format(MYTABLE=name)
@brandonwillard's fixes also made it work for me. Thanks, both!
@brandonwillard: thanks! I have update the gist.
For anyone just getting started with python, it might be helpful to know that you need a couple of imports to get the unit test to work....
if __name__=='__main__':
from pandas import DataFrame
from datetime import datetime
print """Aside from sqlite, you'll need to install the driver and set a valid
connection string for each test routine."""
I made a change to get_schema method to prevent blowing up when dataframe doesn't have the first row.
Check out my fork to see the changes.
@rriehle: thanks, imports added to gist.
I am using Python2.7 ,tried the test case have the following errors:
INSERT INTO test_df (erank,hire_date,name,score) VALUES (%s,%s,%s,%s)
(1L, Timestamp('2012-01-01 00:00:00', tz=None), 'Joe', 1.1)
Traceback (most recent call last):
File "C:/Users/Administrator/PycharmProjects/numpytest/pandas_dbms.py", line 316, in
test_mysql(name, df)
File "C:/Users/Administrator/PycharmProjects/numpytest/pandas_dbms.py", line 283, in test_mysql
write_frame(testdf, name='test_df', con=cn, flavor='mysql', if_exists='replace')
File "C:/Users/Administrator/PycharmProjects/numpytest/pandas_dbms.py", line 115, in write_frame
cur.executemany(insert_sql, data)
File "C:\Program Files (x86)\Python\lib\site-packages\mysql\connector\cursor.py", line 557, in executemany
values.append(fmt % self._process_params(params))
File "C:\Program Files (x86)\Python\lib\site-packages\mysql\connector\cursor.py", line 355, in _process_params
"Failed processing format-parameters; %s" % err)
mysql.connector.errors.ProgrammingError: Failed processing format-parameters; 'MySQLConverter' object has no attribute '_timestamp_to_mysql'
Seems to be the datetime formate error.
If I use some other datatype other than datetime, then it's good.
Thanks,
I had a similar problem to chien where the get_schema looks for a row index 0 which isn't there ( in my data). iloc uses relative position.
If you use frame[ frame.columns[i] ].iloc[0] instead of frame[ frame.columns[i] ][0] on line 206. You'll be set. iloc returns the first item (0 offset); whereas [0] returns the item where the index == 0.
I should add, thank you for sharing this. It is very helpful.
isnull is not defined
This code package is awesome... found stuck at one point.. isnull not defined.. in line 109...
@estenssoros @brajesh2020 from pandas import isnull
Thanks. Added from pandas import isnull
to header code.
TypeError: Argument 'rows' has incorrect type (expected list, got tuple)
Solution: use MySQLdb to get a cursor (instead of pandas), fetch all into a tuple, then cast that as a list when creating the new DataFrame: