Skip to content

Instantly share code, notes, and snippets.

Created July 25, 2023 13:43
Show Gist options
  • Save Cospel/a343e741608e48b367c74a4863f7b812 to your computer and use it in GitHub Desktop.
Save Cospel/a343e741608e48b367c74a4863f7b812 to your computer and use it in GitHub Desktop.
simple resnet18 and vgg like architectures in Group Equivariant CNN
import numpy as np
import tensorflow as tf
from tensorflow.keras import Model
from tensorflow.keras.layers import Input, Activation, MaxPooling2D
from keras_gcnn.layers import GConv2D, GBatchNorm, GroupPool
def SimpleVGG(input_shape, num_classes, h_input='Z2', h_output='C4'):
# Define the input layer
inputs = tf.keras.layers.Input(shape=input_shape)
# Define the convolutional layers
x = GConv2D(32, h_input=h_input, h_output=h_output, kernel_size=3, padding='valid')(inputs)
x = Activation('relu')(x)
x = MaxPooling2D(pool_size=2, strides=2)(x)
x = GConv2D(64, h_input=h_output, h_output=h_output, kernel_size=3, padding='same')(x)
x = GBatchNorm(h=h_output)(x)
x = Activation('relu')(x)
x = MaxPooling2D(pool_size=2, strides=2)(x)
x = GConv2D(128, h_input=h_output, h_output=h_output, kernel_size=3, padding='same')(x)
x = GBatchNorm(h=h_output)(x)
x = Activation('relu')(x)
x = MaxPooling2D(pool_size=2, strides=2)(x)
x = GConv2D(128, h_input=h_output, h_output=h_output, kernel_size=3, padding='same')(x)
x = GBatchNorm(h=h_output)(x)
x = Activation('relu')(x)
x = MaxPooling2D(pool_size=2, strides=2)(x)
x = GConv2D(256, h_input=h_output, h_output=h_output, kernel_size=3, padding='same')(x)
x = GBatchNorm(h=h_output)(x)
x = Activation('relu')(x)
x = MaxPooling2D(pool_size=3, strides=2)(x)
x = GConv2D(256, h_input=h_output, h_output=h_output, kernel_size=3, padding='same')(x)
x = GBatchNorm(h=h_output)(x)
x = Activation('relu')(x)
x = MaxPooling2D(pool_size=3, strides=2)(x)
x = GroupPool(h_output)(x)
x = tf.keras.layers.GlobalAveragePooling2D()(x)
x = tf.keras.layers.Dense(256, activation=None, name="descriptor", kernel_initializer="he_uniform", kernel_regularizer=tf.keras.regularizers.l2(1e-5))(x)
# Create the model
model = tf.keras.models.Model(inputs=inputs, outputs=x)
return model
def ResNet18(input_shape, num_classes, h_input='Z2', h_output='C4'):
# Define the input layer
inputs = tf.keras.layers.Input(shape=input_shape)
# Define the convolutional layers
x = GConv2D(64, h_input=h_input, h_output=h_output, kernel_size=7, strides=2, padding='same')(inputs)
x = GBatchNorm(h=h_output)(x)
x = tf.keras.layers.Activation('relu')(x)
x = tf.keras.layers.MaxPooling2D(pool_size=3, strides=2, padding='same')(x)
x = ResidualBlock(x, filters=[24, 24], h_input=h_output, h_output=h_output, strides=1)
x = ResidualBlock(x, filters=[24, 24], h_input=h_output, h_output=h_output, strides=1)
x = ResidualBlock(x, filters=[48, 48], h_input=h_output, h_output=h_output, strides=2)
x = ResidualBlock(x, filters=[48, 48], h_input=h_output, h_output=h_output, strides=1)
x = ResidualBlock(x, filters=[64, 64], h_input=h_output, h_output=h_output, strides=2)
x = ResidualBlock(x, filters=[64, 64], h_input=h_output, h_output=h_output, strides=1)
x = ResidualBlock(x, filters=[128, 128], h_input=h_output, h_output=h_output, strides=2)
x = ResidualBlock(x, filters=[128, 128], h_input=h_output, h_output=h_output, strides=1)
x = GroupPool(h_output)(x)
x = tf.keras.layers.GlobalAveragePooling2D()(x)
x = tf.keras.layers.Dense(256, activation=None, name="descriptor", kernel_initializer="he_uniform", kernel_regularizer=tf.keras.regularizers.l2(1e-5))(x)
# Create the model
model = tf.keras.models.Model(inputs=inputs, outputs=x)
return model
def ResidualBlock(x, filters, h_input, h_output, strides):
shortcut = x
x = GConv2D(filters[0], h_input=h_input, h_output=h_output, kernel_size=3, strides=strides, padding='same')(x)
x = GBatchNorm(h=h_output)(x)
x = tf.keras.layers.Activation('relu')(x)
x = GConv2D(filters[1], h_input=h_output, h_output=h_output, kernel_size=3, padding='same')(x)
x = GBatchNorm(h=h_output)(x)
if strides != 1 or shortcut.shape[-1] != filters[1]:
shortcut = GConv2D(filters[1], h_input=h_input, h_output=h_output, kernel_size=1, strides=strides, padding='same')(shortcut)
shortcut = GBatchNorm(h=h_output)(shortcut)
x = tf.keras.layers.Add()([x, shortcut])
x = tf.keras.layers.Activation('relu')(x)
return x
model = ResNet18((128, 128, 3), 10)
# Generate random test image:
img = np.random.randn(128, 128, 3)
# # Run a forward pass through the model with the image and transformed images:
res = model.predict(
np.stack([img, np.rot90(img), np.rot90(np.fliplr(img), 2)]),
# print(res.shape)
# print(res[0])
# print(res[1])
# # # Test that activations are the same:
# assert np.allclose(res[0], np.rot90(res[1], 3), rtol=1e-3, atol=1e-3)
# assert np.allclose(res[0], np.flipud(res[2]), rtol=1e-3, atol=1e-3)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment