Last active May 8, 2023 08:10
InceptionTimePlus from tsai modified to be causal
Modified from
I've added my favourite modifications
- causal padding in the conv
- causal padding the max pool
- coord=True by default
- dilation so it can work over larger context lengths
- specify kernels to be [3, 13, 39] instead of the default [13, 19, 39] or whatver. This makes sure we have a small kernel as well which can help performance.
Note that this uses dilation like wavenet. So the receptive field is of a network with n layers, and a kernal size k, is
So the total size of the receptive field r of a TCN with dilation base b, kernel size k with k ≥ b and number of residual blocks n can be computed as
$r = 2 * (k-1) * (b**n - 1)/(b-1)$
from tsai.models.InceptionTimePlus import Conv, Module, noop, Integral, nn, is_listy, SimpleSelfAttention, Concat, SqueezeExciteBlock, Norm, BN1d, delegates, ConvBlock, Add, np, random, ifnone, OrderedDict, Flatten, SigmoidRange, LinBnDrop, GACP1d, GAP1d, named_partial, F, torch, CausalConv1d, Noop
Conv = named_partial('Conv', ConvBlock, norm=None, act=None, padding='causal')
CausalConvBlock = named_partial('CausalConv', ConvBlock, padding='causal')
class CausalMaxPool1d(torch.nn.MaxPool1d):
def __init__(self, ks, stride=1, padding=0, dilation=1):
super().__init__(kernel_size=ks, stride=stride, padding=0, dilation=dilation)
self.__padding = (ks - 1) * dilation
def forward(self, input):
return super().forward(F.pad(input, (self.__padding, 0)))
class CausalInceptionModulePlus(Module):
def __init__(self, ni, nf, ks=[3, 13, 39], bottleneck=True, padding='causal', coord=True, separable=False, dilation=1, stride=1, conv_dropout=0., sa=False, se=None,
norm='Batch', zero_norm=False, bn_1st=True, act=nn.ReLU, act_kwargs={}):
dilation = max(1, dilation)
if not (is_listy(ks) and len(ks) == 3):
if isinstance(ks, Integral): ks = [ks // (2**i) for i in range(3)]
ks = [ksi if ksi % 2 != 0 else ksi - 1 for ksi in ks] # ensure odd ks for padding='same'
bottleneck = False if ni == nf else bottleneck
self.bottleneck = Conv(ni, nf, 1, coord=coord, bias=False) if bottleneck else noop #
self.convs = nn.ModuleList()
for i in range(len(ks)): self.convs.append(Conv(nf if bottleneck else ni, nf, ks[i], padding=padding, coord=coord, separable=separable,
dilation=dilation**i, stride=stride, bias=False))
self.mp_conv = nn.Sequential(*[Conv(ni, nf, 1, coord=coord, bias=False)])
self.concat = Concat()
if norm is not None:
self.norm = Norm(nf * 4, norm=norm, zero_norm=zero_norm)
self.norm = noop
self.conv_dropout = nn.Dropout(conv_dropout) if conv_dropout else noop = SimpleSelfAttention(nf * 4) if sa else noop
self.act = act(**act_kwargs) if act else noop = nn.Sequential(SqueezeExciteBlock(nf * 4, reduction=se), BN1d(nf * 4)) if se else noop
def _init_cnn(self, m):
if getattr(self, 'bias', None) is not None: nn.init.constant_(self.bias, 0)
if isinstance(self, (nn.Conv1d,nn.Conv2d,nn.Conv3d,nn.Linear)): nn.init.kaiming_normal_(self.weight)
for l in m.children(): self._init_cnn(l)
def forward(self, x):
input_tensor = x
x = self.bottleneck(x)
x = self.concat([l(x) for l in self.convs] + [self.mp_conv(input_tensor)])
x = self.norm(x)
x = self.conv_dropout(x)
x =
x = self.act(x)
x =
return x
class CausalInceptionBlockPlus(Module):
def __init__(self, ni, nf, residual=True, depth=6, coord=False, norm=None, zero_norm=False, act=nn.ReLU, act_kwargs={}, sa=False, se=None, dilation=1,
stoch_depth=1., **kwargs):
self.residual, self.depth = residual, depth
self.inception, self.shortcut, self.act = nn.ModuleList(), nn.ModuleList(), nn.ModuleList()
for d in range(depth):
self.inception.append(InceptionModulePlus(ni if d == 0 else nf * 4, nf, coord=coord, norm=norm,
zero_norm=zero_norm if d % 3 == 2 else False,
act=act if d % 3 != 2 else None, act_kwargs=act_kwargs,
sa=sa if d % 3 == 2 else False,
se=se if d % 3 != 2 else None,
if self.residual and d % 3 == 2:
n_in, n_out = ni if d == 2 else nf * 4, nf * 4
if norm is not None:
n = Norm(n_in, norm=norm)
n = Noop
self.shortcut.append(n if n_in == n_out else CausalConvBlock(n_in, n_out, 1, coord=coord, bias=False, norm=norm, padding='causal', act=None))
self.add = Add()
if stoch_depth != 0: keep_prob = np.linspace(1, stoch_depth, depth)
else: keep_prob = np.array([1] * depth)
self.keep_prob = keep_prob
def forward(self, x):
res = x
for i in range(self.depth):
if self.keep_prob[i] > random.random() or not
x = self.inception[i](x)
if self.residual and i % 3 == 2:
res = x = self.act[i//3](self.add(x, self.shortcut[i//3](res)))
return x
# Cell
class CausalInceptionTimePlus(nn.Sequential):
def __init__(self, c_in, c_out, seq_len=None, nf=32, nb_filters=None,
flatten=False, concat_pool=False, fc_dropout=0., bn=False, y_range=None, custom_head=None, **kwargs):
if nb_filters is not None: nf = nb_filters
else: nf = ifnone(nf, nb_filters) # for compatibility
backbone = CausalInceptionBlockPlus(c_in, nf, **kwargs)
self.head_nf = nf * 4
self.c_out = c_out
self.seq_len = seq_len
if custom_head: head = custom_head(self.head_nf, c_out, seq_len)
else: head = self.create_head(self.head_nf, c_out, seq_len, flatten=flatten, concat_pool=concat_pool,
fc_dropout=fc_dropout, bn=bn, y_range=y_range)
layers = OrderedDict([('backbone', nn.Sequential(backbone)), ('head', nn.Sequential(head))])
self.calc_receptive_field(kwargs.get('ks'), kwargs.get('depth'), kwargs.get('dilation', 1))
def calc_receptive_field(self, ks, depth, dilation):
# receptive fields vs R
d=np.array([dilation**i for i in range(3)])
rf = (ks-1)*d*depth
dilations = np.array([max(1, d*dilation) for d in range(depth)])
d=np.array([dilations**i for i in range(3)]).T
rf = ((ks-1)*d).sum(0)
print(f"receptive field {rf}={ks-1}*{d}")
def create_head(self, nf, c_out, seq_len, flatten=False, concat_pool=False, fc_dropout=0., bn=False, y_range=None):
if flatten:
nf *= seq_len
layers = [Flatten()]
if concat_pool: nf *= 2
layers = [GACP1d(1) if concat_pool else GAP1d(1)]
layers += [LinBnDrop(nf, c_out, bn=bn, p=fc_dropout)]
if y_range: layers += [SigmoidRange(*y_range)]
return nn.Sequential(*layers)
