Created
January 3, 2022 10:11
-
-
Save mohamad-amin/5334109dba81b9c26e7b4d1ded7fd9ad to your computer and use it in GitHub Desktop.
A shallow WideResNet implemented using Flax
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from functools import partial | |
#from typing import Callable, Optional, Sequence, Tuple | |
from typing import (Any, Callable, Dict, Iterable, Mapping, Optional, Sequence, Tuple, | |
Union) | |
import jax | |
from jax import lax | |
import jax.numpy as jnp | |
from flax import linen as nn | |
#from .common import ConvBlock, ModuleDef, Sequential | |
#from .splat import SplAtConv2d | |
ModuleDef = Callable[..., Callable] | |
# InitFn = Callable[[PRNGKey, Shape, DType], Array] | |
InitFn = Callable[[Any, Iterable[int], Any], Any] | |
STAGE_SIZES = { | |
18: [2, 2, 2, 2], | |
34: [3, 4, 6, 3], | |
50: [3, 4, 6, 3], | |
101: [3, 4, 23, 3], | |
152: [3, 8, 36, 3], | |
200: [3, 24, 36, 3], | |
269: [3, 30, 48, 8], | |
} | |
def rsoftmax(x, radix, cardinality): | |
# (batch_size, features) -> (batch_size, features) | |
batch = x.shape[0] | |
if radix > 1: | |
x = x.reshape((batch, cardinality, radix, -1)).swapaxes(1, 2) | |
return nn.softmax(x, axis=1).reshape((batch, -1)) | |
else: | |
return nn.sigmoid(x) | |
class ConvBlock(nn.Module): | |
n_filters: int | |
kernel_size: Tuple[int, int] = (3, 3) | |
strides: Tuple[int, int] = (1, 1) | |
activation: Callable = nn.relu | |
padding: Union[str, Iterable[Tuple[int, int]]] = ((0, 0), (0, 0)) | |
is_last: bool = False | |
groups: int = 1 | |
kernel_init: InitFn = nn.initializers.kaiming_normal() | |
bias_init: InitFn = nn.initializers.zeros | |
conv_cls: ModuleDef = nn.Conv | |
norm_cls: Optional[ModuleDef] = partial(nn.BatchNorm, momentum=1.0) | |
force_conv_bias: bool = False | |
@nn.compact | |
def __call__(self, x): | |
x = self.conv_cls( | |
self.n_filters, | |
self.kernel_size, | |
self.strides, | |
use_bias=(not self.norm_cls or self.force_conv_bias), | |
padding=self.padding, | |
feature_group_count=self.groups, | |
kernel_init=self.kernel_init, | |
bias_init=self.bias_init, | |
)(x) | |
if self.norm_cls: | |
scale_init = (nn.initializers.zeros | |
if self.is_last else nn.initializers.ones) | |
x = self.norm_cls(use_running_average=True, scale_init=scale_init)(x) | |
if not self.is_last: | |
x = self.activation(x) | |
return x | |
#def wide_basic_ntk(in_planes, planes, dropout_rate, stride=1, mode='train'): | |
# bn1 = stax.serial( | |
# stax.LayerNorm(axis=(0, 1, 2), eps=1e-5), | |
# stax.Dense(in_planes, 1., 0., parameterization='standard') | |
# ) | |
# conv1 = stax.Conv(planes, (3, 3), padding='SAME', parameterization='standard') | |
# #dropout = stax.Dropout(rate=dropout_rate, mode=mode) | |
# | |
# bn2 = stax.serial( | |
# stax.LayerNorm(axis=(0, 1, 2), eps=1e-5), | |
# stax.Dense(planes, 1., 0., parameterization='standard') | |
# ) | |
# conv2 = stax.Conv(planes, (3, 3), strides=(stride, stride), | |
# padding='SAME', parameterization='standard') | |
# | |
# main = stax.serial( | |
# bn1, | |
# stax.Relu(), conv1, | |
# #dropout, | |
# bn2, | |
# stax.Relu(), conv2 | |
# ) | |
# | |
# shortcut = stax.Identity() | |
# if stride != 1 or in_planes != planes: | |
# shortcut = stax.Conv(planes, (1, 1), strides=(stride, stride), | |
# padding='SAME', parameterization='standard') | |
# | |
# net = stax.serial( | |
# stax.FanOut(2), stax.parallel(main, shortcut), stax.FanInSum()) | |
# | |
# return net | |
class SplAtConv2d(nn.Module): | |
channels: int | |
kernel_size: Tuple[int, int] | |
strides: Tuple[int, int] = (1, 1) | |
padding: Union[str, Iterable[Tuple[int, int]]] = ((0, 0), (0, 0)) | |
groups: int = 1 | |
radix: int = 2 | |
reduction_factor: int = 4 | |
conv_block_cls: ModuleDef = ConvBlock | |
cardinality: int = groups | |
# Match extra bias here: | |
# github.com/zhanghang1989/ResNeSt/blob/master/resnest/torch/splat.py#L39 | |
match_reference: bool = False | |
@nn.compact | |
def __call__(self, x): | |
inter_channels = max(x.shape[-1] * self.radix // self.reduction_factor, 32) | |
conv_block = self.conv_block_cls(self.channels * self.radix, | |
kernel_size=self.kernel_size, | |
strides=self.strides, | |
groups=self.groups * self.radix, | |
padding=self.padding) | |
conv_cls = conv_block.conv_cls # type: ignore | |
x = conv_block(x) | |
if self.radix > 1: | |
# torch split takes split_size: int(rchannel//self.radix) | |
# jnp split takes num sections: self.radix | |
split = jnp.split(x, self.radix, axis=-1) | |
gap = sum(split) | |
else: | |
gap = x | |
gap = gap.mean((1, 2), keepdims=True) # type: ignore # global average pool | |
# Remove force_conv_bias after resolving | |
# github.com/zhanghang1989/ResNeSt/issues/125 | |
gap = self.conv_block_cls(inter_channels, | |
kernel_size=(1, 1), | |
groups=self.cardinality, | |
force_conv_bias=self.match_reference)(gap) | |
attn = conv_cls(self.channels * self.radix, | |
kernel_size=(1, 1), | |
feature_group_count=self.cardinality)(gap) # n x 1 x 1 x c | |
attn = attn.reshape((x.shape[0], -1)) | |
attn = rsoftmax(attn, self.radix, self.cardinality) | |
attn = attn.reshape((x.shape[0], 1, 1, -1)) | |
if self.radix > 1: | |
attns = jnp.split(attn, self.radix, axis=-1) | |
out = sum(a * s for a, s in zip(attns, split)) | |
else: | |
out = attn * x | |
return out | |
class ResNetStem(nn.Module): | |
conv_block_cls: ModuleDef = ConvBlock | |
@nn.compact | |
def __call__(self, x): | |
return self.conv_block_cls(64, | |
kernel_size=(7, 7), | |
strides=(2, 2), | |
padding=[(3, 3), (3, 3)])(x) | |
class ResNetDStem(nn.Module): | |
conv_block_cls: ModuleDef = ConvBlock | |
stem_width: int = 32 | |
# If True, n_filters for first conv is (input_channels + 1) * 8 | |
adaptive_first_width: bool = False | |
@nn.compact | |
def __call__(self, x): | |
cls = partial(self.conv_block_cls, kernel_size=(3, 3), padding=((1, 1), (1, 1))) | |
first_width = (8 * (x.shape[-1] + 1) | |
if self.adaptive_first_width else self.stem_width) | |
x = cls(first_width, strides=(2, 2))(x) | |
x = cls(self.stem_width, strides=(1, 1))(x) | |
x = cls(self.stem_width * 2, strides=(1, 1))(x) | |
return x | |
class ResNetSkipConnection(nn.Module): | |
strides: Tuple[int, int] | |
conv_block_cls: ModuleDef = ConvBlock | |
@nn.compact | |
def __call__(self, x, out_shape): | |
if x.shape != out_shape: | |
x = self.conv_block_cls(out_shape[-1], | |
kernel_size=(1, 1), | |
strides=self.strides, | |
activation=lambda y: y)(x) | |
return x | |
class ResNetDSkipConnection(nn.Module): | |
strides: Tuple[int, int] | |
conv_block_cls: ModuleDef = ConvBlock | |
@nn.compact | |
def __call__(self, x, out_shape): | |
if self.strides != (1, 1): | |
x = nn.avg_pool(x, (2, 2), strides=(2, 2), padding=((0, 0), (0, 0))) | |
if x.shape[-1] != out_shape[-1]: | |
x = self.conv_block_cls(out_shape[-1], (1, 1), activation=lambda y: y)(x) | |
return x | |
class ResNeStSkipConnection(ResNetDSkipConnection): | |
# Inheritance to ensures our variables dict has the right names. | |
pass | |
class ResNetBlock(nn.Module): | |
n_hidden: int | |
strides: Tuple[int, int] = (1, 1) | |
activation: Callable = nn.relu | |
conv_block_cls: ModuleDef = ConvBlock | |
skip_cls: ModuleDef = ResNetSkipConnection | |
@nn.compact | |
def __call__(self, x): | |
skip_cls = partial(self.skip_cls, conv_block_cls=self.conv_block_cls) | |
y = self.conv_block_cls(self.n_hidden, | |
padding=[(1, 1), (1, 1)], | |
strides=self.strides)(x) | |
y = self.conv_block_cls(self.n_hidden, padding=[(1, 1), (1, 1)], | |
is_last=True)(y) | |
return self.activation(y + skip_cls(self.strides)(x, y.shape)) | |
class ResNetBottleneckBlock(nn.Module): | |
n_hidden: int | |
strides: Tuple[int, int] = (1, 1) | |
expansion: int = 4 | |
groups: int = 1 # cardinality | |
base_width: int = 64 | |
activation: Callable = nn.relu | |
conv_block_cls: ModuleDef = ConvBlock | |
skip_cls: ModuleDef = ResNetSkipConnection | |
@nn.compact | |
def __call__(self, x): | |
skip_cls = partial(self.skip_cls, conv_block_cls=self.conv_block_cls) | |
group_width = int(self.n_hidden * (self.base_width / 64.)) * self.groups | |
# Downsampling strides in 3x3 conv instead of 1x1 conv, which improves accuracy. | |
# This variant is called ResNet V1.5 (matches torchvision). | |
y = self.conv_block_cls(group_width, kernel_size=(1, 1))(x) | |
y = self.conv_block_cls(group_width, | |
strides=self.strides, | |
groups=self.groups, | |
padding=((1, 1), (1, 1)))(y) | |
y = self.conv_block_cls(self.n_hidden * self.expansion, | |
kernel_size=(1, 1), | |
is_last=True)(y) | |
return self.activation(y + skip_cls(self.strides)(x, y.shape)) | |
class ResNetDBlock(ResNetBlock): | |
skip_cls: ModuleDef = ResNetDSkipConnection | |
class ResNetDBottleneckBlock(ResNetBottleneckBlock): | |
skip_cls: ModuleDef = ResNetDSkipConnection | |
class ResNeStBottleneckBlock(ResNetBottleneckBlock): | |
skip_cls: ModuleDef = ResNeStSkipConnection | |
avg_pool_first: bool = False | |
radix: int = 2 | |
splat_cls: ModuleDef = SplAtConv2d | |
@nn.compact | |
def __call__(self, x): | |
assert self.radix == 2 # TODO: implement radix != 2 | |
skip_cls = partial(self.skip_cls, conv_block_cls=self.conv_block_cls) | |
group_width = int(self.n_hidden * (self.base_width / 64.)) * self.groups | |
y = self.conv_block_cls(group_width, kernel_size=(1, 1))(x) | |
if self.strides != (1, 1) and self.avg_pool_first: | |
y = nn.avg_pool(y, (3, 3), strides=self.strides, padding=[(1, 1), (1, 1)]) | |
y = self.splat_cls(group_width, | |
kernel_size=(3, 3), | |
strides=(1, 1), | |
padding=[(1, 1), (1, 1)], | |
groups=self.groups, | |
radix=self.radix)(y) | |
if self.strides != (1, 1) and not self.avg_pool_first: | |
y = nn.avg_pool(y, (3, 3), strides=self.strides, padding=[(1, 1), (1, 1)]) | |
y = self.conv_block_cls(self.n_hidden * self.expansion, | |
kernel_size=(1, 1), | |
is_last=True)(y) | |
return self.activation(y + skip_cls(self.strides)(x, y.shape)) | |
class FrozenBatchNorm(nn.BatchNorm): | |
@nn.compact | |
def __call__(self, x, use_running_average: Optional[bool] = None): | |
self.variables['batch_stats'] = jax.tree_map(lambda p: lax.stop_gradient(p), self.variables['batch_stats']) | |
return super.__call__(x, use_running_average) | |
class Sequential(nn.Module): | |
layers: Sequence[Union[nn.Module, Callable[[jnp.ndarray], jnp.ndarray]]] | |
@nn.compact | |
def __call__(self, x): | |
for layer in self.layers: | |
x = layer(x) | |
return x | |
def ResNet( | |
block_cls: ModuleDef, | |
*, | |
stage_sizes: Sequence[int], | |
n_classes: int, | |
hidden_sizes: Sequence[int] = (64, 128, 256, 512), | |
conv_cls: ModuleDef = nn.Conv, | |
norm_cls: Optional[ModuleDef] = partial(nn.BatchNorm, momentum=1.0), | |
conv_block_cls: ModuleDef = ConvBlock, | |
stem_cls: ModuleDef = ResNetStem, | |
pool_fn: Callable = partial(nn.max_pool, | |
window_shape=(3, 3), | |
strides=(2, 2), | |
padding=((1, 1), (1, 1))), | |
model: Optional[ModuleDef] = None | |
) -> Sequential: | |
conv_block_cls = partial(conv_block_cls, conv_cls=conv_cls, norm_cls=norm_cls) | |
stem_cls = partial(stem_cls, conv_block_cls=conv_block_cls) | |
block_cls = partial(block_cls, conv_block_cls=conv_block_cls) | |
layers = [stem_cls(), pool_fn] | |
for i, (hsize, n_blocks) in enumerate(zip(hidden_sizes, stage_sizes)): | |
for b in range(n_blocks): | |
strides = (1, 1) if i == 0 or b != 0 else (2, 2) | |
layers.append(block_cls(n_hidden=hsize, strides=strides)) | |
layers.append(partial(jnp.mean, axis=(1, 2))) # global average pool | |
layers.append(nn.Dense(n_classes)) | |
return Sequential(layers) | |
class WideBasicNtk(nn.Module): | |
in_planes: int | |
planes: int | |
stride: int | |
@nn.compact | |
def __call__(self, x): | |
shortcut = x | |
if self.stride != 1 or self.in_planes != self.planes: | |
shortcut = nn.Conv(self.planes, (1, 1), strides=(self.stride, self.stride), padding='SAME')(x) | |
x = nn.BatchNorm(use_running_average=True, momentum=1.0)(x) | |
x = nn.relu(x) | |
x = nn.Conv(self.planes, (3, 3), strides=(1, 1), padding='SAME')(x) | |
x = nn.BatchNorm(use_running_average=True, momentum=1.0)(x) | |
x = nn.relu(x) | |
x = nn.Conv(self.planes, (3, 3), strides=(self.stride, self.stride), padding='SAME')(x) | |
if shortcut is not None: | |
x = x + shortcut | |
return x | |
def WideResNetNTK( | |
num_layers: int, | |
widen_factor: int, | |
depth: int, | |
num_classes: int, | |
num_input_channels: int = 3, | |
dropout_rate: float = 0.3) -> Sequential: | |
in_planes = 16 | |
assert ((depth-4)%6 ==0), 'wide-resnet depth should be 6n+4' | |
n = (depth-4)/6 | |
k = widen_factor | |
def _wide_layer(block, in_planes, planes, num_blocks, stride): | |
strides = [stride] + [1]*(int(num_blocks)-1) | |
layers = [] | |
for stride in strides: | |
layers.append(block(in_planes, planes, stride)) | |
in_planes = planes | |
return Sequential(layers), in_planes | |
print('| wide-resnet %dx%d' %(depth, k)) | |
nstages = [16, 16*k, 32*k, 64*k] | |
conv1 = nn.Conv(nstages[0], (3, 3), strides=(1, 1), padding='SAME') | |
if num_layers >= 1: | |
layer1, in_planes = _wide_layer(WideBasicNtk, in_planes, nstages[1], | |
n, stride=1) | |
if num_layers >= 2: | |
layer2, in_planes = _wide_layer(WideBasicNtk, in_planes, nstages[2], | |
n, stride=2) | |
if num_layers == 3: | |
layer3, in_planes = _wide_layer(WideBasicNtk, in_planes, nstages[3], | |
n, stride=2) | |
bn1 = nn.BatchNorm(use_running_average=True, momentum=1.0) | |
linear = nn.Dense(num_classes) | |
avg_pool = partial(jnp.mean, axis=(1, 2)) | |
if num_layers == 1: | |
net = Sequential([ | |
conv1, | |
layer1, | |
bn1, | |
nn.relu, | |
avg_pool, | |
linear | |
]) | |
elif num_layers == 2: | |
net = Sequential([ | |
conv1, | |
layer1, layer2, | |
bn1, | |
nn.relu, | |
avg_pool, | |
linear | |
]) | |
elif num_layers == 3: | |
net = Sequential([ | |
conv1, | |
layer1, layer2, layer3, | |
bn1, | |
nn.relu, | |
avg_pool, | |
linear | |
]) | |
return net |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment