Skip to content

Instantly share code, notes, and snippets.

@christopher-beckham
Last active June 2, 2017 15:17
Show Gist options
  • Save christopher-beckham/362caad79a253a7bde3a7a0e44f11775 to your computer and use it in GitHub Desktop.
Save christopher-beckham/362caad79a253a7bde3a7a0e44f11775 to your computer and use it in GitHub Desktop.
from lasagne.nonlinearities import *
from lasagne.layers import Layer
class SpatialSoftmaxLayer(Layer):
"""
Softmax layer that computes the softmax over pixels in the same location,
i.e., over the channel axis. This layer will automatically use the CuDNN
version of this softmax if it is available.
Parameters
----------
incoming : a :class:`Layer`
dnn_softmax_mode : if CuDNN is enabled, what mode should we use for
that implementation. There are two: 'accurate', and 'fast'
"""
def __init__(self, incoming, dnn_softmax_mode='accurate', **kwargs):
super(SpatialSoftmaxLayer, self).__init__(incoming, **kwargs)
self.use_dnn = False
self.input_shape = incoming.output_shape
self.dnn_softmax_mode = dnn_softmax_mode
try:
from theano.sandbox.cuda import dnn
if theano.sandbox.cuda.cuda_enabled and dnn.dnn_available():
self.use_dnn = True
self.dnn_softmax = dnn.GpuDnnSoftmax
except ImportError:
pass
def get_output_for(self, input, **kwargs):
if self.use_dnn:
return self.dnn_softmax('bc01', algo=self.dnn_softmax_mode, mode='channel')(input)
else:
bs, c, h, w = self.input_shape
ds1 = input.dimshuffle((0,2,3,1))
rs1 = ds1.reshape((-1, c))
softm = softmax(rs1)
rs2 = softm.reshape((-1,h,w,c))
ds2 = rs2.dimshuffle((0,3,1,2))
return ds2
@hmeine
Copy link

hmeine commented May 22, 2017

Could you please publicly give this code a license? (E.g. BSD, Apache, …)?
Looks sane for the time being – in Theano I found several issues / PRs (such as Theano/Theano#5719 ) which try introducing better softmax variants in Theano itself. Eventually, they'll make their way into Theano + Lasagne.

@hmeine
Copy link

hmeine commented Jun 2, 2017

I found two problems with this code:

  • It assumes 4D arrays (i.e., it does not work with 3D CNN).
  • It does not seem to work with FCN (where self.input_shape contains None's).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment