Skip to content

Instantly share code, notes, and snippets.

@libfun
Last active September 28, 2016 14:02
Show Gist options
  • Save libfun/dadd3b0208bfe53249fecb7a29c7c906 to your computer and use it in GitHub Desktop.
Save libfun/dadd3b0208bfe53249fecb7a29c7c906 to your computer and use it in GitHub Desktop.
import theano.tensor as T
from lasagne.layers import Layer
class Unpool3DLayer(Layer):
"""
3D Unpooling layer
This layer performs unpooling over the last two dimensions
of a 5D tensor.
Parameters
----------
incoming : a :class:`Layer` instance or tuple
The layer feeding into this layer, or the expected input shape.
pool_size : integer or iterable
The length of the pooling region in each dimension. If an integer, it
is promoted to a square pooling region. If an iterable, it should have
two elements.
mode : {'repeat', 'bed_of_nails'}
Unpooling mode: unpool repeating tensor values or using bed of nails.
Default is 'repeat'.
**kwargs
Any additional keyword arguments are passed to the :class:`Layer`
superclass.
"""
def __init__(self, incoming, ds, mode='repeat', **kwargs):
super(Unpool3DLayer, self).__init__(incoming, **kwargs)
self.mode = mode
if (isinstance(ds, int)):
raise ValueError('ds must be int')
else:
ds = tuple(ds)
if len(ds) != 3:
raise ValueError('ds must have len == 3')
self.ds = ds
def get_output_shape_for(self, input_shape):
output_shape = list(input_shape)
output_shape[2] = input_shape[2] * self.ds[0]
output_shape[3] = input_shape[3] * self.ds[1]
output_shape[4] = input_shape[4] * self.ds[2]
return tuple(output_shape)
def get_output_for(self, input, **kwargs):
ds = self.ds
input_shape = input.shape
output_shape = self.get_output_shape_for(input_shape)
if self.mode == 'bed_of_nails':
_, _, i, j, k = T.nonzero(input)
res = T.zeros(shape=output_shape, dtype='float32')
res = T.set_subtensor(res[:, :, i * self.ds[0], j * self.ds[1], k * self.ds[2]],
input[:, :, i, j, k])
return res
else:
return input.repeat(self.ds[0], axis=2).repeat(self.ds[1], axis=3).repeat(self.ds[2], axis=4)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment