Skip to content

Instantly share code, notes, and snippets.

@shelhamer
Last active October 7, 2018 16:12
Show Gist options
  • Star 20 You must be signed in to star a gist
  • Fork 4 You must be signed in to fork a gist
  • Save shelhamer/8d9a94cf75e6fb2df221 to your computer and use it in GitHub Desktop.
Save shelhamer/8d9a94cf75e6fb2df221 to your computer and use it in GitHub Desktop.
Euclidean Loss as a Python Layer
name: 'EuclideanExample'
layer {
type: 'DummyData'
name: 'x'
top: 'x'
dummy_data_param {
shape: { dim: 10 dim: 3 dim: 2 }
data_filler: { type: 'gaussian' }
}
}
layer {
type: 'DummyData'
name: 'y'
top: 'y'
dummy_data_param {
shape: { dim: 10 dim: 3 dim: 2 }
data_filler: { type: 'gaussian' }
}
}
# include InnerProduct layers for parameters
# so the net will need backward
layer {
type: 'InnerProduct'
name: 'ipx'
top: 'ipx'
bottom: 'x'
inner_product_param {
num_output: 10
weight_filler { type: 'xavier' }
}
}
layer {
type: 'InnerProduct'
name: 'ipy'
top: 'ipy'
bottom: 'y'
inner_product_param {
num_output: 10
weight_filler { type: 'xavier' }
}
}
layer {
type: 'Python'
name: 'loss'
top: 'loss'
bottom: 'ipx'
bottom: 'ipy'
python_param {
# the module name -- usually the filename -- that needs to be in $PYTHONPATH
module: 'pyloss'
# the layer name -- the class name in the module
layer: 'EuclideanLossLayer'
}
# set loss weight so Caffe knows this is a loss layer
loss_weight: 1
}
import caffe
import numpy as np
class EuclideanLossLayer(caffe.Layer):
def setup(self, bottom, top):
# check input pair
if len(bottom) != 2:
raise Exception("Need two inputs to compute distance.")
def reshape(self, bottom, top):
# check input dimensions match
if bottom[0].count != bottom[1].count:
raise Exception("Inputs must have the same dimension.")
# difference is shape of inputs
self.diff = np.zeros_like(bottom[0].data, dtype=np.float32)
# loss output is scalar
top[0].reshape(1)
def forward(self, bottom, top):
self.diff[...] = bottom[0].data - bottom[1].data
top[0].data[...] = np.sum(self.diff**2) / bottom[0].num / 2.
def backward(self, top, propagate_down, bottom):
for i in range(2):
if not propagate_down[i]:
continue
if i == 0:
sign = 1
else:
sign = -1
bottom[i].diff[...] = sign * self.diff / bottom[i].num
Copy link

ghost commented Oct 6, 2015

Consider an input layer like so:

    layer {
      type: 'Python'
      name: 'x'
      top: 'x'
      python_param {
          # the module name -- usually the filename -- that needs to be in $PYTHONPATH
          module: 'pyinput'
          # the layer name -- the class name in the module
          layer: 'PythonDummyData'
      }
    }

What would the appropriate pyinput.py file look like, in order to simply emit random dummy data similar to the example above? Knowing the answer to this, I think I could build on it enough for my more complicated pre-processing needs...

edit:

my guess, but i'm pretty sure it doesn't work:

import caffe
import numpy as np

class PythonDummyData(caffe.Layer):
    def setup(self, bottom, top):
        assert(len(top)==1) # one output

    def reshape(self, bottom, top):
        pass

    def forward(self, bottom, top):
        top[0].data[...] = np.random.randn((10, 3, 2))

    def backward(self, top, propagate_down, bottom):
        pass

edit 2:

If I am successful in getting this to work, I'd be happy to work on an example for the repo...

Copy link

ghost commented Oct 6, 2015

no more errors... going to try to turn this into an actual learning example, and see if it works properly

import caffe
import numpy as np

class PythonDummyData(caffe.Layer):
    def setup(self, bottom, top):
        assert(len(top)==1) # one output

    def reshape(self, bottom, top):
        top[0].reshape(10,3,2)

    def forward(self, bottom, top):
        top[0].data[...] = np.random.randn(10, 3, 2)

    def backward(self,top, propagate_down, bottom):
        pass

@KorolevDmitry
Copy link

Hello, I am trying to use such python layer but I can't.
I am getting this exception:
I1007 17:48:31.366592 30357 layer_factory.hpp:77] Creating layer loss *** Aborted at 1475851711 (unix time) try "date -d @1475851711" if you are using GNU date *** PC: @ 0x7f32895f1156 (unknown) *** SIGSEGV (@0x0) received by PID 30357 (TID 0x7f328b07fa40) from PID 0; stack trace: *** @ 0x7f328883ecb0 (unknown) @ 0x7f32895f1156 (unknown) @ 0x7f3289b43dfe (unknown) @ 0x7f32429d0d9c google::protobuf::MessageLite::ParseFromArray() @ 0x7f3242a1f652 google::protobuf::EncodedDescriptorDatabase::Add() @ 0x7f32429da012 google::protobuf::DescriptorPool::InternalAddGeneratedFile() @ 0x7f3242a2b33e google::protobuf::protobuf_AddDesc_google_2fprotobuf_2fdescriptor_2eproto() @ 0x7f3242a5aa75 google::protobuf::StaticDescriptorInitializer_google_2fprotobuf_2fdescriptor_2eproto::StaticDescriptorInitializer_google_2fprotobuf_2fdescriptor_2eproto() @ 0x7f3242a56beb __static_initialization_and_destruction_0() @ 0x7f3242a56c00 _GLOBAL__sub_I_descriptor.pb.cc @ 0x7f328aeca10a (unknown) @ 0x7f328aeca1f3 (unknown) @ 0x7f328aecec30 (unknown) @ 0x7f328aec9fc4 (unknown) @ 0x7f328aece37b (unknown) @ 0x7f327d91b02b (unknown) @ 0x7f328aec9fc4 (unknown) @ 0x7f327d91b62d (unknown) @ 0x7f327d91b0c1 (unknown) @ 0x7f3288f412ae (unknown) @ 0x7f3288f09dae (unknown) @ 0x7f3288f88729 (unknown) @ 0x7f3288ebccbf (unknown) @ 0x7f3288f81d66 (unknown) @ 0x7f3288e47a3f (unknown) @ 0x7f3288f12d43 (unknown) @ 0x7f3288f8b577 (unknown) @ 0x7f3288f6dc13 (unknown) @ 0x7f3288f7154d (unknown) @ 0x7f3288f71682 (unknown) @ 0x7f3288f71a2c (unknown) @ 0x7f3288f88016 (unknown) Segmentation fault (core dumped)

And I really don't know what to do... Could you suggest anything?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment