Skip to content

Instantly share code, notes, and snippets.

@santisy
Created November 26, 2018 21:36
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save santisy/2014c5ffd1618e0feddf352a456c6951 to your computer and use it in GitHub Desktop.
Save santisy/2014c5ffd1618e0feddf352a456c6951 to your computer and use it in GitHub Desktop.
Question about the typing in torch.jit.script
import torch
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F
class hGRUKernel(nn.Module):
'''
initializer: the one used to initialize the weights
gain: xavier gain
'''
def __init__(self, channel_size, time_step,
non_linear=F.relu,
kernel_size=7, padding=3,
initializer=init.xavier_normal_,
gain=1):
super(hGRUKernel, self).__init__()
g_size = (1, channel_size, 1, 1)
c_size = channel_size
self.initializer = initializer
## non_linear
self.activ = non_linear
## 1x1 conv
self.u1_conv = nn.Conv2d(c_size, c_size, 1, stride=1, padding=0,
bias=False)
self.u2_conv = nn.Conv2d(c_size, c_size, 1, stride=1, padding=0,
bias=False)
## the W
self.w_conv = nn.Conv2d(c_size, c_size, kernel_size,
stride=1, padding=padding,
bias=False)
## alpha and miu
self.alpha = torch.nn.Parameter(
initializer(torch.empty(g_size), gain=1)
)
self.miu = torch.nn.Parameter(torch.empty(g_size))
## TODO: BatchNorm
self.BN_list = nn.ModuleList(
[nn.BatchNorm2d(c_size) for _ in range(4 * time_step)]
)
## gains
self.tau = torch.nn.Parameter(
initializer(torch.empty(g_size), gain=1)
)
self.beta = torch.nn.Parameter(
initializer(torch.empty(g_size), gain=1)
)
self.gamma = torch.nn.Parameter(
initializer(torch.empty(g_size), gain=1)
)
## initialize the weights
self._initilizer()
def _initilizer(self):
self.initializer(self.u1_conv.weight)
self.initializer(self.u2_conv.weight)
self.initializer(self.w_conv.weight)
def forward(self, forward_feat : torch.Tensor,
H_t_minus : torch.Tensor, t : int):
G_t_1 = F.sigmoid(self.BN_list[t * 4](self.u1_conv(H_t_minus)))
C_t_1 = self.BN_list[t * 4 + 1](self.w_conv(G_t_1 * H_t_minus))
H_t_1 = self.activ(
forward_feat - (C_t_1 * \
(self.alpha * H_t_minus - self.miu)
)
)
G_t_2 = F.sigmoid(
self.BN_list[t * 4 + 2](self.u2_conv(H_t_1))
)
C_t_2 = self.BN_list[t * 4 + 3](self.w_conv(H_t_1))
H_t_2_prime = self.activ(
self.tau * H_t_1 + self.beta * C_t_2 + \
self.gamma * H_t_1 * C_t_2
)
output_hidden = H_t_minus * (1 - G_t_2) + G_t_2 * H_t_2_prime
return output_hidden
### The input the hidden should be sampled with in in some range?
class hGRUModule(torch.jit.ScriptModule):
__constants__ = ['_t_n',]
def __init__(self, hGRU_kernel, time_step):
super(hGRUModule, self).__init__()
self._t_n = time_step
self.hGRU_k = hGRU_kernel
@torch.jit.script_method
def forward(self, forward_feat, hidden):
for t in range(self._t_n):
hidden = self.hGRU_k(forward_feat, hidden, t)
return hidden
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment