Created
February 29, 2012 09:59
-
-
Save tdicola/1939559 to your computer and use it in GitHub Desktop.
SQL style query function for pandas
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from pandas import * | |
import numpy as np | |
def query(select=None, from_=None, where=None, groupby=None, orderby=None, | |
ascending=True): | |
""" | |
SQL style query for columns of a DataFrame. | |
select List of column names to select in the final output. Aggregation | |
may be defined by specifying a list of [<aggregation function>, | |
<column name>] instead of column name. | |
from_ DataFrame to query. | |
where Boolean array for filtering DataFrame. | |
groupby List of columns to apply grouping. | |
orderby Column name to sort final results. | |
ascending Boolean to control sort order of final results. | |
""" | |
result = from_ | |
# Apply filtering from where clause | |
if where is not None: | |
result = result[where] | |
# Apply grouping from group by clause | |
# TODO: A lot more error checking (i.e. disallow selecting non-aggregate | |
# columns which aren't in group by). | |
if select is not None and groupby is not None: | |
# Find columns to aggregate based on lists of [function, column] in | |
# select clause | |
aggregates = dict([(x[1], x[0]) for x in select if type(x) is list and \ | |
len(x) == 2]) | |
# Update select clause with all column names (including aggregates) | |
select = map(lambda x: x if type(x) is not list else x[1], select) | |
result = result.groupby(groupby, as_index=False).agg(aggregates) | |
# Apply column selection from select clause | |
if select is not None: | |
result = result[select] | |
# Apply sorting from order by clause | |
if orderby is not None: | |
if len(select) > 1: | |
return result.sort_index(by=orderby, ascending=ascending) | |
else: | |
return result.sort() | |
else: | |
return result | |
if __name__ == '__main__': | |
data = DataFrame({'A': ['foo', 'foo', 'bar', 'bar', 'baz'], | |
'B': np.random.randn(5), | |
'C': range(5) }) | |
print 'Data:' | |
print data | |
print 'SELECT A FROM data WHERE C >= 2:' | |
print query(select = ['A'], | |
from_ = data, | |
where = data['C'] >= 2) | |
print 'SELECT A, SUM(B) FROM data ORDER BY A' | |
print query(select = ['A', [np.sum, 'B']], | |
from_ = data, | |
groupby = ['A'], | |
orderby = ['A']) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment