Skip to content

Instantly share code, notes, and snippets.

@eickenberg
Created November 20, 2017 05:23
Show Gist options
  • Save eickenberg/e341e5a19b857675aa9f0ffaea37d7bf to your computer and use it in GitHub Desktop.
Save eickenberg/e341e5a19b857675aa9f0ffaea37d7bf to your computer and use it in GitHub Desktop.
pfs beginning stub
import torch
from torch import nn
from torch.nn import Parameter, functional as F
class Convolution2DEnergyTemporalBasis(nn.Module):
def __init__(self, n_input_channels,
n_filters_simple,
n_filters_complex,
n_filters_temporal,
spatial_kernel_size,
n_temporal_frequencies,
n_temporal_envelopes,
temporal_kernel_size):
super(Convolution2DEnergyTemporalBasis, self).__init__()
self.n_input_channels = n_input_channels
self.n_filters_simple = n_filters_simple
self.n_filters_complex = n_filters_complex
self.spatial_kernel_size = spatial_kernel_size
self.n_temporal_frequencies = n_temporal_frequencies
self.n_temporal_envelopes = n_temporal_envelopes
self.temporal_kernel_length = temporal_kernel_length
self.build()
def build(self):
self.spatial_conv = nn.Conv2d(self.n_input_channels,
2 * (self.n_filters_simple +
self.n_filters_complex),
self.spatial_kernel_size)
# temporal kernel is outer product of envelopes and sinusoids with frequencies
self.temporal_conv_kernel_frequencies = Parameter(
torch.rand(self.n_temporal_frequencies))
self.temporal_conv_kernel_envelopes = Parameter(
torch.rand(self.n_temporal_envelopes,
self.temporal_kernel_length))
# specify a range with is reused at every forward by sin and cos
# watch out: torch.arange is currently *inclusive* of endpoint
# but deprecation warnings of torch.range say that arange is *exclusive*
self.t_range = Variable(torch.arange(0., 1.,
1. / (self.temporal_kernel_length - 1)))
def forward(self, x):
# Input is 5-dimensional (batch, time_len, n_in_channels, height, width)
# conv2d wants 4D input, so collapse batch, time into batch * time
batch, time_len, n_in_channels, height, width = x.shape
x_reshaped = x.resize_(batch * time_len, n_in_channels, height, width)
x_convolved = self.spatial_conv(x_reshaped)
# for the temporal convolution we first need to create the kernels
# fill in the placeholder variable
lf = len(self.temporal_conv_kernel_frequencies)
lt = len(self.t_range)
le = len(self.temporal_conv_kernel_envelopes)
cosines = torch.cos(2 * np.pi *
self.temporal_conv_kernel_frequencies.resize_(1, lf, 1) *
self.t_range.resize_(1, 1, lt) *
self.temporal_conv_kernel_envelopes.resize_(le, 1, 1))
sines = torch.sin(2 * np.pi *
self.temporal_conv_kernel_frequencies.resize_(1, lf, 1) *
self.t_range.resize_(1, 1, lt) *
self.temporal_conv_kernel_envelopes.resize_(le, 1, 1))
temporal_kernel = torch.cat((cosines.resize_(lf * le, lt),
sines.resize_(lf * le, lt)), dim=0)
# With this temporal kernel of shape (2 * lf * le, lt) we can convolve the spatial
# output. But first we need to reshape it to a 3D thing by collapsing space
batch_time, n_out_channels, out_height, out_width = x_convolved.shape
x_convolved_no_space = x_convolved.resize_(batch_time, n_out_channels, out_height * out_width)
# TODO: use F.conv1d to get the temporal aspects
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment