Skip to content

Instantly share code, notes, and snippets.

@eickenberg
Last active August 14, 2016 21:53
Show Gist options
  • Save eickenberg/f1a0e368961ef6d05b5b to your computer and use it in GitHub Desktop.
Save eickenberg/f1a0e368961ef6d05b5b to your computer and use it in GitHub Desktop.
mgrid and ogrid for theano
# Implements functionality corresponding to numpy.mgrid / numpy.ogrid for symbolic theano variables
# Author: Michael Eickenberg, michael.eickenberg@nsup.org
import theano
import theano.tensor as T
class _nd_grid(object):
"""Implements the mgrid and ogrid functionality for theano tensor
variables.
Parameters
==========
sparse : boolean, optional, default=True
Specifying False leads to the equivalent of numpy's mgrid
functionality. Specifying True leads to the equivalent of ogrid.
"""
def __init__(self, sparse=False):
self.sparse = sparse
def __getitem__(self, *args):
ndim = len(args[0])
ranges = [T.arange(sl.start, sl.stop, sl.step) for sl in args[0]]
shapes = [tuple([1] * j + [r.shape[0]] + [1] * (ndim - 1 - j))
for j, r in enumerate(ranges)]
ranges = [r.reshape(shape) for r, shape in zip(ranges, shapes)]
ones = [T.ones_like(r) for r in ranges]
if self.sparse:
grids = ranges
else:
grids = []
for i in range(ndim):
grid = 1
for j in range(ndim):
if j == i:
grid = grid * ranges[j]
else:
grid = grid * ones[j]
grids.append(grid)
return grids
mgrid = _nd_grid()
ogrid = _nd_grid(sparse=True)
def test__mgrid__ogrid():
import numpy as np
from numpy.testing import assert_array_equal, assert_array_almost_equal
fmgrid = np.mgrid[0:1:.1, 1:10:1., 10:100:10.]
imgrid = np.mgrid[0:2:1, 1:10:1, 10:100:10]
fogrid = np.ogrid[0:1:.1, 1:10:1., 10:100:10.]
iogrid = np.ogrid[0:2:1, 1:10:1, 10:100:10]
tfmgrid = mgrid[0:1:.1, 1:10:1., 10:100:10.]
timgrid = mgrid[0:2:1, 1:10:1, 10:100:10]
tfogrid = ogrid[0:1:.1, 1:10:1., 10:100:10.]
tiogrid = ogrid[0:2:1, 1:10:1, 10:100:10]
for g1, g2 in zip([fmgrid, imgrid, fogrid, iogrid],
[tfmgrid, timgrid, tfogrid, tiogrid]):
for v1, v2 in zip(g1, g2):
assert_array_almost_equal(v1, v2.eval(), decimal=6)
if __name__ == "__main__":
test__mgrid__ogrid()
@Franck-Dernoncourt
Copy link

Thanks for sharing! Just an update: mgrid and ogrid were introduced in Theano 0.8 (released on 2016-03-21).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment