Skip to content

Instantly share code, notes, and snippets.

Created September 5, 2012 16:03
Show Gist options
  • Save bbeck/3639046 to your computer and use it in GitHub Desktop.
Save bbeck/3639046 to your computer and use it in GitHub Desktop.
#!/usr/bin/env python
from __future__ import with_statement
import copy
import functools
import itertools
import sys
from decimal import Decimal
class aggregation(object):
aggregations_by_name = {}
def __init__(self, name, type=None, default=None, help=''): = name
self.type = type
self.default = default = help
def __call__(self, fn):
def wrapper(state, value):
if (self.type is not None):
value = self.type(value)
if (state is None):
if self.default is None:
return value
elif callable(self.default):
state = self.default()
state = copy.copy(self.default)
return fn(state, value)
wrapper.__name__ = =
aggregation.aggregations_by_name[] = wrapper
return wrapper
def get_aggregation(name):
return aggregation.aggregations_by_name[name]
def get_aggregations():
return sorted(aggregation.aggregations_by_name.items())
@aggregation('first', help='The first piece of data from a column in a group.')
def _(state, value):
return state
@aggregation('last', help='The last piece of data from a column in a group.')
def _(state, value):
return value
@aggregation('count', default=0, help='The count of records in a group.')
def _(state, value):
return state + 1
@aggregation('min', type=Decimal, help='The minimal value of a column in a group. The column must contain numeric data.')
def _(state, value):
return min(state, value)
@aggregation('max', type=Decimal, help='The maximal value of a column in a group. The column must contain numeric data.')
def _(state, value):
return max(state, value)
@aggregation('sum', type=Decimal, help='The sum of all values of a column in a group. The column must contain numeric data.')
def _(state, value):
return (state + value)
class AvgData(object):
def __init__(self):
self.num = self.den = 0
def add(self, num, den):
self.num += num
self.den += den
return self
def __str__(self):
prec = len(str(self.num).split('.')[1]) if '.' in str(self.num) else 0
fmt = "%%.%if" % prec
return fmt % float(self.num/self.den)
@aggregation('avg', type=Decimal, default=AvgData, help='The average value of a column in a group. The column must contain numeric data.')
def _(state, value):
return state.add(value, 1)
class DrawdownData(object):
def __init__(self): = None
self.drawdown = 0
def add(self, value): = max(, value)
self.drawdown = max(self.drawdown, - value)
return self
def __str__(self):
return str(self.drawdown)
@aggregation('drawdown', type=Decimal, default=DrawdownData, help='The drawdown of a column in a group. The column must contain numeric data.')
def _(state, value):
return state.add(value)
def get_key_function(ids):
return lambda cols: tuple(cols[id-1] for id in ids)
def get_aggregation_functions(specs):
funcs = []
for spec in specs:
name, col = spec.split(':')
func, col = aggregation.get_aggregation(name), int(col)
funcs.append(lambda state, cols, func=func, col=col: func(state, cols[col-1]))
return funcs
def main(options, args):
key_func = get_key_function(options.keys)
agg_funcs = get_aggregation_functions(args)
# Order in which the keys appeared in the data. We save this in order to output the aggregated
# data in the order that we saw keys.
keys = []
# Mapping for key to the states for the different aggregation functions in the group.
agg_states = {}
for line in sys.stdin:
cols = line.split()
key = key_func(cols)
states = agg_states.get(key)
if states is None:
states = [None for _ in agg_funcs]
agg_states[key] = [func(state, cols) for (func, state) in itertools.izip(agg_funcs, states)]
for key in keys:
print ' '.join(str(x) for x in itertools.chain(key, agg_states[key]))
if __name__ == '__main__':
from errno import EPIPE
from optparse import OptionParser
desc = r"""
Groups and aggregates data read from standard input. This script acts like the 'group by' statement
in a SQL query. Data within a group can be aggregated using the following aggregation functions:
for (name, agg) in aggregation.get_aggregations():
desc += " %-20s %s\n" % (name,
parser = OptionParser(usage='Usage: %prog [options] func:col [func:col...]' + '\n\n' + desc.strip())
parser.add_option('-k', '--key', dest='keys', action='append', type='int', metavar='KEY',
help='Field to use as part of a row key. Can be specified multiple times.')
(options, args) = parser.parse_args()
if not options.keys:
options.keys = [1]
if len(args) == 0:
parser.error('Missing aggregation function(s)')
main(options, args)
except KeyboardInterrupt:
# User requested the program to exit
except IOError, e:
if e.errno != EPIPE:
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment