Skip to content

Instantly share code, notes, and snippets.

@mohamad-amin
Created January 3, 2022 10:11
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save mohamad-amin/5334109dba81b9c26e7b4d1ded7fd9ad to your computer and use it in GitHub Desktop.
Save mohamad-amin/5334109dba81b9c26e7b4d1ded7fd9ad to your computer and use it in GitHub Desktop.
A shallow WideResNet implemented using Flax
from functools import partial
#from typing import Callable, Optional, Sequence, Tuple
from typing import (Any, Callable, Dict, Iterable, Mapping, Optional, Sequence, Tuple,
Union)
import jax
from jax import lax
import jax.numpy as jnp
from flax import linen as nn
#from .common import ConvBlock, ModuleDef, Sequential
#from .splat import SplAtConv2d
ModuleDef = Callable[..., Callable]
# InitFn = Callable[[PRNGKey, Shape, DType], Array]
InitFn = Callable[[Any, Iterable[int], Any], Any]
STAGE_SIZES = {
18: [2, 2, 2, 2],
34: [3, 4, 6, 3],
50: [3, 4, 6, 3],
101: [3, 4, 23, 3],
152: [3, 8, 36, 3],
200: [3, 24, 36, 3],
269: [3, 30, 48, 8],
}
def rsoftmax(x, radix, cardinality):
# (batch_size, features) -> (batch_size, features)
batch = x.shape[0]
if radix > 1:
x = x.reshape((batch, cardinality, radix, -1)).swapaxes(1, 2)
return nn.softmax(x, axis=1).reshape((batch, -1))
else:
return nn.sigmoid(x)
class ConvBlock(nn.Module):
n_filters: int
kernel_size: Tuple[int, int] = (3, 3)
strides: Tuple[int, int] = (1, 1)
activation: Callable = nn.relu
padding: Union[str, Iterable[Tuple[int, int]]] = ((0, 0), (0, 0))
is_last: bool = False
groups: int = 1
kernel_init: InitFn = nn.initializers.kaiming_normal()
bias_init: InitFn = nn.initializers.zeros
conv_cls: ModuleDef = nn.Conv
norm_cls: Optional[ModuleDef] = partial(nn.BatchNorm, momentum=1.0)
force_conv_bias: bool = False
@nn.compact
def __call__(self, x):
x = self.conv_cls(
self.n_filters,
self.kernel_size,
self.strides,
use_bias=(not self.norm_cls or self.force_conv_bias),
padding=self.padding,
feature_group_count=self.groups,
kernel_init=self.kernel_init,
bias_init=self.bias_init,
)(x)
if self.norm_cls:
scale_init = (nn.initializers.zeros
if self.is_last else nn.initializers.ones)
x = self.norm_cls(use_running_average=True, scale_init=scale_init)(x)
if not self.is_last:
x = self.activation(x)
return x
#def wide_basic_ntk(in_planes, planes, dropout_rate, stride=1, mode='train'):
# bn1 = stax.serial(
# stax.LayerNorm(axis=(0, 1, 2), eps=1e-5),
# stax.Dense(in_planes, 1., 0., parameterization='standard')
# )
# conv1 = stax.Conv(planes, (3, 3), padding='SAME', parameterization='standard')
# #dropout = stax.Dropout(rate=dropout_rate, mode=mode)
#
# bn2 = stax.serial(
# stax.LayerNorm(axis=(0, 1, 2), eps=1e-5),
# stax.Dense(planes, 1., 0., parameterization='standard')
# )
# conv2 = stax.Conv(planes, (3, 3), strides=(stride, stride),
# padding='SAME', parameterization='standard')
#
# main = stax.serial(
# bn1,
# stax.Relu(), conv1,
# #dropout,
# bn2,
# stax.Relu(), conv2
# )
#
# shortcut = stax.Identity()
# if stride != 1 or in_planes != planes:
# shortcut = stax.Conv(planes, (1, 1), strides=(stride, stride),
# padding='SAME', parameterization='standard')
#
# net = stax.serial(
# stax.FanOut(2), stax.parallel(main, shortcut), stax.FanInSum())
#
# return net
class SplAtConv2d(nn.Module):
channels: int
kernel_size: Tuple[int, int]
strides: Tuple[int, int] = (1, 1)
padding: Union[str, Iterable[Tuple[int, int]]] = ((0, 0), (0, 0))
groups: int = 1
radix: int = 2
reduction_factor: int = 4
conv_block_cls: ModuleDef = ConvBlock
cardinality: int = groups
# Match extra bias here:
# github.com/zhanghang1989/ResNeSt/blob/master/resnest/torch/splat.py#L39
match_reference: bool = False
@nn.compact
def __call__(self, x):
inter_channels = max(x.shape[-1] * self.radix // self.reduction_factor, 32)
conv_block = self.conv_block_cls(self.channels * self.radix,
kernel_size=self.kernel_size,
strides=self.strides,
groups=self.groups * self.radix,
padding=self.padding)
conv_cls = conv_block.conv_cls # type: ignore
x = conv_block(x)
if self.radix > 1:
# torch split takes split_size: int(rchannel//self.radix)
# jnp split takes num sections: self.radix
split = jnp.split(x, self.radix, axis=-1)
gap = sum(split)
else:
gap = x
gap = gap.mean((1, 2), keepdims=True) # type: ignore # global average pool
# Remove force_conv_bias after resolving
# github.com/zhanghang1989/ResNeSt/issues/125
gap = self.conv_block_cls(inter_channels,
kernel_size=(1, 1),
groups=self.cardinality,
force_conv_bias=self.match_reference)(gap)
attn = conv_cls(self.channels * self.radix,
kernel_size=(1, 1),
feature_group_count=self.cardinality)(gap) # n x 1 x 1 x c
attn = attn.reshape((x.shape[0], -1))
attn = rsoftmax(attn, self.radix, self.cardinality)
attn = attn.reshape((x.shape[0], 1, 1, -1))
if self.radix > 1:
attns = jnp.split(attn, self.radix, axis=-1)
out = sum(a * s for a, s in zip(attns, split))
else:
out = attn * x
return out
class ResNetStem(nn.Module):
conv_block_cls: ModuleDef = ConvBlock
@nn.compact
def __call__(self, x):
return self.conv_block_cls(64,
kernel_size=(7, 7),
strides=(2, 2),
padding=[(3, 3), (3, 3)])(x)
class ResNetDStem(nn.Module):
conv_block_cls: ModuleDef = ConvBlock
stem_width: int = 32
# If True, n_filters for first conv is (input_channels + 1) * 8
adaptive_first_width: bool = False
@nn.compact
def __call__(self, x):
cls = partial(self.conv_block_cls, kernel_size=(3, 3), padding=((1, 1), (1, 1)))
first_width = (8 * (x.shape[-1] + 1)
if self.adaptive_first_width else self.stem_width)
x = cls(first_width, strides=(2, 2))(x)
x = cls(self.stem_width, strides=(1, 1))(x)
x = cls(self.stem_width * 2, strides=(1, 1))(x)
return x
class ResNetSkipConnection(nn.Module):
strides: Tuple[int, int]
conv_block_cls: ModuleDef = ConvBlock
@nn.compact
def __call__(self, x, out_shape):
if x.shape != out_shape:
x = self.conv_block_cls(out_shape[-1],
kernel_size=(1, 1),
strides=self.strides,
activation=lambda y: y)(x)
return x
class ResNetDSkipConnection(nn.Module):
strides: Tuple[int, int]
conv_block_cls: ModuleDef = ConvBlock
@nn.compact
def __call__(self, x, out_shape):
if self.strides != (1, 1):
x = nn.avg_pool(x, (2, 2), strides=(2, 2), padding=((0, 0), (0, 0)))
if x.shape[-1] != out_shape[-1]:
x = self.conv_block_cls(out_shape[-1], (1, 1), activation=lambda y: y)(x)
return x
class ResNeStSkipConnection(ResNetDSkipConnection):
# Inheritance to ensures our variables dict has the right names.
pass
class ResNetBlock(nn.Module):
n_hidden: int
strides: Tuple[int, int] = (1, 1)
activation: Callable = nn.relu
conv_block_cls: ModuleDef = ConvBlock
skip_cls: ModuleDef = ResNetSkipConnection
@nn.compact
def __call__(self, x):
skip_cls = partial(self.skip_cls, conv_block_cls=self.conv_block_cls)
y = self.conv_block_cls(self.n_hidden,
padding=[(1, 1), (1, 1)],
strides=self.strides)(x)
y = self.conv_block_cls(self.n_hidden, padding=[(1, 1), (1, 1)],
is_last=True)(y)
return self.activation(y + skip_cls(self.strides)(x, y.shape))
class ResNetBottleneckBlock(nn.Module):
n_hidden: int
strides: Tuple[int, int] = (1, 1)
expansion: int = 4
groups: int = 1 # cardinality
base_width: int = 64
activation: Callable = nn.relu
conv_block_cls: ModuleDef = ConvBlock
skip_cls: ModuleDef = ResNetSkipConnection
@nn.compact
def __call__(self, x):
skip_cls = partial(self.skip_cls, conv_block_cls=self.conv_block_cls)
group_width = int(self.n_hidden * (self.base_width / 64.)) * self.groups
# Downsampling strides in 3x3 conv instead of 1x1 conv, which improves accuracy.
# This variant is called ResNet V1.5 (matches torchvision).
y = self.conv_block_cls(group_width, kernel_size=(1, 1))(x)
y = self.conv_block_cls(group_width,
strides=self.strides,
groups=self.groups,
padding=((1, 1), (1, 1)))(y)
y = self.conv_block_cls(self.n_hidden * self.expansion,
kernel_size=(1, 1),
is_last=True)(y)
return self.activation(y + skip_cls(self.strides)(x, y.shape))
class ResNetDBlock(ResNetBlock):
skip_cls: ModuleDef = ResNetDSkipConnection
class ResNetDBottleneckBlock(ResNetBottleneckBlock):
skip_cls: ModuleDef = ResNetDSkipConnection
class ResNeStBottleneckBlock(ResNetBottleneckBlock):
skip_cls: ModuleDef = ResNeStSkipConnection
avg_pool_first: bool = False
radix: int = 2
splat_cls: ModuleDef = SplAtConv2d
@nn.compact
def __call__(self, x):
assert self.radix == 2 # TODO: implement radix != 2
skip_cls = partial(self.skip_cls, conv_block_cls=self.conv_block_cls)
group_width = int(self.n_hidden * (self.base_width / 64.)) * self.groups
y = self.conv_block_cls(group_width, kernel_size=(1, 1))(x)
if self.strides != (1, 1) and self.avg_pool_first:
y = nn.avg_pool(y, (3, 3), strides=self.strides, padding=[(1, 1), (1, 1)])
y = self.splat_cls(group_width,
kernel_size=(3, 3),
strides=(1, 1),
padding=[(1, 1), (1, 1)],
groups=self.groups,
radix=self.radix)(y)
if self.strides != (1, 1) and not self.avg_pool_first:
y = nn.avg_pool(y, (3, 3), strides=self.strides, padding=[(1, 1), (1, 1)])
y = self.conv_block_cls(self.n_hidden * self.expansion,
kernel_size=(1, 1),
is_last=True)(y)
return self.activation(y + skip_cls(self.strides)(x, y.shape))
class FrozenBatchNorm(nn.BatchNorm):
@nn.compact
def __call__(self, x, use_running_average: Optional[bool] = None):
self.variables['batch_stats'] = jax.tree_map(lambda p: lax.stop_gradient(p), self.variables['batch_stats'])
return super.__call__(x, use_running_average)
class Sequential(nn.Module):
layers: Sequence[Union[nn.Module, Callable[[jnp.ndarray], jnp.ndarray]]]
@nn.compact
def __call__(self, x):
for layer in self.layers:
x = layer(x)
return x
def ResNet(
block_cls: ModuleDef,
*,
stage_sizes: Sequence[int],
n_classes: int,
hidden_sizes: Sequence[int] = (64, 128, 256, 512),
conv_cls: ModuleDef = nn.Conv,
norm_cls: Optional[ModuleDef] = partial(nn.BatchNorm, momentum=1.0),
conv_block_cls: ModuleDef = ConvBlock,
stem_cls: ModuleDef = ResNetStem,
pool_fn: Callable = partial(nn.max_pool,
window_shape=(3, 3),
strides=(2, 2),
padding=((1, 1), (1, 1))),
model: Optional[ModuleDef] = None
) -> Sequential:
conv_block_cls = partial(conv_block_cls, conv_cls=conv_cls, norm_cls=norm_cls)
stem_cls = partial(stem_cls, conv_block_cls=conv_block_cls)
block_cls = partial(block_cls, conv_block_cls=conv_block_cls)
layers = [stem_cls(), pool_fn]
for i, (hsize, n_blocks) in enumerate(zip(hidden_sizes, stage_sizes)):
for b in range(n_blocks):
strides = (1, 1) if i == 0 or b != 0 else (2, 2)
layers.append(block_cls(n_hidden=hsize, strides=strides))
layers.append(partial(jnp.mean, axis=(1, 2))) # global average pool
layers.append(nn.Dense(n_classes))
return Sequential(layers)
class WideBasicNtk(nn.Module):
in_planes: int
planes: int
stride: int
@nn.compact
def __call__(self, x):
shortcut = x
if self.stride != 1 or self.in_planes != self.planes:
shortcut = nn.Conv(self.planes, (1, 1), strides=(self.stride, self.stride), padding='SAME')(x)
x = nn.BatchNorm(use_running_average=True, momentum=1.0)(x)
x = nn.relu(x)
x = nn.Conv(self.planes, (3, 3), strides=(1, 1), padding='SAME')(x)
x = nn.BatchNorm(use_running_average=True, momentum=1.0)(x)
x = nn.relu(x)
x = nn.Conv(self.planes, (3, 3), strides=(self.stride, self.stride), padding='SAME')(x)
if shortcut is not None:
x = x + shortcut
return x
def WideResNetNTK(
num_layers: int,
widen_factor: int,
depth: int,
num_classes: int,
num_input_channels: int = 3,
dropout_rate: float = 0.3) -> Sequential:
in_planes = 16
assert ((depth-4)%6 ==0), 'wide-resnet depth should be 6n+4'
n = (depth-4)/6
k = widen_factor
def _wide_layer(block, in_planes, planes, num_blocks, stride):
strides = [stride] + [1]*(int(num_blocks)-1)
layers = []
for stride in strides:
layers.append(block(in_planes, planes, stride))
in_planes = planes
return Sequential(layers), in_planes
print('| wide-resnet %dx%d' %(depth, k))
nstages = [16, 16*k, 32*k, 64*k]
conv1 = nn.Conv(nstages[0], (3, 3), strides=(1, 1), padding='SAME')
if num_layers >= 1:
layer1, in_planes = _wide_layer(WideBasicNtk, in_planes, nstages[1],
n, stride=1)
if num_layers >= 2:
layer2, in_planes = _wide_layer(WideBasicNtk, in_planes, nstages[2],
n, stride=2)
if num_layers == 3:
layer3, in_planes = _wide_layer(WideBasicNtk, in_planes, nstages[3],
n, stride=2)
bn1 = nn.BatchNorm(use_running_average=True, momentum=1.0)
linear = nn.Dense(num_classes)
avg_pool = partial(jnp.mean, axis=(1, 2))
if num_layers == 1:
net = Sequential([
conv1,
layer1,
bn1,
nn.relu,
avg_pool,
linear
])
elif num_layers == 2:
net = Sequential([
conv1,
layer1, layer2,
bn1,
nn.relu,
avg_pool,
linear
])
elif num_layers == 3:
net = Sequential([
conv1,
layer1, layer2, layer3,
bn1,
nn.relu,
avg_pool,
linear
])
return net
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment