Skip to content

Instantly share code, notes, and snippets.

What would you like to do?
A dictionary-ready wrapper for theano functions.
__author__ = "nasim.rahaman at"
__doc__ = """A few bells and whistles for the theano function callable.
import theano.tensor as T
x = T.scalar()
y = T.scalar()
f1 = function(inputs={'x': x, 'y': y}, outputs={'z1': x + y, 'z2': x + 2*y})
f1(x=2, y=3)
# Output: {'z1': 5, 'z2': 8}
f2 = function(inputs={'x': x, 'y': y}, outputs={'z12': [x + y, x + 2*y]})
f2(x=2, y=3)
# Output: {'z12': [5, 8]}
f3 = function(inputs=[x, y], outputs=[x + y, x + 2*y])
f3(2, 3)
# Output: (5, 8)
f4 = function(inputs=[x, y], outputs={'z1': x + y, 'z2': x + 2*y})
f4(2, 3)
# Output: {'z1': 5, 'z2': 8}
Can be useful for e.g. having a theano function return the gradients
of a variable w.r.t. multiple theano variables alongside (say) the cost scalar,
without having to worry about the output ordering/bookkeeping.
P.S. Dictionary outputs for theano functions is built in, but they can't (yet)
organize outputs to lists. That's what this wrapper is intended for.
import theano as th
import numpy as np
class pyk:
# Convert a tuple or a non iterable to a list, simultaneously
def obj2list(obj, ndarray2list=True):
listlike = (list, tuple, np.ndarray) if ndarray2list else (list, tuple)
# Try-except clause may not work here because layertrain is an iterator and can be converted to list
if isinstance(obj, listlike):
return list(obj)
return [obj]
def delist(l):
if isinstance(l, (list, tuple)) and len(l) == 1:
return l[0]
return l
# Function to fold a list according to a given lenlist. For l = [a, b, c, d, e] and lenlist = [1, 1, 2, 1],
# unflatten(l) = [a, b, [c, d], e]
def unflatten(l, lenlist):
assert len(l) == sum(lenlist), "Provided length list is not consistent with the list length."
lc = l[:]
outlist = []
for len_ in lenlist:
outsublist = []
for _ in range(len_):
return outlist
def flatten(*args):
return (result for mid in args for result in (pyk.flatten(*mid) if isinstance(mid, (tuple, list)) else (mid,)))
# Smart len function that doesn't break when input is not a list/tuple
def smartlen(l):
if isinstance(l, (list, tuple)):
return len(l)
return 1
# Generic class for functions
class function(object):
def __init__(self, inputs, outputs, mode=None, updates=None, givens=None, no_default_updates=False,
accept_inplace=False, name=None, rebuild_strict=True, allow_input_downcast=None, profile=None,
A simple wrapper for theano functions (can be used with Lasagne), with added syntactic sugar.
:type inputs: list or dict
:param inputs: List of inputs, or alternatively a dict with {'name1': var1, ...}.
:type outputs: list or dict
:param outputs: List of outputs, or alternatively a dict with {'name1': var1, ...}.
:type mode: str or theano.function.Mode
:param mode: Compilation Mode.
:type updates: list or tuple or dict
:param updates: Expressions for new SharedVariable values. Must be iterable over pairs of
(shared_variable, update expression)
:type givens: list or tuple or dict
:param givens: Substitutions to make in the computational graph. Must be iterable over pairs of variables
(var1, var2) where var2 replaces var1 in the computational graph.
:type no_default_updates: bool or list
:param no_default_updates: If True: whether to update variables. See official theano documentation here:
:type accept_inplace: bool
:param accept_inplace: See official theano documentation:
:type name: str
:param name: Name of the function. Useful for profiling.
:type rebuild_strict: bool
:param rebuild_strict: See official theano documentation:
:type allow_input_downcast: bool
:param allow_input_downcast: Whether to allow the input to be downcasted to floatX.
:type profile: bool
:param profile: Whether to profile function. See official theano documentation:
:type on_unused_input: str
:param on_unused_input: What to do if an input is not used.
# Meta
self.inputs = inputs
self.outputs = outputs
self.mode = mode
self.updates = updates
self.givens = givens
self.no_default_updates = no_default_updates
self.accept_inplace = accept_inplace = name
self.rebuild_strict = rebuild_strict
self.allow_input_downcast = allow_input_downcast
self.profile = profile
self.on_unused_input = on_unused_input
# Function containers
self._thfunction = None
self._function = self.__call__
# Compile function
def compile(self):
# If self.inputs is a dict, it must be parsed as kwargs
# If self.outputs is a dict, the output of the compiled function must be parsed to a dict
# Step 1. Compile theano function.
# Fetch input list
inplist = self.inputs if isinstance(self.inputs, list) else self.inputs.values() \
if isinstance(self.inputs, dict) else [self.inputs]
# Flatten inplist to a list
inplist = list(pyk.flatten(inplist))
# Fetch output list
outlist = self.outputs if isinstance(self.outputs, list) else self.outputs.values() \
if isinstance(self.outputs, dict) else [self.outputs]
# Flatten outlist
outlist = pyk.delist(list(pyk.flatten(outlist)))
# Compile
thfunction = th.function(inputs=inplist, outputs=outlist, mode=self.mode, updates=self.updates,
givens=self.givens, no_default_updates=self.no_default_updates,
accept_inplace=self.accept_inplace,, rebuild_strict=self.rebuild_strict,
allow_input_downcast=self.allow_input_downcast, profile=self.profile,
# Write to container
self._thfunction = thfunction
return thfunction
def __call__(self, *args, **kwargs):
# This function wraps the compiled theano function.
# ------------------------------------------------------
# Don't allow args if self.inputs is a dictionary. This is because the user can not be expected to know
# exactly how a dictionary is ordered, unless the dictionary is ordered.
args = list(args)
if isinstance(self.inputs, dict):
assert not args, "Antipasti function object expects keyword arguments because the " \
"provided input was a dict."
if isinstance(self.inputs, list):
assert not kwargs, "Keywords could not be parsed by the Antipasti function object."
# Flatten kwargs or args
if args:
funcargs = list(pyk.flatten(args))
funcargs = list(pyk.flatten(kwargs.values()))
# Evaluate function
outlist = pyk.obj2list(self._thfunction(*funcargs), ndarray2list=False)
# Parse output list
expoutputs = self.outputs.values() if isinstance(self.outputs, dict) else self.outputs
expoutputs = pyk.obj2list(expoutputs, ndarray2list=False)
# Make sure the theano function has returned the correct number of outputs
assert len(outlist) == len(list(pyk.flatten(expoutputs))), "Number of outputs returned by the theano function " \
"is not consistent with the number of expected " \
# Unflatten theano function output (outlist)
# Get list with sublist lengths
lenlist = [pyk.smartlen(expoutput) for expoutput in expoutputs]
# Unflatten outlist
outputs = pyk.unflatten(outlist, lenlist)
# Write to dictionary if self.outputs is a dictionary
if isinstance(self.outputs, dict):
outputs = {outname: outvar for outname, outvar in zip(self.outputs.keys(), outputs)}
elif isinstance(self.outputs, list):
outputs = tuple(outputs)
outputs = pyk.delist(outputs)
return outputs
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment