Skip to content

Instantly share code, notes, and snippets.

@tdicola
Created February 29, 2012 09:59
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save tdicola/1939559 to your computer and use it in GitHub Desktop.
Save tdicola/1939559 to your computer and use it in GitHub Desktop.
SQL style query function for pandas
from pandas import *
import numpy as np
def query(select=None, from_=None, where=None, groupby=None, orderby=None,
ascending=True):
"""
SQL style query for columns of a DataFrame.
select List of column names to select in the final output. Aggregation
may be defined by specifying a list of [<aggregation function>,
<column name>] instead of column name.
from_ DataFrame to query.
where Boolean array for filtering DataFrame.
groupby List of columns to apply grouping.
orderby Column name to sort final results.
ascending Boolean to control sort order of final results.
"""
result = from_
# Apply filtering from where clause
if where is not None:
result = result[where]
# Apply grouping from group by clause
# TODO: A lot more error checking (i.e. disallow selecting non-aggregate
# columns which aren't in group by).
if select is not None and groupby is not None:
# Find columns to aggregate based on lists of [function, column] in
# select clause
aggregates = dict([(x[1], x[0]) for x in select if type(x) is list and \
len(x) == 2])
# Update select clause with all column names (including aggregates)
select = map(lambda x: x if type(x) is not list else x[1], select)
result = result.groupby(groupby, as_index=False).agg(aggregates)
# Apply column selection from select clause
if select is not None:
result = result[select]
# Apply sorting from order by clause
if orderby is not None:
if len(select) > 1:
return result.sort_index(by=orderby, ascending=ascending)
else:
return result.sort()
else:
return result
if __name__ == '__main__':
data = DataFrame({'A': ['foo', 'foo', 'bar', 'bar', 'baz'],
'B': np.random.randn(5),
'C': range(5) })
print 'Data:'
print data
print
print 'SELECT A FROM data WHERE C >= 2:'
print query(select = ['A'],
from_ = data,
where = data['C'] >= 2)
print
print 'SELECT A, SUM(B) FROM data ORDER BY A'
print query(select = ['A', [np.sum, 'B']],
from_ = data,
groupby = ['A'],
orderby = ['A'])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment