Skip to content

Instantly share code, notes, and snippets.

@radi-cho
Created September 10, 2022 20:08
Show Gist options
  • Save radi-cho/42f59fcc17228c4b1ec818ab2e233a10 to your computer and use it in GitHub Desktop.
Save radi-cho/42f59fcc17228c4b1ec818ab2e233a10 to your computer and use it in GitHub Desktop.
class gMLPLayer(layers.Layer):
def __init__(self, num_patches, embedding_dim, dropout_rate, *args, **kwargs):
super(gMLPLayer, self).__init__(*args, **kwargs)
self.channel_projection1 = keras.Sequential(
[
layers.Dense(units=embedding_dim * 2),
layers.ReLU(),
layers.Dropout(rate=dropout_rate),
]
)
self.channel_projection2 = layers.Dense(units=embedding_dim)
self.spatial_projection = layers.Dense(
units=num_patches, bias_initializer="Ones"
)
self.normalize1 = layers.LayerNormalization(epsilon=1e-6)
self.normalize2 = layers.LayerNormalization(epsilon=1e-6)
def spatial_gating_unit(self, x):
u, v = tf.split(x, num_or_size_splits=2, axis=2)
v = self.normalize2(v)
v_channels = tf.linalg.matrix_transpose(v)
v_projected = self.spatial_projection(v_channels)
v_projected = tf.linalg.matrix_transpose(v_projected)
return u * v_projected
def call(self, inputs):
x = self.normalize1(inputs)
x_projected = self.channel_projection1(x)
x_spatial = self.spatial_gating_unit(x_projected)
x_projected = self.channel_projection2(x_spatial)
return x + x_projected
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment