Skip to content

Instantly share code, notes, and snippets.

@tomginsberg
Last active May 22, 2020 00:21
Show Gist options
  • Save tomginsberg/cd7cda679ac92ae5fa8dea16c118e071 to your computer and use it in GitHub Desktop.
Save tomginsberg/cd7cda679ac92ae5fa8dea16c118e071 to your computer and use it in GitHub Desktop.
import tensorflow as tf
from typing import Tuple, Optional
import tensorflow as tf
from tensorflow.keras import Sequential, Model
from tensorflow.keras.layers import Conv3D, BatchNormalization, ReLU, Dropout, Dense, GlobalAveragePooling3D
from tensorflow.keras.initializers import GlorotNormal
class BasicStem(Sequential):
"""
Conv2d + BN + ReLu + Conv1d + BN + Relu
"""
def __init__(self,
input_shape=(16, 224, 224, 2),
stem_channels=(16,32),
k_space=7,
k_time=3,
dropout_rate=.1,
**kwargs
):
mid_channels, out_channels = stem_channels
super().__init__(layers=[
Conv3D(
filters=mid_channels,
kernel_size=(1, k_space, k_space),
strides=(1, 2, 2),
padding='valid',
use_bias=False,
kernel_initializer=GlorotNormal,
input_shape=input_shape
),
BatchNormalization(),
ReLU(),
Dropout(dropout_rate),
Conv3D(
filters=out_channels,
kernel_size=(k_time, 1, 1),
padding='valid',
use_bias=False,
kernel_initializer=GlorotNormal
),
ReLU(),
Dropout(.1)
]
)
class Conv2Plus1D(Sequential):
def __init__(self,
mid_channels,
out_channels,
stride=1,
padding='same',
k_space=3,
k_time=3,
dropout_rate=.1
):
super(Conv2Plus1D, self).__init__(layers=[
Conv3D(
filters=mid_channels,
kernel_size=(1, k_space, k_space),
strides=(1, stride, stride),
padding=padding,
use_bias=False,
kernel_initializer=GlorotNormal
),
BatchNormalization(),
ReLU(),
Dropout(dropout_rate),
Conv3D(
filters=out_channels,
kernel_size=(k_time, 1, 1),
strides=(stride, 1, 1),
padding=padding,
use_bias=False,
kernel_initializer=GlorotNormal
),
BatchNormalization(),
ReLU(),
Dropout(dropout_rate)
])
class ResNetBlock(Model):
def __init__(self,
in_channels,
out_channels,
stride=1,
kernel_size=3
):
super().__init__()
if in_channels == out_channels and stride == 1:
self.residual_transformation = None
else:
self.residual_transformation = Conv3D(
filters=out_channels, kernel_size=(1, 1, 1),
strides=(stride, stride, stride),
padding='same',
use_bias=False,
kernel_initializer=GlorotNormal
)
mid_channels = (in_channels * out_channels * 3 * 3 * 3) // (in_channels * 3 * 3 + 3 * out_channels)
self.body = Sequential(layers=[
Conv2Plus1D(mid_channels=mid_channels, out_channels=out_channels, stride=stride,
k_space=kernel_size, k_time=3),
Conv2Plus1D(mid_channels=mid_channels, out_channels=out_channels,
k_space=kernel_size, k_time=3),
]
)
def call(self, x, **kwargs):
if self.residual_transformation is not None:
residual = self.residual_transformation(x)
else:
residual = x
return tf.nn.relu(self.body.call(x) + residual)
class ResNet2plus1D(Sequential):
def __init__(self,
input_shape: Tuple[int, ...] = (16, 224, 224, 2),
stem: Model = BasicStem,
stem_channels: Tuple[int, ...] = (16, 32),
layer_channels: Tuple[int, ...] = (64, 128, 256),
layer_blocks: Optional[Tuple[int, ...]] = None,
layer_strides: Optional[Tuple[int, ...]] = None,
kernel_sizes: Optional[Tuple[int, ...]] = None,
out_features: int = 3,
**kwargs):
super().__init__()
self.stem = BasicStem(input_shape, stem_channels)
layers = []
if layer_strides is None:
layer_strides = [1] + [2] * (len(layer_channels) - 1)
if layer_blocks is None:
layer_blocks = [2] * len(layer_channels)
if kernel_sizes is None:
# Make the first kernel slightly larger (hxwxd)=(5x5x3)
kernel_sizes = [5] + [3] * (len(layer_channels) - 1)
in_channels = stem_channels[-1]
for channels, stride, blocks, kernel_size in zip(layer_channels, layer_strides, layer_blocks, kernel_sizes):
# add downsampling layer (stride=2 by default), increase channels
layers.append(
ResNetBlock(in_channels, channels, stride=stride, kernel_size=kernel_size)
)
# add fixed channel layers, 3x3x3 kernels, stride = 1
for _ in range(1, blocks):
layers.append(
ResNetBlock(channels, channels)
)
in_channels = channels
self.backbone = Sequential(layers)
# pool to (batch x num_channels)
self.pool = GlobalAveragePooling3D()
self.fc = Dense(out_features)
self.add(self.stem)
self.add(self.backbone)
self.add(self.pool)
self.add(self.fc)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment