Skip to content

Instantly share code, notes, and snippets.

@PavlosMelissinos
Last active May 25, 2017 09:55
Show Gist options
  • Save PavlosMelissinos/6ddb1ba86f51214477a76ce670d18c97 to your computer and use it in GitHub Desktop.
Save PavlosMelissinos/6ddb1ba86f51214477a76ce670d18c97 to your computer and use it in GitHub Desktop.
Rudimentary implementation of unpooling for keras
from keras import backend as K
from keras.layers.convolutional import UpSampling2D
class MaxPoolingMask2D(MaxPooling2D):
def __init__(self, pool_size=(2, 2), strides=None, border_mode='valid',
dim_ordering='default', **kwargs):
super(MaxPoolingMask2D, self).__init__(pool_size, strides, border_mode,
dim_ordering, **kwargs)
def _pooling_function(self, inputs, pool_size, strides,
border_mode, dim_ordering):
pooled = K.pool2d(inputs, pool_size, strides, border_mode,
dim_ordering, pool_mode='max')
upsampled = UpSampling2D(size=pool_size)(pooled)
indexMask = K.tf.equal(inputs, upsampled)
assert indexMask.get_shape().as_list() == inputs.get_shape().as_list()
return indexMask
def get_output_shape_for(self, input_shape):
return input_shape
def unpooling(inputs):
'''
do unpooling with indices, move this to separate layer if it works
1. do naive upsampling (repeat elements)
2. keep only values in mask (stored indices) and set the rest to zeros
'''
x = inputs[0]
mask = inputs[1]
mask_shape = mask.get_shape().as_list()
x_shape = x.get_shape().as_list()
pool_size = (mask_shape[1] / x_shape[1], mask_shape[2] / x_shape[2])
on_success = UpSampling2D(size=pool_size)(x)
on_fail = K.zeros_like(on_success)
return K.tf.where(mask, on_success, on_fail)
def unpooling_output_shape(input_shape):
return input_shape[1]
from keras.engine.topology import merge
from keras.layers.pooling import MaxPooling2D
from .keras_unpooling import MaxPoolingMask2D # replace this with the actual location of the class
x = MaxPooling2D(pool_size=(pool_size, pool_size))(inp)
mask = MaxPoolingMask2D(pool_size=(pool_size, pool_size))(inp) # index selector
# Do stuff...
x = merge([x, mask], mode=unpooling, output_shape=unpooling_output_shape)
@PavlosMelissinos
Copy link
Author

PavlosMelissinos commented May 5, 2017

Code is on keras 1.x (I think I was using 1.1 at the time). Might need some extra work to get it working on keras 2+.

@allanzelener
Copy link

Tensorflow has tf.nn.max_pool_with_argmax which may be better optimized for what you're trying to do.

I'd also guess that using something like tf.scatter_nd, which modifies a tensor in-place at given indices, would be more efficient than comparing large sparse tensors using tf.where.

@PavlosMelissinos
Copy link
Author

PavlosMelissinos commented May 23, 2017

Thanks @allanzelener, those are some really helpful suggestions. Apologies for not responding for all this time but I didn't get any kind of notification that you've posted here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment