Skip to content

Instantly share code, notes, and snippets.

@josharian
Created December 29, 2011 02:30
Show Gist options
  • Save josharian/1531289 to your computer and use it in GitHub Desktop.
Save josharian/1531289 to your computer and use it in GitHub Desktop.
Sample theano trampoline for supporting generic variables
#!/usr/bin/env python
"""
sample theano trampoline
"""
import numpy
import theano
import theano.tensor as T
def landing_key(value):
return type(value) # TODO: Include numpy shape, dtype, stride; autocasting?
class BounceVariable(object):
def __init__(self, name):
self.name = name
self.landing_cache = {}
def land(self, value):
key = landing_key(value)
if key not in self.landing_cache:
if isinstance(value, (int, float)):
landed = T.dscalar(self.name)
else:
landed = T.dmatrix(self.name)
self.landing_cache[key] = landed
return self.landing_cache[key]
def __add__(self, other):
return BounceOp("__add__", self, other)
class BounceOp(object):
def __init__(self, wrapped_op, *variables):
self.wrapped_op = wrapped_op
self.variables = variables
def generate(self, inputs_dict):
landed_variables = []
for variable in self.variables:
if isinstance(variable, BounceVariable):
landed = variable.land(inputs_dict[variable])
elif isinstance(variable, BounceOp):
landed = variable.generate(inputs_dict)
elif isinstance(variable, (int, float, numpy.ndarray, list)):
landed = variable
else:
raise AssertionError()
landed_variables.append(landed)
# TODO: Recurse to new ops
if isinstance(self.wrapped_op, str):
return getattr(landed_variables[0], self.wrapped_op)(*landed_variables[1:])
else:
# TODO: Handle "standalone" ops
raise NotImplementedError()
def __add__(self, other):
return BounceOp("__add__", self, other)
def __getattr__(self, name):
def callable(self, *args):
return BounceOp(name, *args)
class FunctionTrampoline(object):
def __init__(self, inputs, output):
self.inputs = inputs
self.output = output
self.cached_functions = {}
def __call__(self, *args):
land_key = tuple(map(landing_key, args))
if land_key not in self.cached_functions:
inputs_dict = dict(zip(self.inputs, args))
landed_inputs_list = [inp.land(arg) for inp, arg in inputs_dict.iteritems()]
landed_output = self.output.generate(inputs_dict)
self.cached_functions[land_key] = theano.function(landed_inputs_list, landed_output)
return self.cached_functions[land_key](*args)
def create_scalar_adder():
x = T.dscalar('x')
y = T.dscalar('y')
w = T.dscalar('y')
z = x + y + w
f = theano.function([x, y, w], z)
return f
def create_matrix_adder():
x = T.dmatrix('x')
y = T.dmatrix('y')
w = T.dmatrix('w')
z = x + y + w
f = theano.function([x, y, w], z)
return f
def create_generic_adder():
x = BounceVariable('x')
y = BounceVariable('y')
w = BounceVariable('w')
z = x + y + w
f = FunctionTrampoline([x, y, w], z)
return f
def main():
scalar_adder = create_scalar_adder()
assert 9 == scalar_adder(2, 3, 4)
matrix_adder = create_matrix_adder()
assert [[9]] == matrix_adder([[2]], [[3]], [[4]])
try:
scalar_adder([[2]], [[3]], [[4]])
raise AssertionError
except TypeError:
pass
try:
matrix_adder(2, 3, 4)
raise AssertionError
except TypeError:
pass
generic_adder = create_generic_adder()
assert 9 == generic_adder(2, 3, 4)
assert [[9]] == generic_adder([[2]], [[3]], [[4]])
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment