Skip to content

Instantly share code, notes, and snippets.

@cshimmin
Created June 13, 2019 21:17
Show Gist options
  • Save cshimmin/5f96a693530fccdcafff4ce42a483d2f to your computer and use it in GitHub Desktop.
Save cshimmin/5f96a693530fccdcafff4ce42a483d2f to your computer and use it in GitHub Desktop.
keras layer to rotate a (tensor ending with) 2d vector
class RotationLayer(layers.Layer):
def __init__(self, theta, learning_phase_only=True, **kwargs):
self.theta = theta
self.vec_dim = dim
self.R = K.constant([[1,0],[0,1]])*tf.cos(theta) \
+ K.constant([[0,-1],[1,0]])*tf.sin(theta)
self.uses_learning_phase = learning_phase_only
super(RotationLayer, self).__init__(**kwargs)
def call(self, x, training=None):
if self.uses_learning_phase:
return K.in_train_phase(K.dot(x, self.R), x, training=training)
else:
return K.dot(x, self.R)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment