Skip to content

Instantly share code, notes, and snippets.

@mrocklin
Created June 5, 2012 01:10
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save mrocklin/2871821 to your computer and use it in GitHub Desktop.
Save mrocklin/2871821 to your computer and use it in GitHub Desktop.
A quick function to compute the shapes of all variables in a Theano Env
def shape_of_variables(env, input_shapes):
"""
Inputs:
env - the theano.Env in question
input_shapes - a dict mapping input to shape
Outputs:
shapes - a dict mapping variable to shape
WARNING : This modifies the env! Not pure!
>>> import theano
>>> x = theano.tensor.matrix('x')
>>> y = x[512:]; y.name = 'y'
>>> env = theano.Env([x], [y])
>>> shape_of_variables(env, {x: (1024, 1024)})
{y: (512, 1024), x: (1024, 1024)}
"""
if not hasattr(env, 'shape_feature'):
env.extend(theano.tensor.opt.ShapeFeature())
input_dims = [dimension for inp in env.inputs
for dimension in env.shape_feature.shape_of[inp]]
output_dims = [dimension for shape in env.shape_feature.shape_of.values()
for dimension in shape]
compute_shapes = theano.function(input_dims, output_dims)
numeric_input_dims = [dim for inp in env.inputs
for dim in input_shapes[inp]]
numeric_output_dims = compute_shapes(*numeric_input_dims)
sym_to_num_dict = dict(zip(output_dims, numeric_output_dims))
return {var: tuple(sym_to_num_dict[dim]
for dim in env.shape_feature.shape_of[var])
for var in env.shape_feature.shape_of}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment