-
-
Save orioninthesky98/d0a987197950bc0b945d28b240d5bc53 to your computer and use it in GitHub Desktop.
tensorRT wrong outputs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Optional | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
d_spectral_norm = True | |
def _downsample(x, avgpool_kernel): | |
# Downsample (Mean Avg Pooling with 2x2 kernel) | |
return nn.AvgPool3d(kernel_size=avgpool_kernel)(x) | |
class DisBlock0(nn.Module): | |
""" | |
DisBlock for encoder network. minor tweaks on activation and pooling layers | |
""" | |
def __init__( | |
self, | |
in_channels, | |
out_channels, | |
ksize=(1, 1, 3), | |
avgpool_kernel=(1, 1, 2), | |
activation=nn.LeakyReLU(), | |
): | |
super(DisBlock0, self).__init__() | |
self.activation = activation | |
self.avgpool_kernel = avgpool_kernel | |
pad = [int(ks) // 2 for ks in ksize] | |
self.c1 = nn.Conv3d(in_channels, out_channels, kernel_size=ksize, padding=pad) | |
self.c2 = nn.Conv3d(out_channels, out_channels, kernel_size=ksize, padding=pad) | |
self.c_sc = nn.Conv3d(in_channels, out_channels, kernel_size=1, padding=0) | |
if d_spectral_norm: | |
self.c1 = nn.utils.spectral_norm(self.c1) | |
self.c2 = nn.utils.spectral_norm(self.c2) | |
self.c_sc = nn.utils.spectral_norm(self.c_sc) | |
def residual(self, x): | |
h = x | |
h = self.c1(h) | |
h = self.activation(h) | |
h = self.c2(h) | |
h = _downsample(h, self.avgpool_kernel) | |
return h | |
def shortcut(self, x): | |
return self.c_sc(_downsample(x, self.avgpool_kernel)) | |
def forward(self, x): | |
return self.residual(x) + self.shortcut(x) | |
class DisBlock1(nn.Module): | |
""" | |
DisBlock for encoder network. w minor tweaks on activation and pooling layers | |
Args: | |
hidden_channels: int, specify #channels between two Conv layers | |
""" | |
def __init__( | |
self, | |
in_channels, | |
out_channels, | |
hidden_channels=None, | |
ksize=(1, 1, 3), | |
avgpool_kernel=(1, 1, 2), | |
activation=nn.LeakyReLU(), | |
downsample=False, | |
): | |
super(DisBlock1, self).__init__() | |
self.activation = activation | |
self.avgpool_kernel = avgpool_kernel | |
self.downsample = downsample | |
self.learnable_sc = (in_channels != out_channels) or downsample | |
hidden_channels = in_channels if hidden_channels is None else hidden_channels | |
pad = [int(ks) // 2 for ks in ksize] | |
self.c1 = nn.Conv3d(in_channels, hidden_channels, kernel_size=ksize, padding=pad) | |
self.c2 = nn.Conv3d(hidden_channels, out_channels, kernel_size=ksize, padding=pad) | |
if d_spectral_norm: | |
self.c1 = nn.utils.spectral_norm(self.c1) | |
self.c2 = nn.utils.spectral_norm(self.c2) | |
if self.learnable_sc: | |
self.c_sc = nn.Conv3d(in_channels, out_channels, kernel_size=1, padding=0) | |
if d_spectral_norm: | |
self.c_sc = nn.utils.spectral_norm(self.c_sc) | |
def residual(self, x): | |
h = x | |
h = self.activation(h) | |
h = self.c1(h) | |
h = self.activation(h) | |
h = self.c2(h) | |
if self.downsample: | |
h = _downsample(h, self.avgpool_kernel) | |
return h | |
def shortcut(self, x): | |
if self.learnable_sc: | |
x = self.c_sc(x) | |
if self.downsample: | |
return _downsample(x, self.avgpool_kernel) | |
else: | |
return x | |
else: | |
return x | |
def forward(self, x): | |
return self.residual(x) + self.shortcut(x) | |
class Encoder(nn.Module): | |
""" | |
Encoder network. | |
Input: data samples | |
Output: latent (mu and logvar) | |
Args: | |
zdim: int, length of latent vector | |
enc_ch: int, number of channels for Conv layer | |
ksize: tuple, kernel size for Conv layer | |
avgpool_kernel: tuple, kernel size for Avg pooling layer | |
""" | |
def __init__( | |
self, | |
ic: int, | |
zdim: int, | |
enc_ch: int = 64, | |
ksize: tuple[int, ...] = (1, 1, 3), | |
avgpool_kernel: tuple[int, ...] = (1, 1, 2), | |
activation: nn.Module = nn.LeakyReLU(), | |
downsample: Optional[list[bool]] = [True, False, False], | |
num_separate_unconstr_feats: Optional[int] = 0, | |
): | |
super(Encoder, self).__init__() | |
if downsample is not None: | |
assert len(downsample) == 3 | |
else: | |
downsample = [True, False, False] | |
assert zdim > 0, f"zdim must be positive, got {zdim}!" | |
assert isinstance(zdim, int), f"zdim must be an integer, got {zdim}!" | |
self.zdim = zdim | |
self.ch = enc_ch | |
self.activation = activation | |
# NOTE: DisBlock0 will always downsample with AvgPooling kernel | |
self.block1 = DisBlock0(ic, self.ch, ksize=ksize, avgpool_kernel=avgpool_kernel) | |
self.block2 = DisBlock1( | |
self.ch, | |
self.ch, | |
ksize=ksize, | |
avgpool_kernel=avgpool_kernel, | |
activation=activation, | |
downsample=downsample[0], | |
) | |
self.block3 = DisBlock1( | |
self.ch, | |
self.ch, | |
ksize=ksize, | |
avgpool_kernel=avgpool_kernel, | |
activation=activation, | |
downsample=downsample[1], # was True | |
) | |
# NOTE: assigning here breaks compat with prev saved weights (just do strict = False) | |
# self.first_block = nn.Sequential(*[self.block1, self.block2, self.block3]) | |
self.block4 = DisBlock1( | |
self.ch, | |
self.ch, | |
ksize=ksize, | |
avgpool_kernel=avgpool_kernel, | |
activation=activation, | |
downsample=downsample[2], # was True | |
) | |
# we may want a separate linear layer to specially predict unconstrained (scale) feats | |
if num_separate_unconstr_feats is None: | |
# if not specified, set to 0 | |
num_separate_unconstr_feats = 0 | |
assert isinstance(num_separate_unconstr_feats, int) | |
self.num_sep_unconstr_feats = num_separate_unconstr_feats | |
# final linear layers after convolution layers | |
# l5 is to predict latent mu, l6 is to predict latent logvar | |
self.l5 = nn.Linear(self.ch, self.zdim, bias=False) | |
self.l6 = nn.Linear(self.ch, self.zdim, bias=False) | |
self.l5 = nn.utils.spectral_norm(self.l5) | |
self.l6 = nn.utils.spectral_norm(self.l6) | |
# alternative behavior: separate linear layer to specially predict unconstrained (scale) feats | |
if self.num_sep_unconstr_feats > 0: | |
self.l5_unconstr = nn.Linear(self.ch, self.num_sep_unconstr_feats, bias=False) | |
self.l6_unconstr = nn.Linear(self.ch, self.num_sep_unconstr_feats, bias=False) | |
self.l5_unconstr = nn.utils.spectral_norm(self.l5_unconstr) | |
self.l6_unconstr = nn.utils.spectral_norm(self.l6_unconstr) | |
def forward(self, x: torch.Tensor, original_batch_size: int = 0, trt_compat_mode: bool = False): | |
h = x | |
h = self.block1(h) | |
h = self.block2(h) | |
h = self.block3(h) | |
h = self.block4(h) | |
h = self.activation(h) | |
# ^ shape of h: [bsz, enc_ch, il, cl, z_downsampled], eg if 4x down, then 12 --> 3 | |
# v Global sum pooling | |
h = h.sum(2).sum(2).sum(2) # repeatedly sum over il, cl, z_downsampled | |
# ^ shape of h: [bsz, enc_ch], NB: for sparse models, bsz is actly bsz * num_latents with zdim = 1 | |
# for TensorRT, we will handle inv vs unconstr features in outer model's forward() | |
if trt_compat_mode: | |
return h | |
# alternative behavior, NOTE: NOT compatible with tensorRT | |
# want to predict unconstrained feats separately from inv feats | |
# first, we need to split the batch into 2 parts | |
# our batch consists of: part 1) original_batch_size * num_inv_feats + part 2) original_batch_size | |
h_inv, h_unconstr = h[:-original_batch_size], h[-original_batch_size:] | |
mu_inv, logvar_inv = self.l5(h_inv), self.l6(h_inv) | |
# ^ shape of mu_inv is: [orig_bsz * num_inv_feats, 1] | |
mu_unconstr, logvar_unconstr = self.l5_unconstr(h_unconstr), self.l6_unconstr(h_unconstr) | |
# ^ shape of mu_unconstr is: [orig_bsz, num_sep_unconstr_feats] | |
# reshape from [orig_bsz * num_inv_feats, 1] --> [orig_bsz, num_inv_feats] | |
mu_inv = mu_inv.view(original_batch_size, -1) | |
logvar_inv = logvar_inv.view(original_batch_size, -1) | |
# finally, combine the 2 latent vectors --> [orig_bsz, num_inv_feats + num_sep_unconstr_feats] | |
mu = torch.cat((mu_inv, mu_unconstr), dim=1) | |
logvar = torch.cat((logvar_inv, logvar_unconstr), dim=1) | |
return mu, logvar | |
class FinalEncoder(nn.Module): | |
def __init__( | |
self, | |
zdim: int = 4, | |
input_dim: tuple[int, ...] = (1, 1, 1, 40), | |
start: int = 0, | |
end: int = 40, | |
enc_ch: int = 64, | |
ksize: tuple[int, ...] = (1, 1, 3), | |
avgpool_kernel: tuple[int, ...] = (1, 1, 2), | |
activation: nn.Module = nn.LeakyReLU(), | |
row_normalize: bool = False, | |
downsample: list[bool] = [True, False, False], | |
num_inv_feats: int = 3, | |
trt_compat_mode: bool = True, | |
): | |
super(FinalEncoder, self).__init__() | |
self.trt_compat_mode = trt_compat_mode | |
self.zdim = zdim | |
self.ic = input_dim[-2] | |
self.input_dim = input_dim[-1] | |
# valid angles and padding | |
# sample_net_input_end[-1] - sample_net_input_orig[-1] | |
self.w_dim = end - start | |
self.start = start | |
self.end = end | |
self.padding = nn.ConstantPad1d((start, self.input_dim - end), 0) | |
# this particular architecture only makes sense for invariant learning | |
# for non-invariant learning, we can just use the original SpikeEncoder. | |
assert num_inv_feats > 0, f"num_inv_feats must be positive, got {num_inv_feats}!" | |
assert isinstance(num_inv_feats, int) | |
self.num_inv_feats = num_inv_feats | |
self.num_unconstr_feats = self.zdim - self.num_inv_feats | |
self.inv_end_idx: int = self.zdim if self.num_inv_feats == 0 else self.num_inv_feats | |
# zdim is still fixed to 1, this is for invariant features | |
# however, for unconstrained features, zdim = self.num_unconstr_feats | |
# because we set num_separate_unconstr_feats | |
self.encoder = Encoder( | |
ic=input_dim[0], | |
zdim=1, | |
enc_ch=enc_ch, | |
ksize=ksize, | |
avgpool_kernel=avgpool_kernel, | |
activation=activation, | |
downsample=downsample, | |
num_separate_unconstr_feats=self.num_unconstr_feats, | |
) | |
# sparsity non-learnable parameters | |
self.row_normalize = row_normalize | |
self.lambda0 = 10 | |
self.lambda1 = 0.1 | |
self.a = 1.0 | |
self.b = zdim | |
# sparsity learnable parameters | |
# Ws can be updated via backprop, not only by sparse loss but also all other losses | |
self.Ws = nn.Parameter(torch.randn((self.num_inv_feats, self.w_dim), dtype=torch.float)) | |
# p_star & thetas are updated manually, not via backprop | |
self.p_star = nn.Parameter( | |
(0.5 * torch.ones(self.num_inv_feats, self.w_dim, dtype=torch.float)), | |
requires_grad=False, | |
) | |
self.thetas = nn.Parameter(torch.rand(self.w_dim, dtype=torch.float), requires_grad=False) | |
def forward(self, x: torch.Tensor): | |
bs = x.shape[0] | |
# add padding left and right to the learnt matrix Ws | |
# expands the last dim from self.w_dim to self.input_dim | |
mask = self.padding(self.get_encoder_mask()) | |
# "mask" the input tensor x using learnt matrix Ws (zero-padded) | |
# mul here is not matrix-mul, but elementwise mul with broadcast | |
# expands to output shape: [bs, num_inv_feats, ic, input_dim] | |
masked_input = torch.mul( | |
x.view(bs, 1, self.ic, self.input_dim), | |
mask.view(1, self.num_inv_feats, 1, self.input_dim), | |
) | |
# TensorRT does not allow bsz (1st dim) to change | |
# so, we split masked_input into num_inv_feats groups (of size bsz each) and run forward on each group | |
if self.trt_compat_mode: | |
inv_mus = [] | |
inv_logvars = [] | |
for i in range(self.num_inv_feats): | |
masked_input = masked_input.clone() | |
curr_input = torch.index_select(masked_input, dim=1, index=torch.tensor([i], device=masked_input.device)) # dtype=torch.long | |
curr_input = curr_input.view(bs, 1, 1, self.ic, self.input_dim) # .contiguous() | |
h = self.encoder(curr_input, trt_compat_mode=True) | |
inv_mu, inv_logvar = self.encoder.l5(h), self.encoder.l6(h) | |
inv_mus.append(inv_mu) | |
inv_logvars.append(inv_logvar) | |
# inv_mus = torch.cat(inv_mus, dim=1) # [bsz, 1] * num_inv_feats --> [bsz, num_inv_feats] | |
inv_logvars = torch.cat(inv_logvars, dim=1) | |
# for unconstrained features, we don't mask inputs | |
unmasked_input = x.view(bs, 1, 1, self.ic, self.input_dim) / self.w_dim | |
h = self.encoder(unmasked_input, trt_compat_mode=True) | |
unconstr_mu, unconstr_logvar = self.encoder.l5_unconstr(h), self.encoder.l6_unconstr(h) | |
# combine all features | |
mu = torch.cat((inv_mus, unconstr_mu), dim=1) # output shape: [bsz, num_latents] | |
logvar = torch.cat((inv_logvars, unconstr_logvar), dim=1) | |
return mu, logvar | |
# for each original sample, there are num_inv_feats exact copies of it | |
masked_input = masked_input.view(bs * self.num_inv_feats, 1, 1, self.ic, self.input_dim) | |
unmasked_input = x.view(bs, 1, 1, self.ic, self.input_dim) / self.w_dim | |
# combine both inputs along batch dim --> [(bsz * self.num_inv_feats) + bsz, 1, ic, input_dim] | |
combined_input = torch.cat((masked_input, unmasked_input), dim=0) | |
mu, logvar = self.encoder(combined_input, original_batch_size=bs, trt_compat_mode=False) | |
return mu, logvar | |
def get_encoder_mask(self): | |
return F.normalize(self.Ws.abs(), p=1, dim=-1) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment