Created
November 20, 2017 05:23
-
-
Save eickenberg/e341e5a19b857675aa9f0ffaea37d7bf to your computer and use it in GitHub Desktop.
pfs beginning stub
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch import nn | |
from torch.nn import Parameter, functional as F | |
class Convolution2DEnergyTemporalBasis(nn.Module): | |
def __init__(self, n_input_channels, | |
n_filters_simple, | |
n_filters_complex, | |
n_filters_temporal, | |
spatial_kernel_size, | |
n_temporal_frequencies, | |
n_temporal_envelopes, | |
temporal_kernel_size): | |
super(Convolution2DEnergyTemporalBasis, self).__init__() | |
self.n_input_channels = n_input_channels | |
self.n_filters_simple = n_filters_simple | |
self.n_filters_complex = n_filters_complex | |
self.spatial_kernel_size = spatial_kernel_size | |
self.n_temporal_frequencies = n_temporal_frequencies | |
self.n_temporal_envelopes = n_temporal_envelopes | |
self.temporal_kernel_length = temporal_kernel_length | |
self.build() | |
def build(self): | |
self.spatial_conv = nn.Conv2d(self.n_input_channels, | |
2 * (self.n_filters_simple + | |
self.n_filters_complex), | |
self.spatial_kernel_size) | |
# temporal kernel is outer product of envelopes and sinusoids with frequencies | |
self.temporal_conv_kernel_frequencies = Parameter( | |
torch.rand(self.n_temporal_frequencies)) | |
self.temporal_conv_kernel_envelopes = Parameter( | |
torch.rand(self.n_temporal_envelopes, | |
self.temporal_kernel_length)) | |
# specify a range with is reused at every forward by sin and cos | |
# watch out: torch.arange is currently *inclusive* of endpoint | |
# but deprecation warnings of torch.range say that arange is *exclusive* | |
self.t_range = Variable(torch.arange(0., 1., | |
1. / (self.temporal_kernel_length - 1))) | |
def forward(self, x): | |
# Input is 5-dimensional (batch, time_len, n_in_channels, height, width) | |
# conv2d wants 4D input, so collapse batch, time into batch * time | |
batch, time_len, n_in_channels, height, width = x.shape | |
x_reshaped = x.resize_(batch * time_len, n_in_channels, height, width) | |
x_convolved = self.spatial_conv(x_reshaped) | |
# for the temporal convolution we first need to create the kernels | |
# fill in the placeholder variable | |
lf = len(self.temporal_conv_kernel_frequencies) | |
lt = len(self.t_range) | |
le = len(self.temporal_conv_kernel_envelopes) | |
cosines = torch.cos(2 * np.pi * | |
self.temporal_conv_kernel_frequencies.resize_(1, lf, 1) * | |
self.t_range.resize_(1, 1, lt) * | |
self.temporal_conv_kernel_envelopes.resize_(le, 1, 1)) | |
sines = torch.sin(2 * np.pi * | |
self.temporal_conv_kernel_frequencies.resize_(1, lf, 1) * | |
self.t_range.resize_(1, 1, lt) * | |
self.temporal_conv_kernel_envelopes.resize_(le, 1, 1)) | |
temporal_kernel = torch.cat((cosines.resize_(lf * le, lt), | |
sines.resize_(lf * le, lt)), dim=0) | |
# With this temporal kernel of shape (2 * lf * le, lt) we can convolve the spatial | |
# output. But first we need to reshape it to a 3D thing by collapsing space | |
batch_time, n_out_channels, out_height, out_width = x_convolved.shape | |
x_convolved_no_space = x_convolved.resize_(batch_time, n_out_channels, out_height * out_width) | |
# TODO: use F.conv1d to get the temporal aspects |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment