Last active
May 25, 2017 09:55
-
-
Save PavlosMelissinos/6ddb1ba86f51214477a76ce670d18c97 to your computer and use it in GitHub Desktop.
Rudimentary implementation of unpooling for keras
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from keras import backend as K | |
from keras.layers.convolutional import UpSampling2D | |
class MaxPoolingMask2D(MaxPooling2D): | |
def __init__(self, pool_size=(2, 2), strides=None, border_mode='valid', | |
dim_ordering='default', **kwargs): | |
super(MaxPoolingMask2D, self).__init__(pool_size, strides, border_mode, | |
dim_ordering, **kwargs) | |
def _pooling_function(self, inputs, pool_size, strides, | |
border_mode, dim_ordering): | |
pooled = K.pool2d(inputs, pool_size, strides, border_mode, | |
dim_ordering, pool_mode='max') | |
upsampled = UpSampling2D(size=pool_size)(pooled) | |
indexMask = K.tf.equal(inputs, upsampled) | |
assert indexMask.get_shape().as_list() == inputs.get_shape().as_list() | |
return indexMask | |
def get_output_shape_for(self, input_shape): | |
return input_shape | |
def unpooling(inputs): | |
''' | |
do unpooling with indices, move this to separate layer if it works | |
1. do naive upsampling (repeat elements) | |
2. keep only values in mask (stored indices) and set the rest to zeros | |
''' | |
x = inputs[0] | |
mask = inputs[1] | |
mask_shape = mask.get_shape().as_list() | |
x_shape = x.get_shape().as_list() | |
pool_size = (mask_shape[1] / x_shape[1], mask_shape[2] / x_shape[2]) | |
on_success = UpSampling2D(size=pool_size)(x) | |
on_fail = K.zeros_like(on_success) | |
return K.tf.where(mask, on_success, on_fail) | |
def unpooling_output_shape(input_shape): | |
return input_shape[1] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from keras.engine.topology import merge | |
from keras.layers.pooling import MaxPooling2D | |
from .keras_unpooling import MaxPoolingMask2D # replace this with the actual location of the class | |
x = MaxPooling2D(pool_size=(pool_size, pool_size))(inp) | |
mask = MaxPoolingMask2D(pool_size=(pool_size, pool_size))(inp) # index selector | |
# Do stuff... | |
x = merge([x, mask], mode=unpooling, output_shape=unpooling_output_shape) |
Tensorflow has tf.nn.max_pool_with_argmax
which may be better optimized for what you're trying to do.
I'd also guess that using something like tf.scatter_nd
, which modifies a tensor in-place at given indices, would be more efficient than comparing large sparse tensors using tf.where
.
Thanks @allanzelener, those are some really helpful suggestions. Apologies for not responding for all this time but I didn't get any kind of notification that you've posted here.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Code is on keras 1.x (I think I was using 1.1 at the time). Might need some extra work to get it working on keras 2+.