Skip to content

Instantly share code, notes, and snippets.

@devymex
Created February 27, 2022 07:57
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 2 You must be signed in to fork a gist
  • Save devymex/51687edd41eef4ccc56d76a0c66bf92c to your computer and use it in GitHub Desktop.
Save devymex/51687edd41eef4ccc56d76a0c66bf92c to your computer and use it in GitHub Desktop.
Exporting Video-Swin-Transformer to onnx for TensorRT 7.x
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
import numpy as np
import math
def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
'survival rate' as the argument.
"""
if drop_prob == 0. or not training:
return x
keep_prob = 1 - drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
if keep_prob > 0.0 and scale_by_keep:
random_tensor.div_(keep_prob)
return x * random_tensor
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
"""
def __init__(self, drop_prob=None, scale_by_keep=True):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
self.scale_by_keep = scale_by_keep
def forward(self, x):
return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
# Cut & paste from PyTorch official master until it's in a few official releases - RW
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
# type: (Tensor, float, float, float, float) -> Tensor
r"""Fills the input Tensor with values drawn from a truncated
normal distribution. The values are effectively drawn from the
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
with values outside :math:`[a, b]` redrawn until they are within
the bounds. The method used for generating the random values works
best when :math:`a \leq \text{mean} \leq b`.
Args:
tensor: an n-dimensional `torch.Tensor`
mean: the mean of the normal distribution
std: the standard deviation of the normal distribution
a: the minimum cutoff value
b: the maximum cutoff value
Examples:
>>> w = torch.empty(3, 5)
>>> nn.init.trunc_normal_(w)
"""
def norm_cdf(x):
# Computes standard normal cumulative distribution function
return (1. + math.erf(x / math.sqrt(2.))) / 2.
# if (mean < a - 2 * std) or (mean > b + 2 * std):
# warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
# "The distribution of values may be incorrect.",
# stacklevel=2)
with torch.no_grad():
# Values are generated by using a truncated uniform distribution and
# then using the inverse CDF for the normal distribution.
# Get upper and lower cdf values
l = norm_cdf((a - mean) / std)
u = norm_cdf((b - mean) / std)
# Uniformly fill tensor with values from [l, u], then translate to
# [2l-1, 2u-1].
tensor.uniform_(2 * l - 1, 2 * u - 1)
# Use inverse cdf transform for normal distribution to get truncated
# standard normal
tensor.erfinv_()
# Transform to proper mean, std
tensor.mul_(std * math.sqrt(2.))
tensor.add_(mean)
# Clamp to ensure it's in the proper range
tensor.clamp_(min=a, max=b)
return tensor
def get_window_size(x_size, window_size, shift_size=None):
use_window_size = list(window_size)
if shift_size is not None:
use_shift_size = list(shift_size)
for i in range(len(x_size)):
if x_size[i] <= window_size[i]:
use_window_size[i] = x_size[i]
if shift_size is not None:
use_shift_size[i] = 0
if shift_size is None:
return tuple(use_window_size)
else:
return tuple(use_window_size), tuple(use_shift_size)
class Mlp(nn.Module):
""" Multilayer perceptron."""
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class WindowPartition(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, window_size = None):
"""
Args:
x: (B, D, H, W, C)
window_size (tuple[int]): window size
Returns:
windows: (B*num_windows, window_size*window_size, C)
"""
B, D, H, W, C = x.shape
if not hasattr(self, 'x_shape'):
self.x_shape = [
B,
D // window_size[0], window_size[0],
H // window_size[1], window_size[1],
W // window_size[2], window_size[2],
C
]
if not hasattr(self, 'out_shape'):
self.out_shape = [-1, window_size[0] * window_size[1] * window_size[2], C]
x = x.view(self.x_shape)
return x.permute(0, 1, 3, 5, 2, 4, 6, 7).reshape(self.out_shape)
class WindowReverse(nn.Module):
def __init__(self):
super().__init__()
def forward(self, windows, window_size = None, B = None, D = None, H = None, W = None):
"""
Args:
windows: (B*num_windows, window_size, window_size, C)
window_size (tuple[int]): Window size
H (int): Height of image
W (int): Width of image
Returns:
x: (B, D, H, W, C)
"""
if not hasattr(self, 'w_shape'):
self.w_shape = [
B,
D // window_size[0],
H // window_size[1],
W // window_size[2],
window_size[0],
window_size[1],
window_size[2],
-1]
if not hasattr(self, 'out_shape'):
self.out_shape = [B, D, H, W, -1]
x = windows.view(self.w_shape)
return x.permute(0, 1, 4, 2, 5, 3, 6, 7).reshape(self.out_shape)
class PatchMerging(nn.Module):
""" Patch Merging Layer
Args:
dim (int): Number of input channels.
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def __init__(self, dim, norm_layer=nn.LayerNorm):
super().__init__()
self.dim = dim
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
self.norm = norm_layer(4 * dim)
def forward(self, x):
""" Forward function.
Args:
x: Input feature, tensor size (B, D, H, W, C).
"""
B, D, H, W, C = x.shape
if not hasattr(self, 'pad_wh'):
self.pad_wh = [W % 2, H % 2]
self.half_wh = [W // 2, H // 2]
if self.pad_wh[0] or self.pad_wh[1]:
x = F.pad(x, (0, 0, 0, self.pad_wh[0], 0, self.pad_wh[1]))
# x0 = x[:, :, 0::2, 0::2, :] # B D H/2 W/2 C
# x1 = x[:, :, 1::2, 0::2, :] # B D H/2 W/2 C
# x2 = x[:, :, 0::2, 1::2, :] # B D H/2 W/2 C
# x3 = x[:, :, 1::2, 1::2, :] # B D H/2 W/2 C
# y = torch.cat([x0, x1, x2, x3], -1) # B D H/2 W/2 4*C
if not hasattr(self, 'conv_w'):
conv_w = torch.stack([
torch.Tensor([[1., 0.], [0., 0.]]),
torch.Tensor([[0., 0.], [1., 0.]]),
torch.Tensor([[0., 1.], [0., 0.]]),
torch.Tensor([[0., 0.], [0., 1.]])
])
conv_w = conv_w.unsqueeze(1).repeat(C, 1, 1, 1).to(x.device)
conv_w.requires_grad = False
self.register_buffer('conv_w', conv_w)
x = x.reshape(B * D, H, W, C).permute(0, 3, 1, 2)
x = F.conv2d(x, weight=self.conv_w, bias=None, stride=2, padding=0, groups=C)
x = x.reshape([B * D, C, 4, self.half_wh[1], self.half_wh[0]])
x = x.permute(0, 3, 4, 2, 1)
x = x.reshape([B, D, self.half_wh[1], self.half_wh[0], C * 4])
#assert (x - y).sum().item() < 1e-4
# x = y
return self.reduction(self.norm(x))
class PatchEmbed3D(nn.Module):
""" Video to Patch Embedding.
Args:
patch_size (int): Patch token size. Default: (2,4,4).
in_chans (int): Number of input video channels. Default: 3.
embed_dim (int): Number of linear projection output channels. Default: 96.
norm_layer (nn.Module, optional): Normalization layer. Default: None
"""
def __init__(self, patch_size=(2,4,4), in_chans=3, embed_dim=96, norm_layer=None):
super().__init__()
self.patch_size = patch_size
self.in_chans = in_chans
self.embed_dim = embed_dim
self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
if norm_layer is not None:
self.norm = norm_layer(embed_dim)
else:
self.norm = None
def forward(self, x):
"""Forward function."""
# padding
_, _, D, H, W = x.size()
if not hasattr(self, 'padW'):
self.padW = W % self.patch_size[2]
if not hasattr(self, 'padH'):
self.padH = H % self.patch_size[1]
if not hasattr(self, 'padD'):
self.padD = D % self.patch_size[0]
if self.padW != 0:
x = F.pad(x, (0, self.patch_size[2] - self.padW))
if self.padH != 0:
x = F.pad(x, (0, 0, 0, self.patch_size[1] - self.padH))
if self.padD != 0:
x = F.pad(x, (0, 0, 0, 0, 0, self.patch_size[0] - self.padD))
x = self.proj(x) # B C D Wh Ww
if self.norm is not None:
D, Wh, Ww = x.size(2), x.size(3), x.size(4)
x = x.flatten(2).transpose(1, 2)
x = self.norm(x)
x = x.transpose(1, 2).view(-1, self.embed_dim, D, Wh, Ww)
return x
# cache each stage results
class ComputeMask(nn.Module):
def __init__(self):
super().__init__()
self.window_partition = WindowPartition()
def forward(self, D, H, W, window_size, shift_size, device):
if not hasattr(self, 'attn_mask'):
img_mask = torch.zeros((1, D, H, W, 1), device=device) # 1 Dp Hp Wp 1
cnt = 0
for d in slice(-window_size[0]), slice(-window_size[0], -shift_size[0]), slice(-shift_size[0],None):
for h in slice(-window_size[1]), slice(-window_size[1], -shift_size[1]), slice(-shift_size[1],None):
for w in slice(-window_size[2]), slice(-window_size[2], -shift_size[2]), slice(-shift_size[2],None):
img_mask[:, d, h, w, :] = cnt
cnt += 1
mask_windows = self.window_partition(img_mask, window_size) # nW, ws[0]*ws[1]*ws[2], 1
mask_windows = mask_windows.squeeze(-1) # nW, ws[0]*ws[1]*ws[2]
self.attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
self.attn_mask = self.attn_mask.masked_fill(self.attn_mask != 0, float(-100.0)).masked_fill(self.attn_mask == 0, float(0.0))
return self.attn_mask
class WindowAttention3D(nn.Module):
""" Window based multi-head self attention (W-MSA) module with relative position bias.
It supports both of shifted and non-shifted window.
Args:
dim (int): Number of input channels.
window_size (tuple[int]): The temporal length, height and width of the window.
num_heads (int): Number of attention heads.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
"""
def __init__(self, dim, window_size, num_heads, qkv_bias=False,
qk_scale=None, attn_drop=0., proj_drop=0., in_shape_nc = []):
super().__init__()
self.dim = dim
self.window_size = window_size # Wd, Wh, Ww
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
# define a parameter table of relative position bias
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1) * (2 * window_size[2] - 1), num_heads)) # 2*Wd-1 * 2*Wh-1 * 2*Ww-1, nH
# get pair-wise relative position index for each token inside the window
coords_d = torch.arange(self.window_size[0])
coords_h = torch.arange(self.window_size[1])
coords_w = torch.arange(self.window_size[2])
coords = torch.stack(torch.meshgrid(coords_d, coords_h, coords_w)) # 3, Wd, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 3, Wd*Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 3, Wd*Wh*Ww, Wd*Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wd*Wh*Ww, Wd*Wh*Ww, 3
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += self.window_size[1] - 1
relative_coords[:, :, 2] += self.window_size[2] - 1
relative_coords[:, :, 0] *= (2 * self.window_size[1] - 1) * (2 * self.window_size[2] - 1)
relative_coords[:, :, 1] *= (2 * self.window_size[2] - 1)
relative_position_index = relative_coords.sum(-1) # Wd*Wh*Ww, Wd*Wh*Ww
self.register_buffer("relative_position_index", relative_position_index)
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
trunc_normal_(self.relative_position_bias_table, std=.02)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x, mask=None):
""" Forward function.
Args:
x: input features with shape of (num_windows*B, N, C)
mask: (0/-inf) mask with shape of (num_windows, N, N) or None
"""
B_, N, C = x.shape
indices = self.relative_position_index[:N, :N].reshape(-1)
relative_position_bias = self.relative_position_bias_table[indices].reshape(N, N, -1) # Wd*Wh*Ww,Wd*Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wd*Wh*Ww, Wd*Wh*Ww
if not hasattr(self, 'qkv_shape'):
self.qkv_shape = [B_, N, 3, self.num_heads, C // self.num_heads]
qkv = self.qkv(x).reshape(self.qkv_shape).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # B_, nH, N, C
q = q * self.scale
attn = q @ k.transpose(-2, -1)
attn = attn + relative_position_bias.unsqueeze(0) # B_, nH, N, N
if mask is not None:
nW = mask.shape[0]
if not hasattr(self, 'attn_shape'):
self.attn_shape = [B_ // nW, nW, self.num_heads, N, N]
attn = attn.view(self.attn_shape) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N)
attn = self.softmax(attn)
else:
attn = self.softmax(attn)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(-1, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class SwinTransformerBlock3D(nn.Module):
""" Swin Transformer Block.
Args:
dim (int): Number of input channels.
num_heads (int): Number of attention heads.
window_size (tuple[int]): Window size.
shift_size (tuple[int]): Shift size for SW-MSA.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float, optional): Stochastic depth rate. Default: 0.0
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def __init__(self, dim, num_heads, window_size=(2,7,7), shift_size=(0,0,0),
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_checkpoint=False):
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.window_size = window_size
self.shift_size = shift_size
self.mlp_ratio = mlp_ratio
self.use_checkpoint=use_checkpoint
assert 0 <= self.shift_size[0] < self.window_size[0], "shift_size must in 0-window_size"
assert 0 <= self.shift_size[1] < self.window_size[1], "shift_size must in 0-window_size"
assert 0 <= self.shift_size[2] < self.window_size[2], "shift_size must in 0-window_size"
self.norm1 = norm_layer(dim)
self.attn = WindowAttention3D(dim, window_size=self.window_size, num_heads=num_heads,
qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
self.window_partition = WindowPartition()
self.window_reverse = WindowReverse()
def forward_part1(self, x, mask_matrix):
B, D, H, W, C = x.shape
if not hasattr(self, 'window_size_n'):
self.window_size_n, self.shift_size_n = get_window_size((D, H, W), self.window_size, self.shift_size)
x = self.norm1(x)
# pad feature maps to multiples of window size
pad_l = pad_t = pad_d0 = 0
if not hasattr(self, 'pad_dbr'):
self.pad_dbr = [
(self.window_size_n[0] - D % self.window_size_n[0]) % self.window_size_n[0],
(self.window_size_n[1] - H % self.window_size_n[1]) % self.window_size_n[1],
(self.window_size_n[2] - W % self.window_size_n[2]) % self.window_size_n[2]
]
x = F.pad(x, (0, 0, pad_l, self.pad_dbr[2], pad_t, self.pad_dbr[1], pad_d0, self.pad_dbr[0]))
_, Dp, Hp, Wp, _ = x.shape
# cyclic shift
if any(i > 0 for i in self.shift_size_n):
shifted_x = torch.roll(x, shifts=(-self.shift_size_n[0], -self.shift_size_n[1], -self.shift_size_n[2]), dims=(1, 2, 3))
attn_mask = mask_matrix
else:
shifted_x = x
attn_mask = None
# partition windows
x_windows = self.window_partition(shifted_x, self.window_size_n) # B*nW, Wd*Wh*Ww, C
# W-MSA/SW-MSA
attn_windows = self.attn(x_windows, mask=attn_mask) # B*nW, Wd*Wh*Ww, C
# merge windows
attn_windows_shape = (-1,) + self.window_size_n + (C,)
attn_windows = attn_windows.view(attn_windows_shape)
shifted_x = self.window_reverse(attn_windows, self.window_size_n, B, Dp, Hp, Wp) # B D' H' W' C
# reverse cyclic shift
if any(i > 0 for i in self.shift_size_n):
x = torch.roll(shifted_x, shifts=(self.shift_size_n[0], self.shift_size_n[1], self.shift_size_n[2]), dims=(1, 2, 3))
else:
x = shifted_x
if self.pad_dbr[0] > 0 or self.pad_dbr[1] > 0 or self.pad_dbr[2] > 0:
x = x[:, :D, :H, :W, :].contiguous()
return x
def forward_part2(self, x):
return self.drop_path(self.mlp(self.norm2(x)))
def forward(self, x, mask_matrix):
""" Forward function.
Args:
x: Input feature, tensor size (B, D, H, W, C).
mask_matrix: Attention mask for cyclic shift.
"""
shortcut = x
if self.use_checkpoint:
x = checkpoint.checkpoint(self.forward_part1, x, mask_matrix)
else:
x = self.forward_part1(x, mask_matrix)
x = shortcut + self.drop_path(x)
if self.use_checkpoint:
x = x + checkpoint.checkpoint(self.forward_part2, x)
else:
x = x + self.forward_part2(x)
return x
class BasicLayer(nn.Module):
""" A basic Swin Transformer layer for one stage.
Args:
dim (int): Number of feature channels
depth (int): Depths of this stage.
num_heads (int): Number of attention head.
window_size (tuple[int]): Local window size. Default: (1,7,7).
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
"""
def __init__(self,
dim,
depth,
num_heads,
window_size=(1,7,7),
mlp_ratio=4.,
qkv_bias=False,
qk_scale=None,
drop=0.,
attn_drop=0.,
drop_path=0.,
norm_layer=nn.LayerNorm,
downsample=None,
use_checkpoint=False,
in_shape_dhw = []):
super().__init__()
self.window_size = window_size
self.shift_size = tuple(i // 2 for i in window_size)
self.depth = depth
self.use_checkpoint = use_checkpoint
# build blocks
self.blocks = nn.ModuleList([
SwinTransformerBlock3D(
dim=dim,
num_heads=num_heads,
window_size=window_size,
shift_size=(0,0,0) if (i % 2 == 0) else self.shift_size,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop,
attn_drop=attn_drop,
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
norm_layer=norm_layer,
use_checkpoint=use_checkpoint,
)
for i in range(depth)])
self.downsample = downsample
if self.downsample is not None:
self.downsample = downsample(dim=dim, norm_layer=norm_layer)
self.compute_mask = ComputeMask()
def forward(self, x):
""" Forward function.
Args:
x: Input feature, tensor size (B, C, D, H, W).
"""
# calculate attention mask for SW-MSA
if not hasattr(self, 'attn_mask'):
_, _, D, H, W = x.shape
window_size, shift_size = get_window_size((D,H,W), self.window_size, self.shift_size)
Dp = int(np.ceil(D / window_size[0])) * window_size[0]
Hp = int(np.ceil(H / window_size[1])) * window_size[1]
Wp = int(np.ceil(W / window_size[2])) * window_size[2]
attn_mask = self.compute_mask(Dp, Hp, Wp, window_size, shift_size, x.device)
self.register_buffer("attn_mask", attn_mask)
x = torch.permute(x, [0, 2, 3, 4, 1]) # x = rearrange(x, 'b c d h w -> b d h w c')
for blk in self.blocks:
x = blk(x, self.attn_mask)
if self.downsample is not None:
x = self.downsample(x)
x = torch.permute(x, [0, 4, 1, 2, 3]) # x = rearrange(x, 'b d h w c -> b c d h w')
return x
class SwinTransformer3D(nn.Module):
""" Swin Transformer backbone.
A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
https://arxiv.org/pdf/2103.14030
Args:
patch_size (int | tuple(int)): Patch size. Default: (4,4,4).
in_chans (int): Number of input image channels. Default: 3.
embed_dim (int): Number of linear projection output channels. Default: 96.
depths (tuple[int]): Depths of each Swin Transformer stage.
num_heads (tuple[int]): Number of attention head of each stage.
window_size (int): Window size. Default: 7.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: Truee
qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
drop_rate (float): Dropout rate.
attn_drop_rate (float): Attention dropout rate. Default: 0.
drop_path_rate (float): Stochastic depth rate. Default: 0.2.
norm_layer: Normalization layer. Default: nn.LayerNorm.
patch_norm (bool): If True, add normalization after patch embedding. Default: False.
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
-1 means not freezing any parameters.
"""
def __init__(self,
patch_size=(2,4,4),
in_chans=3,
embed_dim=96,
depths=[2, 2, 6, 2],
num_heads=[3, 6, 12, 24],
window_size=(8,7,7),
mlp_ratio=4.,
qkv_bias=True,
qk_scale=None,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.2,
norm_layer=nn.LayerNorm,
patch_norm=True,
frozen_stages=-1,
use_checkpoint=False):
super().__init__()
self.num_layers = len(depths)
self.embed_dim = embed_dim
self.patch_norm = patch_norm
self.frozen_stages = frozen_stages
self.window_size = window_size
self.patch_size = patch_size
# split image into non-overlapping patches
self.patch_embed = PatchEmbed3D(
patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
norm_layer=norm_layer if self.patch_norm else None)
self.pos_drop = nn.Dropout(p=drop_rate)
# stochastic depth
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
# build layers
self.layers = nn.ModuleList()
for i_layer in range(self.num_layers):
layer = BasicLayer(
dim=int(embed_dim * 2**i_layer),
depth=depths[i_layer],
num_heads=num_heads[i_layer],
window_size=window_size,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
norm_layer=norm_layer,
downsample=PatchMerging if i_layer<self.num_layers-1 else None,
use_checkpoint=use_checkpoint)
self.layers.append(layer)
self.num_features = int(embed_dim * 2**(self.num_layers-1))
self.norm = norm_layer(self.num_features)
self._freeze_stages()
def _freeze_stages(self):
if self.frozen_stages >= 0:
self.patch_embed.eval()
for param in self.patch_embed.parameters():
param.requires_grad = False
if self.frozen_stages >= 1:
self.pos_drop.eval()
for i in range(0, self.frozen_stages):
m = self.layers[i]
m.eval()
for param in m.parameters():
param.requires_grad = False
def train(self, mode=True):
"""Convert the model into training mode while keep layers freezed."""
super(SwinTransformer3D, self).train(mode)
self._freeze_stages()
def forward(self, x):
x = self.patch_embed(x)
x = self.pos_drop(x)
for layer in self.layers:
x = layer(x.contiguous())
x = torch.permute(x, [0, 2, 3, 4, 1]) # x = rearrange(x, 'n c d h w -> n d h w c')
x = self.norm(x)
x = torch.permute(x, [0, 4, 1, 2, 3]) # x = rearrange(x, 'n d h w c -> n c d h w')
return x
class I3DHead(nn.Module):
def __init__(self,
num_classes = 400,
in_channels = 768,
spatial_type='avg',
dropout_ratio=0.5):
super().__init__()
self.num_classes = num_classes
self.in_channels = in_channels
self.spatial_type = spatial_type
self.dropout_ratio = dropout_ratio
if self.dropout_ratio != 0:
self.dropout = nn.Dropout(p=self.dropout_ratio)
else:
self.dropout = None
self.fc_cls = nn.Linear(self.in_channels, self.num_classes)
if self.spatial_type == 'avg':
# use `nn.AdaptiveAvgPool3d` to adaptively match the in_channels.
self.avg_pool = nn.AdaptiveAvgPool3d((1, 1, 1))
else:
self.avg_pool = None
def forward(self, x):
if self.avg_pool is not None:
x = self.avg_pool(x)
if self.dropout is not None:
x = self.dropout(x)
x = x.view(x.shape[0], -1)
cls_score = self.fc_cls(x)
return cls_score
class VideoSwinTransformer(nn.Module):
def __init__(self, model_name: str, num_classes: int):
super().__init__()
self.backbone = SwinTransformer3D()
self.cls_head = I3DHead(num_classes=num_classes)
pretrained_filename = f'models/pretrained/{model_name}.pth'
pretrain_params = torch.load(pretrained_filename)['state_dict']
with torch.no_grad():
for name, param in self.named_parameters():
if name in pretrain_params:
if param.shape == pretrain_params[name].shape:
param.copy_(pretrain_params[name])
elif 'weight' in name:
nn.init.xavier_uniform_(param)
print(f'parameter "{name}" is initialized by xavier_uniform_')
elif 'bias' in name:
nn.init.zeros_(param)
print(f'parameter "{name}" is initialized by zeros_')
else:
raise f'intialization of {name} is failed'
def forward(self, x): # B C D Wh Ww
x = self.backbone(x)
return self.cls_head(x)
if __name__ == "__main__":
model = VideoSwinTransformer('vswt-tiny', 400)
model.train(False)
torch.manual_seed(1)
x = torch.rand([4, 3, 8, 224, 224], dtype=torch.float32)
y = model(x)
y = model(x)
from array import array
with open('debug/x.bin', 'wb') as x_file:
float_array = array('f', x.flatten().tolist())
float_array.tofile(x_file)
with open('debug/y.bin', 'wb') as y_file:
float_array = array('f', y.flatten().tolist())
float_array.tofile(y_file)
print(x.view(-1)[:10])
y = y.reshape([16, -1])
y = y[:, 0].tolist()
print('\n'.join([str(a) for a in y]))
def export_onnx(model, input_shapes, output_names, onnx_filename):
model.eval()
input_names = list(input_shapes.keys())
sample_inputs = []
for name in input_names:
sample_inputs.append(torch.zeros(input_shapes[name], requires_grad=False))
sample_inputs = tuple(sample_inputs)
dynamic_axes = {}
for name in input_names:
dynamic_axes[name] = {0: 'batch'}
for name in output_names:
dynamic_axes[name] = {0: 'batch'}
torch.onnx.export(
model=model.cpu(),
args=sample_inputs,
f=onnx_filename,
export_params=True,
verbose=False,
opset_version=9,
input_names=input_names,
output_names=output_names,
do_constant_folding=True,
dynamic_axes=dynamic_axes)
export_onnx(model, {'images': [4, 3, 8, 224, 224]}, ['output'], 'debug/swin.onnx')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment