Skip to content

Instantly share code, notes, and snippets.

@FrancescoSaverioZuppichini
Created April 17, 2019 12:19
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save FrancescoSaverioZuppichini/1d040f72de7409a8858eed17dcdb74c8 to your computer and use it in GitHub Desktop.
Save FrancescoSaverioZuppichini/1d040f72de7409a8858eed17dcdb74c8 to your computer and use it in GitHub Desktop.
class Kernel(Layer):
def __init__(self, r, **kwargs):
self.r = r
super(Kernel, self).__init__(**kwargs)
def build(self, input_shape):
self.mu = self.add_weight(name='mu',
shape=(1, self.r),
initializer='uniform',
trainable=True)
self.sigma = self.add_weight(name='sigma',
shape=(1, self.r),
initializer='uniform',
trainable=True)
super(Kernel, self).build(input_shape)
def call(self, x, u):
indices_row_, indices_col_, values_u_, corasen_shape = u
# print('x', x.shape)
batch_size, vertices_n, features_n= x.shape
if DEBUG: print("[INFO] x.shape={}".format(x.shape))
# equation (11)
exp = tf.reduce_sum(tf.square(values_u_ - self.mu)/tf.square(self.sigma), axis=1)
if DEBUG: print("[INFO] exp.shape={}".format(exp.shape))
gaussians = tf.SparseTensor(indices=np.vstack([indices_row_, indices_col_]).T,
values=exp,
dense_shape=[corasen_shape]*2)
if DEBUG: print("[INFO] gaussians.shape={}".format(gaussians.shape))
gaussians = tf.sparse.reorder(gaussians)
# applying softmax to normalize each row
gaussians = tf.sparse.softmax(gaussians)
return gaussians
def compute_output_shape(self, input_shape):
return (input_shape[0], self.output_dim)
class MonetConv(Layer):
def __init__(self, r, u, hidden_size, k, **kwargs):
self.r = r
self.u = u
self.hidden_size = hidden_size
self.k = k
super(MonetConv, self).__init__(**kwargs)
def build(self, input_shape):
self.kernels = [Kernel(self.r) for _ in range(self.k)]
self.W = self.add_weight(name='W',
shape=(input_shape[-1] * self.k, self.hidden_size),
initializer='uniform',
trainable=True)
super(MonetConv, self).build(input_shape)
def run_kernels(self, x, u):
gaussians = tf.stack([self.run_kernel(k, x,u) for k in self.kernels])
# [N_KERNELS, BATCH_SIZE, VERTECES_N, FEATURES_N]
if DEBUG: print("[INFO] gaussians.shape={}".format(gaussians.shape))
return gaussians
def run_kernel(self, kernel, x, u):
batch_size, vertices_n, features_n= x.shape
batch_size = 10
gaussians = kernel(x, u)
# we need to make it from 3d to 2d to include all the batches
x_flat = tf.reshape(x, [vertices_n, batch_size * features_n])
if DEBUG: print("[INFO] x_flat.shape={}".format(x_flat.shape)) # remove transpose
# equation (9)
D_f = tf.sparse.sparse_dense_matmul(gaussians, x_flat) # shape = M x Fin*N
if DEBUG: print("[INFO] D_f.shape={}".format(D_f.shape))
D_f = tf.transpose(tf.reshape(D_f, [vertices_n, features_n, batch_size]), [2,0,1])
# [BATCH_SIZE, VERTECES_N, FEATURES_N]
if DEBUG: print("[INFO] D_f.shape={}".format(D_f.shape))
return D_f
def call(self, x):
batch_size, vertices_n, features_n= x.shape
batch_size = 10
if DEBUG: print("[INFO] x.shape={}".format(x.shape))
x = self.run_kernels(x, self.u)
x = tf.reshape(x, [batch_size*vertices_n, features_n*(self.k)])
if DEBUG: print("[INFO] x.shape={}".format(x.shape))
x = x @ self.W
if DEBUG: print("[INFO] x.shape={}".format(x.shape))
x = tf.reshape(x, [batch_size, vertices_n, self.hidden_size])
if DEBUG: print("[INFO] out x.shape={}".format(x.shape))
# print("[INFO] out x.shape={}".format(x.shape))
return x
def compute_output_shape(self, input_shape):
return (input_shape[1], self.hidden_size)
class MPool1(Layer):
def __init__(self, p, **kwargs):
self.p = p
super(MPool1, self).__init__(**kwargs)
def call(self, x):
"""Max pooling of size p. Should be a power of 2 (this is possible thanks to the reordering we previously did)."""
if self.p > 1:
x = tf.expand_dims(x, 3) # shape = N x M x F x 1
x = tf.nn.max_pool(x, ksize=[1,self.p,1,1], strides=[1, self.p,1,1], padding='SAME')
x = tf.squeeze(x, [3]) # shape = N x M/p x F
return x
def compute_output_shape(self, input_shape):
return input_shape
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment