Skip to content

Instantly share code, notes, and snippets.

@lazuxd
Created March 5, 2020 21:36
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save lazuxd/d7aaba284123bf3340e723701e381e6e to your computer and use it in GitHub Desktop.
Save lazuxd/d7aaba284123bf3340e723701e381e6e to your computer and use it in GitHub Desktop.
Building a ResNet in Keras
from tensorflow import Tensor
from tensorflow.keras.layers import Input, Conv2D, ReLU, BatchNormalization,\
Add, AveragePooling2D, Flatten, Dense
from tensorflow.keras.models import Model
def relu_bn(inputs: Tensor) -> Tensor:
relu = ReLU()(inputs)
bn = BatchNormalization()(relu)
return bn
def residual_block(x: Tensor, downsample: bool, filters: int, kernel_size: int = 3) -> Tensor:
y = Conv2D(kernel_size=kernel_size,
strides= (1 if not downsample else 2),
filters=filters,
padding="same")(x)
y = relu_bn(y)
y = Conv2D(kernel_size=kernel_size,
strides=1,
filters=filters,
padding="same")(y)
if downsample:
x = Conv2D(kernel_size=1,
strides=2,
filters=filters,
padding="same")(x)
out = Add()([x, y])
out = relu_bn(out)
return out
def create_res_net():
inputs = Input(shape=(32, 32, 3))
num_filters = 64
t = BatchNormalization()(inputs)
t = Conv2D(kernel_size=3,
strides=1,
filters=num_filters,
padding="same")(t)
t = relu_bn(t)
num_blocks_list = [2, 5, 5, 2]
for i in range(len(num_blocks_list)):
num_blocks = num_blocks_list[i]
for j in range(num_blocks):
t = residual_block(t, downsample=(j==0 and i!=0), filters=num_filters)
num_filters *= 2
t = AveragePooling2D(4)(t)
t = Flatten()(t)
outputs = Dense(10, activation='softmax')(t)
model = Model(inputs, outputs)
model.compile(
optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy']
)
return model
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment