Created
November 26, 2018 21:36
-
-
Save santisy/2014c5ffd1618e0feddf352a456c6951 to your computer and use it in GitHub Desktop.
Question about the typing in torch.jit.script
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import torch.nn.init as init | |
import torch.nn.functional as F | |
class hGRUKernel(nn.Module): | |
''' | |
initializer: the one used to initialize the weights | |
gain: xavier gain | |
''' | |
def __init__(self, channel_size, time_step, | |
non_linear=F.relu, | |
kernel_size=7, padding=3, | |
initializer=init.xavier_normal_, | |
gain=1): | |
super(hGRUKernel, self).__init__() | |
g_size = (1, channel_size, 1, 1) | |
c_size = channel_size | |
self.initializer = initializer | |
## non_linear | |
self.activ = non_linear | |
## 1x1 conv | |
self.u1_conv = nn.Conv2d(c_size, c_size, 1, stride=1, padding=0, | |
bias=False) | |
self.u2_conv = nn.Conv2d(c_size, c_size, 1, stride=1, padding=0, | |
bias=False) | |
## the W | |
self.w_conv = nn.Conv2d(c_size, c_size, kernel_size, | |
stride=1, padding=padding, | |
bias=False) | |
## alpha and miu | |
self.alpha = torch.nn.Parameter( | |
initializer(torch.empty(g_size), gain=1) | |
) | |
self.miu = torch.nn.Parameter(torch.empty(g_size)) | |
## TODO: BatchNorm | |
self.BN_list = nn.ModuleList( | |
[nn.BatchNorm2d(c_size) for _ in range(4 * time_step)] | |
) | |
## gains | |
self.tau = torch.nn.Parameter( | |
initializer(torch.empty(g_size), gain=1) | |
) | |
self.beta = torch.nn.Parameter( | |
initializer(torch.empty(g_size), gain=1) | |
) | |
self.gamma = torch.nn.Parameter( | |
initializer(torch.empty(g_size), gain=1) | |
) | |
## initialize the weights | |
self._initilizer() | |
def _initilizer(self): | |
self.initializer(self.u1_conv.weight) | |
self.initializer(self.u2_conv.weight) | |
self.initializer(self.w_conv.weight) | |
def forward(self, forward_feat : torch.Tensor, | |
H_t_minus : torch.Tensor, t : int): | |
G_t_1 = F.sigmoid(self.BN_list[t * 4](self.u1_conv(H_t_minus))) | |
C_t_1 = self.BN_list[t * 4 + 1](self.w_conv(G_t_1 * H_t_minus)) | |
H_t_1 = self.activ( | |
forward_feat - (C_t_1 * \ | |
(self.alpha * H_t_minus - self.miu) | |
) | |
) | |
G_t_2 = F.sigmoid( | |
self.BN_list[t * 4 + 2](self.u2_conv(H_t_1)) | |
) | |
C_t_2 = self.BN_list[t * 4 + 3](self.w_conv(H_t_1)) | |
H_t_2_prime = self.activ( | |
self.tau * H_t_1 + self.beta * C_t_2 + \ | |
self.gamma * H_t_1 * C_t_2 | |
) | |
output_hidden = H_t_minus * (1 - G_t_2) + G_t_2 * H_t_2_prime | |
return output_hidden | |
### The input the hidden should be sampled with in in some range? | |
class hGRUModule(torch.jit.ScriptModule): | |
__constants__ = ['_t_n',] | |
def __init__(self, hGRU_kernel, time_step): | |
super(hGRUModule, self).__init__() | |
self._t_n = time_step | |
self.hGRU_k = hGRU_kernel | |
@torch.jit.script_method | |
def forward(self, forward_feat, hidden): | |
for t in range(self._t_n): | |
hidden = self.hGRU_k(forward_feat, hidden, t) | |
return hidden |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment