Skip to content

Instantly share code, notes, and snippets.

@KeAWang
Last active October 23, 2023 23:52
Show Gist options
  • Save KeAWang/5093bffdf6fac21ab1adaefd5b7ad9a0 to your computer and use it in GitHub Desktop.
Save KeAWang/5093bffdf6fac21ab1adaefd5b7ad9a0 to your computer and use it in GitHub Desktop.
Temporal Convolutional Network in PyTorch (https://arxiv.org/abs/1803.01271)
import torch
from typing import List
import torch.nn.functional as F
def receptive_field(kernel_size: int, dilation: int):
return 1 + (kernel_size - 1) * dilation
class Seq2SeqConv1d(torch.nn.Module):
""" Pads input so that conv output has the same length as the input
i.e. N x Cin x T -> N x Cout x T
"""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int,
dilation: int = 1,
groups: int = 1,
causal: bool = False,
):
super().__init__()
self.receptive_field = receptive_field(kernel_size, dilation)
padding = self.receptive_field // 2 # Each side
self.conv = torch.nn.Conv1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=1,
padding=padding,
dilation=dilation,
groups=groups,
)
if causal:
self.num_chomp = padding
elif self.receptive_field % 2 == 0:
# With this padding, output size will be
# input_size - (ceil(receptive_field / 2) - floor(receptive_field / 2)) + 1
# If receptive_field is even, output_size = input_size + 1, so we truncate
self.num_chomp = 1
else:
self.num_chomp = 0
def forward(self, x):
out = self.conv(x)
if self.num_chomp > 0:
out = out[..., : -self.num_chomp]
return out
class Seq2SeqConv1dBlock(torch.nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int,
dilation: int,
causal: bool = False,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.dilation = dilation
self.activation = torch.nn.GELU()
self.conv1 = Seq2SeqConv1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
dilation=dilation,
causal=causal,
)
self.conv2 = Seq2SeqConv1d(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=kernel_size,
dilation=dilation,
causal=causal,
)
self.projector = torch.nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=1)
# TODO: should we initialize weights to N(0, 0.01) like in TCN?
# TODO: do we need layer norm?
# TODO: do we need dropout?
def forward(self, x):
res = self.projector(x)
# preactivation residual blocks
out = self.activation(x)
out = self.conv1(out)
out = self.activation(out)
out = self.conv2(out)
out = out + res
return out
class TemporalConvNet(torch.nn.Module):
"""
Expects N x num_timesteps x input_size.
Outputs N x hidden_sizes[-1].
"""
def __init__(
self,
num_timesteps: int,
input_size: int,
hidden_sizes: List[int],
kernel_size: int,
dilation_factor: int = 1,
causal: bool = True,
):
super().__init__()
blocks = [
Seq2SeqConv1dBlock(
in_channels=hidden_sizes[i - 1] if i > 0 else input_size,
out_channels=hidden_sizes[i],
kernel_size=kernel_size,
dilation=int(dilation_factor ** i),
causal=causal,
)
for i in range(len(hidden_sizes))
]
self.blocks = torch.nn.ModuleList(blocks)
self.linear_comb = Seq2SeqConv1d(in_channels=num_timesteps, out_channels=1, kernel_size=1, causal=False)
proj_size = hidden_sizes[-1]
self.projection = torch.nn.Sequential(torch.nn.Linear(proj_size, proj_size),
torch.nn.GELU(),
torch.nn.Linear(proj_size, proj_size)
)
self.num_timesteps = num_timesteps
self.input_size = input_size
self.hidden_sizes = hidden_sizes
def forward(self, x):
assert x.shape[1:] == (self.num_timesteps, self.input_size)
x = x.transpose(1, 2) # N x T x D -> N x D x T
for block in self.blocks:
x = block(x)
x = x.transpose(1, 2) # N x D x T -> N x T x D
x = self.linear_comb(x) # N x T x D -> N x 1 x D
x = x.squeeze(1) # N x 1 x D -> N x D
assert x.shape[1:] == (self.hidden_sizes[-1],)
return x
@KeAWang
Copy link
Author

KeAWang commented Jan 27, 2022

Some notes on generalizing causal convs: pytorch/pytorch#1333

@KeAWang
Copy link
Author

KeAWang commented Oct 23, 2023

Note to self: bug in causal from wrong padding

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment