Last active
June 14, 2019 15:29
Star
You must be signed in to star a gist
conv_rnn_modules.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# pylint: disable-all | |
import torch.nn as nn | |
from torch.nn import functional as F | |
import torch | |
def time_to_batch(x): | |
t, n = x.size()[:2] | |
x = x.view(n * t, *x.size()[2:]) | |
return x, n | |
def batch_to_time(x, n=32): | |
nt = x.size(0) | |
time = int(nt / n) | |
x = x.view(time, n, *x.size()[1:]) | |
return x | |
def hard_sigmoid(x, alpha=0.0): | |
return torch.clamp(x + 0.5, 0 - alpha, 1 + alpha) | |
class GConv2d(nn.Module): | |
r""" Generic Convolution Operator. | |
Applies Distributed in time if you feed a 5D Tensor, it assumes T,N,C,H,W format | |
""" | |
def __init__(self, in_channels, out_channels, kernel_size, stride, dilation=1, separable=True, norm="batch", bias=False, nonlinearity=F.relu): | |
super(GConv2d, self).__init__() | |
self.in_channels = in_channels | |
self.out_channels = out_channels | |
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, dilation=dilation, | |
groups=in_channels if separable else 1, | |
padding=(kernel_size * dilation) // 2, | |
bias=bias) | |
if norm == "batch": | |
self.bn1 = nn.BatchNorm2d(in_channels, affine=True) | |
elif norm == "instance": | |
self.bn1 = nn.InstanceNorm2d(in_channels, affine=True) | |
elif norm == "group": | |
self.bn1 = nn.GroupNorm(num_groups=3, num_channels=in_channels, affine=True) | |
elif norm == "weight": | |
self.conv1 = nn.utils.weight_norm(self.conv1) | |
self.bn1 = lambda x: x | |
else: | |
self.bn1 = lambda x: x | |
self.act = nonlinearity | |
def forward(self, x): | |
is_volume = x.dim() == 5 | |
if is_volume: | |
x, n = time_to_batch(x) | |
h = self.conv1(x) | |
h = self.bn1(h) | |
h = self.act(h) | |
if is_volume: | |
h = batch_to_time(h, n) | |
return h | |
def conv_dw(in_channels, out_channels, kernel_size=3, stride=1, norm='instance', dilation=1): | |
return nn.Sequential(GConv2d(in_channels, in_channels, kernel_size, stride, kernel_size//2, dilation, True, norm), | |
GConv2d(in_channels, out_channels, 1, 1, 1, False, norm)) | |
class BaseRNN(nn.Module): | |
""" | |
base class doing the unrolling, and keeps memory inside | |
""" | |
def __init__(self, hard=False): | |
super(BaseRNN, self).__init__() | |
self.conv_x2h = lambda x: x | |
self.sigmoid = hard_sigmoid if hard else torch.sigmoid | |
self.tanh = F.hardtanh if hard else torch.tanh | |
self.reset_hidden() | |
def forward(self, x): | |
xi = self.conv_x2h(x) | |
inference = xi.dim() == 4 | |
xseq = [xi] if inference else xi.unbind(0) | |
if isinstance(self.prev_hidden, list): | |
for item in self.prev_hidden: | |
if item is not None: | |
item.detach() | |
elif self.prev_hidden is not None: | |
self.prev_hidden.detach() | |
result = [] | |
for t, xt in enumerate(xseq): | |
self.prev_hidden, hidden = self.update_hidden(xt) | |
result.append(hidden) | |
result = hidden if inference else torch.cat(result, dim=0) | |
return result | |
def get_hidden(self): | |
return self.prev_hidden | |
def set_hidden(self, hidden): | |
self.prev_hidden = hidden | |
def update_hidden(self, xt): | |
raise NotImplementedError() | |
def reset_hidden(self, mask=None): | |
if mask is None or self.prev_hidden is None: | |
self.prev_hidden = None | |
else: | |
if isinstance(self.prev_hidden, list): | |
for item in self.prev_hidden: | |
if item is not None: | |
item *= mask | |
elif self.prev_hidden is not None: | |
self.prev_hidden *= mask | |
class ConvLSTM(BaseRNN): | |
r"""ConvLSTM cell | |
""" | |
def __init__(self, in_channels, hidden_dim, conv_func, stride=1, hard=False, nonlinearity=F.relu): | |
super(ConvLSTM, self).__init__(hard) | |
self.hidden_dim = hidden_dim | |
self.conv_x2h = conv_func(in_channels=in_channels, out_channels=4 * self.hidden_dim, stride=stride) | |
self.conv_h2h = conv_func(in_channels=self.hidden_dim, out_channels=4 * self.hidden_dim, stride=1) | |
self.act = nonlinearity | |
def update_hidden(self, xt): | |
prev_h, prev_c = self.prev_hidden if self.prev_hidden is None else None, None | |
tmp = xt if prev_h is None else self.conv_h2h(prev_h) + xt | |
cc_i, cc_f, cc_o, cc_g = torch.split(tmp, self.hidden_dim, dim=1) | |
f = self.sigmoid(cc_f) | |
i = self.sigmoid(cc_i) | |
o = self.sigmoid(cc_o) | |
g = self.act(cc_g) | |
c = i * g if prev_h is None else f * prev_c + i * g | |
h = o * self.act(c) | |
return [h, c], h.unsqueeze(0) | |
class ConvGRU(BaseRNN): | |
r"""ConvGRU cell | |
""" | |
def __init__(self, in_channels, hidden_dim, conv_func, stride=1, hard=False, nonlinearity=F.relu): | |
super(ConvGRU, self).__init__(hard) | |
self.hidden_dim = hidden_dim | |
self.conv_x2h = conv_func(in_channels=in_channels, out_channels=3 * self.hidden_dim, stride=stride) | |
self.conv_h2zr = conv_func(in_channels=self.hidden_dim, out_channels=2 * self.hidden_dim, stride=1) | |
self.conv_h2h = conv_func(in_channels=self.hidden_dim, out_channels=self.hidden_dim, stride=1) | |
self.act = nonlinearity | |
def update_hidden(self, xt): | |
x_zr, x_h = xt[:, :2 * self.hidden_dim], xt[:, 2 * self.hidden_dim:] | |
tmp = x_zr if self.prev_hidden is None else self.conv_h2zr(self.prev_hidden) + x_zr | |
cc_z, cc_r = torch.split(tmp, self.hidden_dim, dim=1) | |
z = self.sigmoid(cc_z) | |
r = self.sigmoid(cc_r) | |
tmp = x_h if self.prev_hidden is None else self.conv_h2h(r * self.prev_hidden) + x_h | |
tmp = self.act(tmp) | |
hidden = z * tmp if self.prev_hidden is None else (1-z) * self.prev_hidden + z * tmp | |
return hidden, hidden.unsqueeze(0) | |
class ConvQRNN(BaseRNN): | |
r"""ConvGRU cell | |
""" | |
def __init__(self, in_channels, hidden_dim, conv_func, stride=1, hard=False, nonlinearity=F.relu): | |
super(ConvQRNN, self).__init__(hard) | |
self.hidden_dim = hidden_dim | |
self.conv_x2h = conv_func(in_channels=in_channels, out_channels=2 * self.hidden_dim, stride=stride) | |
self.conv_h2h = lambda x: x | |
self.act = nonlinearity | |
def update_hidden(self, xt): | |
f, x = torch.split(xt, self.hidden_dim, dim=1) | |
f = torch.sigmoid(f) | |
hidden = (1-f) * x if self.prev_hidden is None else (1-f) * x + f * self.prev_hidden | |
hidden = self.act(hidden) | |
return hidden, hidden.unsqueeze(0) | |
class ReciprocalConvLSTM(BaseRNN): | |
"""r https://arxiv.org/pdf/1807.00053.pdf | |
""" | |
def __init__(self, in_channels, hidden_dim, conv_func, stride=1, hard=False, nonlinearity=F.relu): | |
super(ReciprocalConvLSTM, self).__init__(hard) | |
self.hidden_dim = hidden_dim | |
self.conv_x2h = conv_func(in_channels=in_channels, out_channels=2 * self.hidden_dim, stride=stride) | |
self.conv_c2hc = conv_func(in_channels=self.hidden_dim, out_channels=2 * self.hidden_dim, stride=1) | |
self.conv_h2hc = conv_func(in_channels=self.hidden_dim, out_channels=2 * self.hidden_dim, stride=1) | |
self.act = nonlinearity | |
def update_hidden(self, xt): | |
if self.prev_hidden is None: | |
h, c = torch.split(xt, self.hidden_dim, dim=1) | |
else: | |
prev_h, prev_c = self.prev_hidden | |
xth, xtc = torch.split(xt, self.hidden_dim, dim=1) | |
tmp_from_h = self.sigmoid(self.conv_h2hc(prev_h) + xth) | |
tmp_from_c = self.sigmoid(self.conv_c2hc(prev_c) + xtc) | |
hc, hh = torch.split(tmp_from_h, self.hidden_dim, dim=1) | |
ch, cc = torch.split(tmp_from_c, self.hidden_dim, dim=1) | |
h = (1 - ch) * xth + (1 - hh) * prev_h | |
c = (1 - hc) * xtc + (1 - cc) * prev_c | |
h, c = self.act(h), self.act(c) | |
return [h, c], h.unsqueeze(0) | |
if __name__ == '__main__': | |
from functools import partial | |
t, n, c, h, w = 4, 4, 8, 64, 64 | |
x = torch.rand(t, n, c, h, w) | |
#conv_func = partial(GConv2d, kernel_size=5, padding=5//2, norm='weight') | |
conv_func = partial(conv_dw, kernel_size=3, norm='weight') | |
lstm = ConvQRNN(c, 8, conv_func) | |
for _ in range(10): | |
lstm.reset_hidden() | |
y = lstm(x) | |
# conv = conv_func(in_channels=c, out_channels=8, stride=2) | |
# print(conv(x).shape) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment