Skip to content

Instantly share code, notes, and snippets.

@orioninthesky98
Last active July 9, 2024 04:52
Show Gist options
  • Save orioninthesky98/d0a987197950bc0b945d28b240d5bc53 to your computer and use it in GitHub Desktop.
Save orioninthesky98/d0a987197950bc0b945d28b240d5bc53 to your computer and use it in GitHub Desktop.
tensorRT wrong outputs
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
d_spectral_norm = True
def _downsample(x, avgpool_kernel):
# Downsample (Mean Avg Pooling with 2x2 kernel)
return nn.AvgPool3d(kernel_size=avgpool_kernel)(x)
class DisBlock0(nn.Module):
"""
DisBlock for encoder network. minor tweaks on activation and pooling layers
"""
def __init__(
self,
in_channels,
out_channels,
ksize=(1, 1, 3),
avgpool_kernel=(1, 1, 2),
activation=nn.LeakyReLU(),
):
super(DisBlock0, self).__init__()
self.activation = activation
self.avgpool_kernel = avgpool_kernel
pad = [int(ks) // 2 for ks in ksize]
self.c1 = nn.Conv3d(in_channels, out_channels, kernel_size=ksize, padding=pad)
self.c2 = nn.Conv3d(out_channels, out_channels, kernel_size=ksize, padding=pad)
self.c_sc = nn.Conv3d(in_channels, out_channels, kernel_size=1, padding=0)
if d_spectral_norm:
self.c1 = nn.utils.spectral_norm(self.c1)
self.c2 = nn.utils.spectral_norm(self.c2)
self.c_sc = nn.utils.spectral_norm(self.c_sc)
def residual(self, x):
h = x
h = self.c1(h)
h = self.activation(h)
h = self.c2(h)
h = _downsample(h, self.avgpool_kernel)
return h
def shortcut(self, x):
return self.c_sc(_downsample(x, self.avgpool_kernel))
def forward(self, x):
return self.residual(x) + self.shortcut(x)
class DisBlock1(nn.Module):
"""
DisBlock for encoder network. w minor tweaks on activation and pooling layers
Args:
hidden_channels: int, specify #channels between two Conv layers
"""
def __init__(
self,
in_channels,
out_channels,
hidden_channels=None,
ksize=(1, 1, 3),
avgpool_kernel=(1, 1, 2),
activation=nn.LeakyReLU(),
downsample=False,
):
super(DisBlock1, self).__init__()
self.activation = activation
self.avgpool_kernel = avgpool_kernel
self.downsample = downsample
self.learnable_sc = (in_channels != out_channels) or downsample
hidden_channels = in_channels if hidden_channels is None else hidden_channels
pad = [int(ks) // 2 for ks in ksize]
self.c1 = nn.Conv3d(in_channels, hidden_channels, kernel_size=ksize, padding=pad)
self.c2 = nn.Conv3d(hidden_channels, out_channels, kernel_size=ksize, padding=pad)
if d_spectral_norm:
self.c1 = nn.utils.spectral_norm(self.c1)
self.c2 = nn.utils.spectral_norm(self.c2)
if self.learnable_sc:
self.c_sc = nn.Conv3d(in_channels, out_channels, kernel_size=1, padding=0)
if d_spectral_norm:
self.c_sc = nn.utils.spectral_norm(self.c_sc)
def residual(self, x):
h = x
h = self.activation(h)
h = self.c1(h)
h = self.activation(h)
h = self.c2(h)
if self.downsample:
h = _downsample(h, self.avgpool_kernel)
return h
def shortcut(self, x):
if self.learnable_sc:
x = self.c_sc(x)
if self.downsample:
return _downsample(x, self.avgpool_kernel)
else:
return x
else:
return x
def forward(self, x):
return self.residual(x) + self.shortcut(x)
class Encoder(nn.Module):
"""
Encoder network.
Input: data samples
Output: latent (mu and logvar)
Args:
zdim: int, length of latent vector
enc_ch: int, number of channels for Conv layer
ksize: tuple, kernel size for Conv layer
avgpool_kernel: tuple, kernel size for Avg pooling layer
"""
def __init__(
self,
ic: int,
zdim: int,
enc_ch: int = 64,
ksize: tuple[int, ...] = (1, 1, 3),
avgpool_kernel: tuple[int, ...] = (1, 1, 2),
activation: nn.Module = nn.LeakyReLU(),
downsample: Optional[list[bool]] = [True, False, False],
num_separate_unconstr_feats: Optional[int] = 0,
):
super(Encoder, self).__init__()
if downsample is not None:
assert len(downsample) == 3
else:
downsample = [True, False, False]
assert zdim > 0, f"zdim must be positive, got {zdim}!"
assert isinstance(zdim, int), f"zdim must be an integer, got {zdim}!"
self.zdim = zdim
self.ch = enc_ch
self.activation = activation
# NOTE: DisBlock0 will always downsample with AvgPooling kernel
self.block1 = DisBlock0(ic, self.ch, ksize=ksize, avgpool_kernel=avgpool_kernel)
self.block2 = DisBlock1(
self.ch,
self.ch,
ksize=ksize,
avgpool_kernel=avgpool_kernel,
activation=activation,
downsample=downsample[0],
)
self.block3 = DisBlock1(
self.ch,
self.ch,
ksize=ksize,
avgpool_kernel=avgpool_kernel,
activation=activation,
downsample=downsample[1], # was True
)
# NOTE: assigning here breaks compat with prev saved weights (just do strict = False)
# self.first_block = nn.Sequential(*[self.block1, self.block2, self.block3])
self.block4 = DisBlock1(
self.ch,
self.ch,
ksize=ksize,
avgpool_kernel=avgpool_kernel,
activation=activation,
downsample=downsample[2], # was True
)
# we may want a separate linear layer to specially predict unconstrained (scale) feats
if num_separate_unconstr_feats is None:
# if not specified, set to 0
num_separate_unconstr_feats = 0
assert isinstance(num_separate_unconstr_feats, int)
self.num_sep_unconstr_feats = num_separate_unconstr_feats
# final linear layers after convolution layers
# l5 is to predict latent mu, l6 is to predict latent logvar
self.l5 = nn.Linear(self.ch, self.zdim, bias=False)
self.l6 = nn.Linear(self.ch, self.zdim, bias=False)
self.l5 = nn.utils.spectral_norm(self.l5)
self.l6 = nn.utils.spectral_norm(self.l6)
# alternative behavior: separate linear layer to specially predict unconstrained (scale) feats
if self.num_sep_unconstr_feats > 0:
self.l5_unconstr = nn.Linear(self.ch, self.num_sep_unconstr_feats, bias=False)
self.l6_unconstr = nn.Linear(self.ch, self.num_sep_unconstr_feats, bias=False)
self.l5_unconstr = nn.utils.spectral_norm(self.l5_unconstr)
self.l6_unconstr = nn.utils.spectral_norm(self.l6_unconstr)
def forward(self, x: torch.Tensor, original_batch_size: int = 0, trt_compat_mode: bool = False):
h = x
h = self.block1(h)
h = self.block2(h)
h = self.block3(h)
h = self.block4(h)
h = self.activation(h)
# ^ shape of h: [bsz, enc_ch, il, cl, z_downsampled], eg if 4x down, then 12 --> 3
# v Global sum pooling
h = h.sum(2).sum(2).sum(2) # repeatedly sum over il, cl, z_downsampled
# ^ shape of h: [bsz, enc_ch], NB: for sparse models, bsz is actly bsz * num_latents with zdim = 1
# for TensorRT, we will handle inv vs unconstr features in outer model's forward()
if trt_compat_mode:
return h
# alternative behavior, NOTE: NOT compatible with tensorRT
# want to predict unconstrained feats separately from inv feats
# first, we need to split the batch into 2 parts
# our batch consists of: part 1) original_batch_size * num_inv_feats + part 2) original_batch_size
h_inv, h_unconstr = h[:-original_batch_size], h[-original_batch_size:]
mu_inv, logvar_inv = self.l5(h_inv), self.l6(h_inv)
# ^ shape of mu_inv is: [orig_bsz * num_inv_feats, 1]
mu_unconstr, logvar_unconstr = self.l5_unconstr(h_unconstr), self.l6_unconstr(h_unconstr)
# ^ shape of mu_unconstr is: [orig_bsz, num_sep_unconstr_feats]
# reshape from [orig_bsz * num_inv_feats, 1] --> [orig_bsz, num_inv_feats]
mu_inv = mu_inv.view(original_batch_size, -1)
logvar_inv = logvar_inv.view(original_batch_size, -1)
# finally, combine the 2 latent vectors --> [orig_bsz, num_inv_feats + num_sep_unconstr_feats]
mu = torch.cat((mu_inv, mu_unconstr), dim=1)
logvar = torch.cat((logvar_inv, logvar_unconstr), dim=1)
return mu, logvar
class FinalEncoder(nn.Module):
def __init__(
self,
zdim: int = 4,
input_dim: tuple[int, ...] = (1, 1, 1, 40),
start: int = 0,
end: int = 40,
enc_ch: int = 64,
ksize: tuple[int, ...] = (1, 1, 3),
avgpool_kernel: tuple[int, ...] = (1, 1, 2),
activation: nn.Module = nn.LeakyReLU(),
row_normalize: bool = False,
downsample: list[bool] = [True, False, False],
num_inv_feats: int = 3,
trt_compat_mode: bool = True,
):
super(FinalEncoder, self).__init__()
self.trt_compat_mode = trt_compat_mode
self.zdim = zdim
self.ic = input_dim[-2]
self.input_dim = input_dim[-1]
# valid angles and padding
# sample_net_input_end[-1] - sample_net_input_orig[-1]
self.w_dim = end - start
self.start = start
self.end = end
self.padding = nn.ConstantPad1d((start, self.input_dim - end), 0)
# this particular architecture only makes sense for invariant learning
# for non-invariant learning, we can just use the original SpikeEncoder.
assert num_inv_feats > 0, f"num_inv_feats must be positive, got {num_inv_feats}!"
assert isinstance(num_inv_feats, int)
self.num_inv_feats = num_inv_feats
self.num_unconstr_feats = self.zdim - self.num_inv_feats
self.inv_end_idx: int = self.zdim if self.num_inv_feats == 0 else self.num_inv_feats
# zdim is still fixed to 1, this is for invariant features
# however, for unconstrained features, zdim = self.num_unconstr_feats
# because we set num_separate_unconstr_feats
self.encoder = Encoder(
ic=input_dim[0],
zdim=1,
enc_ch=enc_ch,
ksize=ksize,
avgpool_kernel=avgpool_kernel,
activation=activation,
downsample=downsample,
num_separate_unconstr_feats=self.num_unconstr_feats,
)
# sparsity non-learnable parameters
self.row_normalize = row_normalize
self.lambda0 = 10
self.lambda1 = 0.1
self.a = 1.0
self.b = zdim
# sparsity learnable parameters
# Ws can be updated via backprop, not only by sparse loss but also all other losses
self.Ws = nn.Parameter(torch.randn((self.num_inv_feats, self.w_dim), dtype=torch.float))
# p_star & thetas are updated manually, not via backprop
self.p_star = nn.Parameter(
(0.5 * torch.ones(self.num_inv_feats, self.w_dim, dtype=torch.float)),
requires_grad=False,
)
self.thetas = nn.Parameter(torch.rand(self.w_dim, dtype=torch.float), requires_grad=False)
def forward(self, x: torch.Tensor):
bs = x.shape[0]
# add padding left and right to the learnt matrix Ws
# expands the last dim from self.w_dim to self.input_dim
mask = self.padding(self.get_encoder_mask())
# "mask" the input tensor x using learnt matrix Ws (zero-padded)
# mul here is not matrix-mul, but elementwise mul with broadcast
# expands to output shape: [bs, num_inv_feats, ic, input_dim]
masked_input = torch.mul(
x.view(bs, 1, self.ic, self.input_dim),
mask.view(1, self.num_inv_feats, 1, self.input_dim),
)
# TensorRT does not allow bsz (1st dim) to change
# so, we split masked_input into num_inv_feats groups (of size bsz each) and run forward on each group
if self.trt_compat_mode:
inv_mus = []
inv_logvars = []
for i in range(self.num_inv_feats):
masked_input = masked_input.clone()
curr_input = torch.index_select(masked_input, dim=1, index=torch.tensor([i], device=masked_input.device)) # dtype=torch.long
curr_input = curr_input.view(bs, 1, 1, self.ic, self.input_dim) # .contiguous()
h = self.encoder(curr_input, trt_compat_mode=True)
inv_mu, inv_logvar = self.encoder.l5(h), self.encoder.l6(h)
inv_mus.append(inv_mu)
inv_logvars.append(inv_logvar)
# inv_mus = torch.cat(inv_mus, dim=1) # [bsz, 1] * num_inv_feats --> [bsz, num_inv_feats]
inv_logvars = torch.cat(inv_logvars, dim=1)
# for unconstrained features, we don't mask inputs
unmasked_input = x.view(bs, 1, 1, self.ic, self.input_dim) / self.w_dim
h = self.encoder(unmasked_input, trt_compat_mode=True)
unconstr_mu, unconstr_logvar = self.encoder.l5_unconstr(h), self.encoder.l6_unconstr(h)
# combine all features
mu = torch.cat((inv_mus, unconstr_mu), dim=1) # output shape: [bsz, num_latents]
logvar = torch.cat((inv_logvars, unconstr_logvar), dim=1)
return mu, logvar
# for each original sample, there are num_inv_feats exact copies of it
masked_input = masked_input.view(bs * self.num_inv_feats, 1, 1, self.ic, self.input_dim)
unmasked_input = x.view(bs, 1, 1, self.ic, self.input_dim) / self.w_dim
# combine both inputs along batch dim --> [(bsz * self.num_inv_feats) + bsz, 1, ic, input_dim]
combined_input = torch.cat((masked_input, unmasked_input), dim=0)
mu, logvar = self.encoder(combined_input, original_batch_size=bs, trt_compat_mode=False)
return mu, logvar
def get_encoder_mask(self):
return F.normalize(self.Ws.abs(), p=1, dim=-1)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment