Skip to content

Instantly share code, notes, and snippets.

@has2k1
Last active August 29, 2015 13:57
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save has2k1/9637948 to your computer and use it in GitHub Desktop.
Save has2k1/9637948 to your computer and use it in GitHub Desktop.
import numpy as np
class geom(object):
"""Base class of all geoms"""
DEFAULT_AES = dict()
REQUIRED_AES = set()
DEFAULT_PARAMS = dict()
data = None
aes = None
manual_aes = None
params = None
def __init__(self, *args, **kwargs):
# assign data, aes, manual_aes, params
pass
def __radd__(self, gg):
# add layer to ggplot object
pass
def _rename_aes(self, data, translations):
# helper function for the geoms
# to convert from ggplot2 api to matplotib
return data
def plot_layer(self, data, ax):
# abstract function to be implemented by each geom
pass
#
class geom_point(geom):
DEFAULT_AES = {'alpha': 1, 'color': 'black', 'fill': None,
'shape': 'o', 'size': 20}
REQUIRED_AES = {'x', 'y'}
DEFAULT_PARAMS = {'stat': 'identity', 'position': 'identity',
'cmap':None, 'label': ''}
def plot_layer(self, data, ax):
translations = {'size': 's', 'shape': 'marker',
'color': 'edgecolor', 'fill': 'color'}
# TODO: Not sure if position adjustments are applied before or
# after the grouping.
if self.params['position'] == 'jitter':
data = self._jitter(data)
_groups = set(self.aes) & {'color', 'fill', 'shape', 'alpha', 'size'}
for key, _data in data.groupby(_groups):
_data = _data.to_dict()
_data = self._rename_aes(_data, translations)
ax.scatter(**_data)
class stat(object):
"""Base class of all stats"""
REQUIRED_AES = set()
DEFAULT_PARAMS = dict()
CREATES = set() # extra columns created by the stat
data = None
aes = None
params = None
def __init__(self, *args, **kwargs):
# assign data, aes, params
pass
def __radd__(self, gg):
# add layer to ggplot object
pass
def compute(self, data):
# abstract function to be implemented by each
return data
# An example of a stat
class stat_identity(stat):
DEFAULT_PARAMS = {'geom': 'point', 'position': 'identity',
'width': None, 'height': None}
def compute(self, data):
return data
class position(object):
"""Base class for all positions"""
# Aesthetics that map onto the x and y scales
X = {'x', 'xmin', 'xmax', 'xend', 'xintercept'}
Y = {'y', 'ymin', 'xmax', 'yend', 'yintercept'}
def __init__(self, width=None, height=None, **kwargs):
self.width = kwargs.get('w', width)
self.height = kwargs.get('h', height)
def adjust(self, data):
"""
Positions must override this function
How?
----
Make necessary adjustments the columns in the dataframe.
Create the position transformation functions and
use self._transform_position() do the rest.
See: position_jitter.adjust()
"""
return data
def _transform_position(self, data, trans_x=None, trans_y=None):
"""
Transform all the variables map onto the x and y scales.
Parameters
----------
data : dataframe
trans_x : function
Transforms x scale mappings
Takes one argument, either a scalar or an array-type
trans_y : function
Transforms y scale mappings
Takes one argument, either a scalar or an array-type
Helper function for self.adjust
"""
if trans_x:
xs = filter(lambda name: name in self.X, data.columns)
data[xs] = data[x].apply(trans_x)
if trans_y:
ys = filter(lambda name: name in self.Y, data.columns)
data[ys] = data[ys].apply(trans_y)
return data
class position_identity(position):
pass
class position_jitter(position):
def adjust(self, data):
if not self.width:
self.width = resolution(data['x']) * .4
if not self.height:
self.height = resolution(data['y']) * .4
trans_x = None
trans_y = None
if self.width:
trans_x = lambda x: jitter(x, self.width)
if self.height:
trans_y = lambda y: jitter(y, self.height)
return self._transform_position(data, trans_x, trans_y)
########### Ported functions for position='jitter' ##############
def resolution(x, zero=True):
"""
Compute the resolution of a data vector
Resolution is smallest non-zero distance between adjacent values
Parameters
----------
x : 1D array_like
zero : Boolean
Whether to include zero values in the computation
Result
------
res : resolution of x
If x is an integer array, then the resolution is 1
"""
if isinstance(x, list, tuple):
x = np.array(x)
# (unsigned) integers or an effective range of zero
if (x.dtype.kind in ('i', 'u') or
x.ptp() < np.finfo(float).resolution()):
return 1
if not zero:
x = x[x!=0]
return min(np.diff(np.sort(x)))
def jitter(x, factor=1, amount=None):
"""
Add a small amount of noise to values in an array_like
"""
if len(x) == 0:
return x
if isinstance(x, (list, tuple)):
x = np.array(x)
try:
z = np.ptp(x[np.isfinite(x)])
except IndexError:
z = 0
if z == 0:
z = abs(min(x))
if z == 0:
z = 1
if amount is None:
_x = np.round(x, 3-np.int(np.floor(np.log10(z)))).astype(np.int)
xx = np.unique(np.sort(_x))
d = np.diff(xx)
if len(d):
d = min(d)
elif xx != 0:
d = xx/10.
else:
d = z/10
amount = factor/5. * abs(d)
elif amount == 0:
amount = factor * (z / 50.)
return x + np.random.uniform(-amount, amount, len(x))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment