Last active
May 22, 2020 00:21
-
-
Save tomginsberg/cd7cda679ac92ae5fa8dea16c118e071 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
from typing import Tuple, Optional | |
import tensorflow as tf | |
from tensorflow.keras import Sequential, Model | |
from tensorflow.keras.layers import Conv3D, BatchNormalization, ReLU, Dropout, Dense, GlobalAveragePooling3D | |
from tensorflow.keras.initializers import GlorotNormal | |
class BasicStem(Sequential): | |
""" | |
Conv2d + BN + ReLu + Conv1d + BN + Relu | |
""" | |
def __init__(self, | |
input_shape=(16, 224, 224, 2), | |
stem_channels=(16,32), | |
k_space=7, | |
k_time=3, | |
dropout_rate=.1, | |
**kwargs | |
): | |
mid_channels, out_channels = stem_channels | |
super().__init__(layers=[ | |
Conv3D( | |
filters=mid_channels, | |
kernel_size=(1, k_space, k_space), | |
strides=(1, 2, 2), | |
padding='valid', | |
use_bias=False, | |
kernel_initializer=GlorotNormal, | |
input_shape=input_shape | |
), | |
BatchNormalization(), | |
ReLU(), | |
Dropout(dropout_rate), | |
Conv3D( | |
filters=out_channels, | |
kernel_size=(k_time, 1, 1), | |
padding='valid', | |
use_bias=False, | |
kernel_initializer=GlorotNormal | |
), | |
ReLU(), | |
Dropout(.1) | |
] | |
) | |
class Conv2Plus1D(Sequential): | |
def __init__(self, | |
mid_channels, | |
out_channels, | |
stride=1, | |
padding='same', | |
k_space=3, | |
k_time=3, | |
dropout_rate=.1 | |
): | |
super(Conv2Plus1D, self).__init__(layers=[ | |
Conv3D( | |
filters=mid_channels, | |
kernel_size=(1, k_space, k_space), | |
strides=(1, stride, stride), | |
padding=padding, | |
use_bias=False, | |
kernel_initializer=GlorotNormal | |
), | |
BatchNormalization(), | |
ReLU(), | |
Dropout(dropout_rate), | |
Conv3D( | |
filters=out_channels, | |
kernel_size=(k_time, 1, 1), | |
strides=(stride, 1, 1), | |
padding=padding, | |
use_bias=False, | |
kernel_initializer=GlorotNormal | |
), | |
BatchNormalization(), | |
ReLU(), | |
Dropout(dropout_rate) | |
]) | |
class ResNetBlock(Model): | |
def __init__(self, | |
in_channels, | |
out_channels, | |
stride=1, | |
kernel_size=3 | |
): | |
super().__init__() | |
if in_channels == out_channels and stride == 1: | |
self.residual_transformation = None | |
else: | |
self.residual_transformation = Conv3D( | |
filters=out_channels, kernel_size=(1, 1, 1), | |
strides=(stride, stride, stride), | |
padding='same', | |
use_bias=False, | |
kernel_initializer=GlorotNormal | |
) | |
mid_channels = (in_channels * out_channels * 3 * 3 * 3) // (in_channels * 3 * 3 + 3 * out_channels) | |
self.body = Sequential(layers=[ | |
Conv2Plus1D(mid_channels=mid_channels, out_channels=out_channels, stride=stride, | |
k_space=kernel_size, k_time=3), | |
Conv2Plus1D(mid_channels=mid_channels, out_channels=out_channels, | |
k_space=kernel_size, k_time=3), | |
] | |
) | |
def call(self, x, **kwargs): | |
if self.residual_transformation is not None: | |
residual = self.residual_transformation(x) | |
else: | |
residual = x | |
return tf.nn.relu(self.body.call(x) + residual) | |
class ResNet2plus1D(Sequential): | |
def __init__(self, | |
input_shape: Tuple[int, ...] = (16, 224, 224, 2), | |
stem: Model = BasicStem, | |
stem_channels: Tuple[int, ...] = (16, 32), | |
layer_channels: Tuple[int, ...] = (64, 128, 256), | |
layer_blocks: Optional[Tuple[int, ...]] = None, | |
layer_strides: Optional[Tuple[int, ...]] = None, | |
kernel_sizes: Optional[Tuple[int, ...]] = None, | |
out_features: int = 3, | |
**kwargs): | |
super().__init__() | |
self.stem = BasicStem(input_shape, stem_channels) | |
layers = [] | |
if layer_strides is None: | |
layer_strides = [1] + [2] * (len(layer_channels) - 1) | |
if layer_blocks is None: | |
layer_blocks = [2] * len(layer_channels) | |
if kernel_sizes is None: | |
# Make the first kernel slightly larger (hxwxd)=(5x5x3) | |
kernel_sizes = [5] + [3] * (len(layer_channels) - 1) | |
in_channels = stem_channels[-1] | |
for channels, stride, blocks, kernel_size in zip(layer_channels, layer_strides, layer_blocks, kernel_sizes): | |
# add downsampling layer (stride=2 by default), increase channels | |
layers.append( | |
ResNetBlock(in_channels, channels, stride=stride, kernel_size=kernel_size) | |
) | |
# add fixed channel layers, 3x3x3 kernels, stride = 1 | |
for _ in range(1, blocks): | |
layers.append( | |
ResNetBlock(channels, channels) | |
) | |
in_channels = channels | |
self.backbone = Sequential(layers) | |
# pool to (batch x num_channels) | |
self.pool = GlobalAveragePooling3D() | |
self.fc = Dense(out_features) | |
self.add(self.stem) | |
self.add(self.backbone) | |
self.add(self.pool) | |
self.add(self.fc) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment