-
-
Save Ir1d/2f2314385e14f90d323c3f5c7cded97d to your computer and use it in GitHub Desktop.
Exporting Video-Swin-Transformer to onnx for TensorRT 7.x
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch.utils.checkpoint as checkpoint | |
import numpy as np | |
import math | |
def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True): | |
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). | |
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, | |
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... | |
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for | |
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use | |
'survival rate' as the argument. | |
""" | |
if drop_prob == 0. or not training: | |
return x | |
keep_prob = 1 - drop_prob | |
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets | |
random_tensor = x.new_empty(shape).bernoulli_(keep_prob) | |
if keep_prob > 0.0 and scale_by_keep: | |
random_tensor.div_(keep_prob) | |
return x * random_tensor | |
class DropPath(nn.Module): | |
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). | |
""" | |
def __init__(self, drop_prob=None, scale_by_keep=True): | |
super(DropPath, self).__init__() | |
self.drop_prob = drop_prob | |
self.scale_by_keep = scale_by_keep | |
def forward(self, x): | |
return drop_path(x, self.drop_prob, self.training, self.scale_by_keep) | |
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): | |
# Cut & paste from PyTorch official master until it's in a few official releases - RW | |
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf | |
# type: (Tensor, float, float, float, float) -> Tensor | |
r"""Fills the input Tensor with values drawn from a truncated | |
normal distribution. The values are effectively drawn from the | |
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` | |
with values outside :math:`[a, b]` redrawn until they are within | |
the bounds. The method used for generating the random values works | |
best when :math:`a \leq \text{mean} \leq b`. | |
Args: | |
tensor: an n-dimensional `torch.Tensor` | |
mean: the mean of the normal distribution | |
std: the standard deviation of the normal distribution | |
a: the minimum cutoff value | |
b: the maximum cutoff value | |
Examples: | |
>>> w = torch.empty(3, 5) | |
>>> nn.init.trunc_normal_(w) | |
""" | |
def norm_cdf(x): | |
# Computes standard normal cumulative distribution function | |
return (1. + math.erf(x / math.sqrt(2.))) / 2. | |
# if (mean < a - 2 * std) or (mean > b + 2 * std): | |
# warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " | |
# "The distribution of values may be incorrect.", | |
# stacklevel=2) | |
with torch.no_grad(): | |
# Values are generated by using a truncated uniform distribution and | |
# then using the inverse CDF for the normal distribution. | |
# Get upper and lower cdf values | |
l = norm_cdf((a - mean) / std) | |
u = norm_cdf((b - mean) / std) | |
# Uniformly fill tensor with values from [l, u], then translate to | |
# [2l-1, 2u-1]. | |
tensor.uniform_(2 * l - 1, 2 * u - 1) | |
# Use inverse cdf transform for normal distribution to get truncated | |
# standard normal | |
tensor.erfinv_() | |
# Transform to proper mean, std | |
tensor.mul_(std * math.sqrt(2.)) | |
tensor.add_(mean) | |
# Clamp to ensure it's in the proper range | |
tensor.clamp_(min=a, max=b) | |
return tensor | |
def get_window_size(x_size, window_size, shift_size=None): | |
use_window_size = list(window_size) | |
if shift_size is not None: | |
use_shift_size = list(shift_size) | |
for i in range(len(x_size)): | |
if x_size[i] <= window_size[i]: | |
use_window_size[i] = x_size[i] | |
if shift_size is not None: | |
use_shift_size[i] = 0 | |
if shift_size is None: | |
return tuple(use_window_size) | |
else: | |
return tuple(use_window_size), tuple(use_shift_size) | |
class Mlp(nn.Module): | |
""" Multilayer perceptron.""" | |
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): | |
super().__init__() | |
out_features = out_features or in_features | |
hidden_features = hidden_features or in_features | |
self.fc1 = nn.Linear(in_features, hidden_features) | |
self.act = act_layer() | |
self.fc2 = nn.Linear(hidden_features, out_features) | |
self.drop = nn.Dropout(drop) | |
def forward(self, x): | |
x = self.fc1(x) | |
x = self.act(x) | |
x = self.drop(x) | |
x = self.fc2(x) | |
x = self.drop(x) | |
return x | |
class WindowPartition(nn.Module): | |
def __init__(self): | |
super().__init__() | |
def forward(self, x, window_size = None): | |
""" | |
Args: | |
x: (B, D, H, W, C) | |
window_size (tuple[int]): window size | |
Returns: | |
windows: (B*num_windows, window_size*window_size, C) | |
""" | |
B, D, H, W, C = x.shape | |
if not hasattr(self, 'x_shape'): | |
self.x_shape = [ | |
B, | |
D // window_size[0], window_size[0], | |
H // window_size[1], window_size[1], | |
W // window_size[2], window_size[2], | |
C | |
] | |
if not hasattr(self, 'out_shape'): | |
self.out_shape = [-1, window_size[0] * window_size[1] * window_size[2], C] | |
x = x.view(self.x_shape) | |
return x.permute(0, 1, 3, 5, 2, 4, 6, 7).reshape(self.out_shape) | |
class WindowReverse(nn.Module): | |
def __init__(self): | |
super().__init__() | |
def forward(self, windows, window_size = None, B = None, D = None, H = None, W = None): | |
""" | |
Args: | |
windows: (B*num_windows, window_size, window_size, C) | |
window_size (tuple[int]): Window size | |
H (int): Height of image | |
W (int): Width of image | |
Returns: | |
x: (B, D, H, W, C) | |
""" | |
if not hasattr(self, 'w_shape'): | |
self.w_shape = [ | |
B, | |
D // window_size[0], | |
H // window_size[1], | |
W // window_size[2], | |
window_size[0], | |
window_size[1], | |
window_size[2], | |
-1] | |
if not hasattr(self, 'out_shape'): | |
self.out_shape = [B, D, H, W, -1] | |
x = windows.view(self.w_shape) | |
return x.permute(0, 1, 4, 2, 5, 3, 6, 7).reshape(self.out_shape) | |
class PatchMerging(nn.Module): | |
""" Patch Merging Layer | |
Args: | |
dim (int): Number of input channels. | |
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm | |
""" | |
def __init__(self, dim, norm_layer=nn.LayerNorm): | |
super().__init__() | |
self.dim = dim | |
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) | |
self.norm = norm_layer(4 * dim) | |
def forward(self, x): | |
""" Forward function. | |
Args: | |
x: Input feature, tensor size (B, D, H, W, C). | |
""" | |
B, D, H, W, C = x.shape | |
if not hasattr(self, 'pad_wh'): | |
self.pad_wh = [W % 2, H % 2] | |
self.half_wh = [W // 2, H // 2] | |
if self.pad_wh[0] or self.pad_wh[1]: | |
x = F.pad(x, (0, 0, 0, self.pad_wh[0], 0, self.pad_wh[1])) | |
# x0 = x[:, :, 0::2, 0::2, :] # B D H/2 W/2 C | |
# x1 = x[:, :, 1::2, 0::2, :] # B D H/2 W/2 C | |
# x2 = x[:, :, 0::2, 1::2, :] # B D H/2 W/2 C | |
# x3 = x[:, :, 1::2, 1::2, :] # B D H/2 W/2 C | |
# y = torch.cat([x0, x1, x2, x3], -1) # B D H/2 W/2 4*C | |
if not hasattr(self, 'conv_w'): | |
conv_w = torch.stack([ | |
torch.Tensor([[1., 0.], [0., 0.]]), | |
torch.Tensor([[0., 0.], [1., 0.]]), | |
torch.Tensor([[0., 1.], [0., 0.]]), | |
torch.Tensor([[0., 0.], [0., 1.]]) | |
]) | |
conv_w = conv_w.unsqueeze(1).repeat(C, 1, 1, 1).to(x.device) | |
conv_w.requires_grad = False | |
self.register_buffer('conv_w', conv_w) | |
x = x.reshape(B * D, H, W, C).permute(0, 3, 1, 2) | |
x = F.conv2d(x, weight=self.conv_w, bias=None, stride=2, padding=0, groups=C) | |
x = x.reshape([B * D, C, 4, self.half_wh[1], self.half_wh[0]]) | |
x = x.permute(0, 3, 4, 2, 1) | |
x = x.reshape([B, D, self.half_wh[1], self.half_wh[0], C * 4]) | |
#assert (x - y).sum().item() < 1e-4 | |
# x = y | |
return self.reduction(self.norm(x)) | |
class PatchEmbed3D(nn.Module): | |
""" Video to Patch Embedding. | |
Args: | |
patch_size (int): Patch token size. Default: (2,4,4). | |
in_chans (int): Number of input video channels. Default: 3. | |
embed_dim (int): Number of linear projection output channels. Default: 96. | |
norm_layer (nn.Module, optional): Normalization layer. Default: None | |
""" | |
def __init__(self, patch_size=(2,4,4), in_chans=3, embed_dim=96, norm_layer=None): | |
super().__init__() | |
self.patch_size = patch_size | |
self.in_chans = in_chans | |
self.embed_dim = embed_dim | |
self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) | |
if norm_layer is not None: | |
self.norm = norm_layer(embed_dim) | |
else: | |
self.norm = None | |
def forward(self, x): | |
"""Forward function.""" | |
# padding | |
_, _, D, H, W = x.size() | |
if not hasattr(self, 'padW'): | |
self.padW = W % self.patch_size[2] | |
if not hasattr(self, 'padH'): | |
self.padH = H % self.patch_size[1] | |
if not hasattr(self, 'padD'): | |
self.padD = D % self.patch_size[0] | |
if self.padW != 0: | |
x = F.pad(x, (0, self.patch_size[2] - self.padW)) | |
if self.padH != 0: | |
x = F.pad(x, (0, 0, 0, self.patch_size[1] - self.padH)) | |
if self.padD != 0: | |
x = F.pad(x, (0, 0, 0, 0, 0, self.patch_size[0] - self.padD)) | |
x = self.proj(x) # B C D Wh Ww | |
if self.norm is not None: | |
D, Wh, Ww = x.size(2), x.size(3), x.size(4) | |
x = x.flatten(2).transpose(1, 2) | |
x = self.norm(x) | |
x = x.transpose(1, 2).view(-1, self.embed_dim, D, Wh, Ww) | |
return x | |
# cache each stage results | |
class ComputeMask(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.window_partition = WindowPartition() | |
def forward(self, D, H, W, window_size, shift_size, device): | |
if not hasattr(self, 'attn_mask'): | |
img_mask = torch.zeros((1, D, H, W, 1), device=device) # 1 Dp Hp Wp 1 | |
cnt = 0 | |
for d in slice(-window_size[0]), slice(-window_size[0], -shift_size[0]), slice(-shift_size[0],None): | |
for h in slice(-window_size[1]), slice(-window_size[1], -shift_size[1]), slice(-shift_size[1],None): | |
for w in slice(-window_size[2]), slice(-window_size[2], -shift_size[2]), slice(-shift_size[2],None): | |
img_mask[:, d, h, w, :] = cnt | |
cnt += 1 | |
mask_windows = self.window_partition(img_mask, window_size) # nW, ws[0]*ws[1]*ws[2], 1 | |
mask_windows = mask_windows.squeeze(-1) # nW, ws[0]*ws[1]*ws[2] | |
self.attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) | |
self.attn_mask = self.attn_mask.masked_fill(self.attn_mask != 0, float(-100.0)).masked_fill(self.attn_mask == 0, float(0.0)) | |
return self.attn_mask | |
class WindowAttention3D(nn.Module): | |
""" Window based multi-head self attention (W-MSA) module with relative position bias. | |
It supports both of shifted and non-shifted window. | |
Args: | |
dim (int): Number of input channels. | |
window_size (tuple[int]): The temporal length, height and width of the window. | |
num_heads (int): Number of attention heads. | |
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True | |
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set | |
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 | |
proj_drop (float, optional): Dropout ratio of output. Default: 0.0 | |
""" | |
def __init__(self, dim, window_size, num_heads, qkv_bias=False, | |
qk_scale=None, attn_drop=0., proj_drop=0., in_shape_nc = []): | |
super().__init__() | |
self.dim = dim | |
self.window_size = window_size # Wd, Wh, Ww | |
self.num_heads = num_heads | |
head_dim = dim // num_heads | |
self.scale = qk_scale or head_dim ** -0.5 | |
# define a parameter table of relative position bias | |
self.relative_position_bias_table = nn.Parameter( | |
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1) * (2 * window_size[2] - 1), num_heads)) # 2*Wd-1 * 2*Wh-1 * 2*Ww-1, nH | |
# get pair-wise relative position index for each token inside the window | |
coords_d = torch.arange(self.window_size[0]) | |
coords_h = torch.arange(self.window_size[1]) | |
coords_w = torch.arange(self.window_size[2]) | |
coords = torch.stack(torch.meshgrid(coords_d, coords_h, coords_w)) # 3, Wd, Wh, Ww | |
coords_flatten = torch.flatten(coords, 1) # 3, Wd*Wh*Ww | |
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 3, Wd*Wh*Ww, Wd*Wh*Ww | |
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wd*Wh*Ww, Wd*Wh*Ww, 3 | |
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 | |
relative_coords[:, :, 1] += self.window_size[1] - 1 | |
relative_coords[:, :, 2] += self.window_size[2] - 1 | |
relative_coords[:, :, 0] *= (2 * self.window_size[1] - 1) * (2 * self.window_size[2] - 1) | |
relative_coords[:, :, 1] *= (2 * self.window_size[2] - 1) | |
relative_position_index = relative_coords.sum(-1) # Wd*Wh*Ww, Wd*Wh*Ww | |
self.register_buffer("relative_position_index", relative_position_index) | |
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) | |
self.attn_drop = nn.Dropout(attn_drop) | |
self.proj = nn.Linear(dim, dim) | |
self.proj_drop = nn.Dropout(proj_drop) | |
trunc_normal_(self.relative_position_bias_table, std=.02) | |
self.softmax = nn.Softmax(dim=-1) | |
def forward(self, x, mask=None): | |
""" Forward function. | |
Args: | |
x: input features with shape of (num_windows*B, N, C) | |
mask: (0/-inf) mask with shape of (num_windows, N, N) or None | |
""" | |
B_, N, C = x.shape | |
indices = self.relative_position_index[:N, :N].reshape(-1) | |
relative_position_bias = self.relative_position_bias_table[indices].reshape(N, N, -1) # Wd*Wh*Ww,Wd*Wh*Ww,nH | |
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wd*Wh*Ww, Wd*Wh*Ww | |
if not hasattr(self, 'qkv_shape'): | |
self.qkv_shape = [B_, N, 3, self.num_heads, C // self.num_heads] | |
qkv = self.qkv(x).reshape(self.qkv_shape).permute(2, 0, 3, 1, 4) | |
q, k, v = qkv[0], qkv[1], qkv[2] # B_, nH, N, C | |
q = q * self.scale | |
attn = q @ k.transpose(-2, -1) | |
attn = attn + relative_position_bias.unsqueeze(0) # B_, nH, N, N | |
if mask is not None: | |
nW = mask.shape[0] | |
if not hasattr(self, 'attn_shape'): | |
self.attn_shape = [B_ // nW, nW, self.num_heads, N, N] | |
attn = attn.view(self.attn_shape) + mask.unsqueeze(1).unsqueeze(0) | |
attn = attn.view(-1, self.num_heads, N, N) | |
attn = self.softmax(attn) | |
else: | |
attn = self.softmax(attn) | |
attn = self.attn_drop(attn) | |
x = (attn @ v).transpose(1, 2).reshape(-1, N, C) | |
x = self.proj(x) | |
x = self.proj_drop(x) | |
return x | |
class SwinTransformerBlock3D(nn.Module): | |
""" Swin Transformer Block. | |
Args: | |
dim (int): Number of input channels. | |
num_heads (int): Number of attention heads. | |
window_size (tuple[int]): Window size. | |
shift_size (tuple[int]): Shift size for SW-MSA. | |
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. | |
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True | |
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. | |
drop (float, optional): Dropout rate. Default: 0.0 | |
attn_drop (float, optional): Attention dropout rate. Default: 0.0 | |
drop_path (float, optional): Stochastic depth rate. Default: 0.0 | |
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU | |
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm | |
""" | |
def __init__(self, dim, num_heads, window_size=(2,7,7), shift_size=(0,0,0), | |
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., | |
act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_checkpoint=False): | |
super().__init__() | |
self.dim = dim | |
self.num_heads = num_heads | |
self.window_size = window_size | |
self.shift_size = shift_size | |
self.mlp_ratio = mlp_ratio | |
self.use_checkpoint=use_checkpoint | |
assert 0 <= self.shift_size[0] < self.window_size[0], "shift_size must in 0-window_size" | |
assert 0 <= self.shift_size[1] < self.window_size[1], "shift_size must in 0-window_size" | |
assert 0 <= self.shift_size[2] < self.window_size[2], "shift_size must in 0-window_size" | |
self.norm1 = norm_layer(dim) | |
self.attn = WindowAttention3D(dim, window_size=self.window_size, num_heads=num_heads, | |
qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) | |
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() | |
self.norm2 = norm_layer(dim) | |
mlp_hidden_dim = int(dim * mlp_ratio) | |
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) | |
self.window_partition = WindowPartition() | |
self.window_reverse = WindowReverse() | |
def forward_part1(self, x, mask_matrix): | |
B, D, H, W, C = x.shape | |
if not hasattr(self, 'window_size_n'): | |
self.window_size_n, self.shift_size_n = get_window_size((D, H, W), self.window_size, self.shift_size) | |
x = self.norm1(x) | |
# pad feature maps to multiples of window size | |
pad_l = pad_t = pad_d0 = 0 | |
if not hasattr(self, 'pad_dbr'): | |
self.pad_dbr = [ | |
(self.window_size_n[0] - D % self.window_size_n[0]) % self.window_size_n[0], | |
(self.window_size_n[1] - H % self.window_size_n[1]) % self.window_size_n[1], | |
(self.window_size_n[2] - W % self.window_size_n[2]) % self.window_size_n[2] | |
] | |
x = F.pad(x, (0, 0, pad_l, self.pad_dbr[2], pad_t, self.pad_dbr[1], pad_d0, self.pad_dbr[0])) | |
_, Dp, Hp, Wp, _ = x.shape | |
# cyclic shift | |
if any(i > 0 for i in self.shift_size_n): | |
shifted_x = torch.roll(x, shifts=(-self.shift_size_n[0], -self.shift_size_n[1], -self.shift_size_n[2]), dims=(1, 2, 3)) | |
attn_mask = mask_matrix | |
else: | |
shifted_x = x | |
attn_mask = None | |
# partition windows | |
x_windows = self.window_partition(shifted_x, self.window_size_n) # B*nW, Wd*Wh*Ww, C | |
# W-MSA/SW-MSA | |
attn_windows = self.attn(x_windows, mask=attn_mask) # B*nW, Wd*Wh*Ww, C | |
# merge windows | |
attn_windows_shape = (-1,) + self.window_size_n + (C,) | |
attn_windows = attn_windows.view(attn_windows_shape) | |
shifted_x = self.window_reverse(attn_windows, self.window_size_n, B, Dp, Hp, Wp) # B D' H' W' C | |
# reverse cyclic shift | |
if any(i > 0 for i in self.shift_size_n): | |
x = torch.roll(shifted_x, shifts=(self.shift_size_n[0], self.shift_size_n[1], self.shift_size_n[2]), dims=(1, 2, 3)) | |
else: | |
x = shifted_x | |
if self.pad_dbr[0] > 0 or self.pad_dbr[1] > 0 or self.pad_dbr[2] > 0: | |
x = x[:, :D, :H, :W, :].contiguous() | |
return x | |
def forward_part2(self, x): | |
return self.drop_path(self.mlp(self.norm2(x))) | |
def forward(self, x, mask_matrix): | |
""" Forward function. | |
Args: | |
x: Input feature, tensor size (B, D, H, W, C). | |
mask_matrix: Attention mask for cyclic shift. | |
""" | |
shortcut = x | |
if self.use_checkpoint: | |
x = checkpoint.checkpoint(self.forward_part1, x, mask_matrix) | |
else: | |
x = self.forward_part1(x, mask_matrix) | |
x = shortcut + self.drop_path(x) | |
if self.use_checkpoint: | |
x = x + checkpoint.checkpoint(self.forward_part2, x) | |
else: | |
x = x + self.forward_part2(x) | |
return x | |
class BasicLayer(nn.Module): | |
""" A basic Swin Transformer layer for one stage. | |
Args: | |
dim (int): Number of feature channels | |
depth (int): Depths of this stage. | |
num_heads (int): Number of attention head. | |
window_size (tuple[int]): Local window size. Default: (1,7,7). | |
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. | |
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True | |
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. | |
drop (float, optional): Dropout rate. Default: 0.0 | |
attn_drop (float, optional): Attention dropout rate. Default: 0.0 | |
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 | |
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm | |
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None | |
""" | |
def __init__(self, | |
dim, | |
depth, | |
num_heads, | |
window_size=(1,7,7), | |
mlp_ratio=4., | |
qkv_bias=False, | |
qk_scale=None, | |
drop=0., | |
attn_drop=0., | |
drop_path=0., | |
norm_layer=nn.LayerNorm, | |
downsample=None, | |
use_checkpoint=False, | |
in_shape_dhw = []): | |
super().__init__() | |
self.window_size = window_size | |
self.shift_size = tuple(i // 2 for i in window_size) | |
self.depth = depth | |
self.use_checkpoint = use_checkpoint | |
# build blocks | |
self.blocks = nn.ModuleList([ | |
SwinTransformerBlock3D( | |
dim=dim, | |
num_heads=num_heads, | |
window_size=window_size, | |
shift_size=(0,0,0) if (i % 2 == 0) else self.shift_size, | |
mlp_ratio=mlp_ratio, | |
qkv_bias=qkv_bias, | |
qk_scale=qk_scale, | |
drop=drop, | |
attn_drop=attn_drop, | |
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, | |
norm_layer=norm_layer, | |
use_checkpoint=use_checkpoint, | |
) | |
for i in range(depth)]) | |
self.downsample = downsample | |
if self.downsample is not None: | |
self.downsample = downsample(dim=dim, norm_layer=norm_layer) | |
self.compute_mask = ComputeMask() | |
def forward(self, x): | |
""" Forward function. | |
Args: | |
x: Input feature, tensor size (B, C, D, H, W). | |
""" | |
# calculate attention mask for SW-MSA | |
if not hasattr(self, 'attn_mask'): | |
_, _, D, H, W = x.shape | |
window_size, shift_size = get_window_size((D,H,W), self.window_size, self.shift_size) | |
Dp = int(np.ceil(D / window_size[0])) * window_size[0] | |
Hp = int(np.ceil(H / window_size[1])) * window_size[1] | |
Wp = int(np.ceil(W / window_size[2])) * window_size[2] | |
attn_mask = self.compute_mask(Dp, Hp, Wp, window_size, shift_size, x.device) | |
self.register_buffer("attn_mask", attn_mask) | |
x = torch.permute(x, [0, 2, 3, 4, 1]) # x = rearrange(x, 'b c d h w -> b d h w c') | |
for blk in self.blocks: | |
x = blk(x, self.attn_mask) | |
if self.downsample is not None: | |
x = self.downsample(x) | |
x = torch.permute(x, [0, 4, 1, 2, 3]) # x = rearrange(x, 'b d h w c -> b c d h w') | |
return x | |
class SwinTransformer3D(nn.Module): | |
""" Swin Transformer backbone. | |
A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - | |
https://arxiv.org/pdf/2103.14030 | |
Args: | |
patch_size (int | tuple(int)): Patch size. Default: (4,4,4). | |
in_chans (int): Number of input image channels. Default: 3. | |
embed_dim (int): Number of linear projection output channels. Default: 96. | |
depths (tuple[int]): Depths of each Swin Transformer stage. | |
num_heads (tuple[int]): Number of attention head of each stage. | |
window_size (int): Window size. Default: 7. | |
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. | |
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: Truee | |
qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. | |
drop_rate (float): Dropout rate. | |
attn_drop_rate (float): Attention dropout rate. Default: 0. | |
drop_path_rate (float): Stochastic depth rate. Default: 0.2. | |
norm_layer: Normalization layer. Default: nn.LayerNorm. | |
patch_norm (bool): If True, add normalization after patch embedding. Default: False. | |
frozen_stages (int): Stages to be frozen (stop grad and set eval mode). | |
-1 means not freezing any parameters. | |
""" | |
def __init__(self, | |
patch_size=(2,4,4), | |
in_chans=3, | |
embed_dim=96, | |
depths=[2, 2, 6, 2], | |
num_heads=[3, 6, 12, 24], | |
window_size=(8,7,7), | |
mlp_ratio=4., | |
qkv_bias=True, | |
qk_scale=None, | |
drop_rate=0., | |
attn_drop_rate=0., | |
drop_path_rate=0.2, | |
norm_layer=nn.LayerNorm, | |
patch_norm=True, | |
frozen_stages=-1, | |
use_checkpoint=False): | |
super().__init__() | |
self.num_layers = len(depths) | |
self.embed_dim = embed_dim | |
self.patch_norm = patch_norm | |
self.frozen_stages = frozen_stages | |
self.window_size = window_size | |
self.patch_size = patch_size | |
# split image into non-overlapping patches | |
self.patch_embed = PatchEmbed3D( | |
patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, | |
norm_layer=norm_layer if self.patch_norm else None) | |
self.pos_drop = nn.Dropout(p=drop_rate) | |
# stochastic depth | |
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule | |
# build layers | |
self.layers = nn.ModuleList() | |
for i_layer in range(self.num_layers): | |
layer = BasicLayer( | |
dim=int(embed_dim * 2**i_layer), | |
depth=depths[i_layer], | |
num_heads=num_heads[i_layer], | |
window_size=window_size, | |
mlp_ratio=mlp_ratio, | |
qkv_bias=qkv_bias, | |
qk_scale=qk_scale, | |
drop=drop_rate, | |
attn_drop=attn_drop_rate, | |
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], | |
norm_layer=norm_layer, | |
downsample=PatchMerging if i_layer<self.num_layers-1 else None, | |
use_checkpoint=use_checkpoint) | |
self.layers.append(layer) | |
self.num_features = int(embed_dim * 2**(self.num_layers-1)) | |
self.norm = norm_layer(self.num_features) | |
self._freeze_stages() | |
def _freeze_stages(self): | |
if self.frozen_stages >= 0: | |
self.patch_embed.eval() | |
for param in self.patch_embed.parameters(): | |
param.requires_grad = False | |
if self.frozen_stages >= 1: | |
self.pos_drop.eval() | |
for i in range(0, self.frozen_stages): | |
m = self.layers[i] | |
m.eval() | |
for param in m.parameters(): | |
param.requires_grad = False | |
def train(self, mode=True): | |
"""Convert the model into training mode while keep layers freezed.""" | |
super(SwinTransformer3D, self).train(mode) | |
self._freeze_stages() | |
def forward(self, x): | |
x = self.patch_embed(x) | |
x = self.pos_drop(x) | |
for layer in self.layers: | |
x = layer(x.contiguous()) | |
x = torch.permute(x, [0, 2, 3, 4, 1]) # x = rearrange(x, 'n c d h w -> n d h w c') | |
x = self.norm(x) | |
x = torch.permute(x, [0, 4, 1, 2, 3]) # x = rearrange(x, 'n d h w c -> n c d h w') | |
return x | |
class I3DHead(nn.Module): | |
def __init__(self, | |
num_classes = 400, | |
in_channels = 768, | |
spatial_type='avg', | |
dropout_ratio=0.5): | |
super().__init__() | |
self.num_classes = num_classes | |
self.in_channels = in_channels | |
self.spatial_type = spatial_type | |
self.dropout_ratio = dropout_ratio | |
if self.dropout_ratio != 0: | |
self.dropout = nn.Dropout(p=self.dropout_ratio) | |
else: | |
self.dropout = None | |
self.fc_cls = nn.Linear(self.in_channels, self.num_classes) | |
if self.spatial_type == 'avg': | |
# use `nn.AdaptiveAvgPool3d` to adaptively match the in_channels. | |
self.avg_pool = nn.AdaptiveAvgPool3d((1, 1, 1)) | |
else: | |
self.avg_pool = None | |
def forward(self, x): | |
if self.avg_pool is not None: | |
x = self.avg_pool(x) | |
if self.dropout is not None: | |
x = self.dropout(x) | |
x = x.view(x.shape[0], -1) | |
cls_score = self.fc_cls(x) | |
return cls_score | |
class VideoSwinTransformer(nn.Module): | |
def __init__(self, model_name: str, num_classes: int): | |
super().__init__() | |
self.backbone = SwinTransformer3D() | |
self.cls_head = I3DHead(num_classes=num_classes) | |
pretrained_filename = f'models/pretrained/{model_name}.pth' | |
pretrain_params = torch.load(pretrained_filename)['state_dict'] | |
with torch.no_grad(): | |
for name, param in self.named_parameters(): | |
if name in pretrain_params: | |
if param.shape == pretrain_params[name].shape: | |
param.copy_(pretrain_params[name]) | |
elif 'weight' in name: | |
nn.init.xavier_uniform_(param) | |
print(f'parameter "{name}" is initialized by xavier_uniform_') | |
elif 'bias' in name: | |
nn.init.zeros_(param) | |
print(f'parameter "{name}" is initialized by zeros_') | |
else: | |
raise f'intialization of {name} is failed' | |
def forward(self, x): # B C D Wh Ww | |
x = self.backbone(x) | |
return self.cls_head(x) | |
if __name__ == "__main__": | |
model = VideoSwinTransformer('vswt-tiny', 400) | |
model.train(False) | |
torch.manual_seed(1) | |
x = torch.rand([4, 3, 8, 224, 224], dtype=torch.float32) | |
y = model(x) | |
y = model(x) | |
from array import array | |
with open('debug/x.bin', 'wb') as x_file: | |
float_array = array('f', x.flatten().tolist()) | |
float_array.tofile(x_file) | |
with open('debug/y.bin', 'wb') as y_file: | |
float_array = array('f', y.flatten().tolist()) | |
float_array.tofile(y_file) | |
print(x.view(-1)[:10]) | |
y = y.reshape([16, -1]) | |
y = y[:, 0].tolist() | |
print('\n'.join([str(a) for a in y])) | |
def export_onnx(model, input_shapes, output_names, onnx_filename): | |
model.eval() | |
input_names = list(input_shapes.keys()) | |
sample_inputs = [] | |
for name in input_names: | |
sample_inputs.append(torch.zeros(input_shapes[name], requires_grad=False)) | |
sample_inputs = tuple(sample_inputs) | |
dynamic_axes = {} | |
for name in input_names: | |
dynamic_axes[name] = {0: 'batch'} | |
for name in output_names: | |
dynamic_axes[name] = {0: 'batch'} | |
torch.onnx.export( | |
model=model.cpu(), | |
args=sample_inputs, | |
f=onnx_filename, | |
export_params=True, | |
verbose=False, | |
opset_version=9, | |
input_names=input_names, | |
output_names=output_names, | |
do_constant_folding=True, | |
dynamic_axes=dynamic_axes) | |
export_onnx(model, {'images': [4, 3, 8, 224, 224]}, ['output'], 'debug/swin.onnx') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment