Skip to content

Instantly share code, notes, and snippets.

@XinyueZ
Created November 7, 2022 20:48
Show Gist options
  • Save XinyueZ/00b88c66f6570919b2e7d00e176260df to your computer and use it in GitHub Desktop.
Save XinyueZ/00b88c66f6570919b2e7d00e176260df to your computer and use it in GitHub Desktop.
PixelRNNs Many-To-One
class GenModel(nn.Module):
def _init_weights(self, module):
if module.state_dict().get('weight') != None:
if type(module) == nn.BatchNorm1d:
nn.init.zeros_(module.bias)
else:
nn.init.kaiming_uniform_(module.weight)
def __init__(self, input_size, hidden_size, num_layers, bidirectional):
super().__init__()
self.convPool = nn.Sequential(
nn.Conv1d(kernel_size=4,
#dilation=0,
in_channels=input_size,
out_channels=hidden_size,
stride=1),
nn.BatchNorm1d(hidden_size),
nn.ReLU(),
nn.AvgPool1d(kernel_size=4),
)
self.convPool.apply(self._init_weights)
self.rnn = nn.LSTM(input_size=hidden_size,
hidden_size=hidden_size,
num_layers=num_layers,
batch_first=True,
bidirectional=bidirectional)
self.rnn.apply(self._init_weights)
# Output of RNN is duplicated automatically if bidirectional is true.
rnn_units = hidden_size*2 if bidirectional else hidden_size
self.flatten = nn.Sequential(
nn.Flatten(),
nn.BatchNorm1d(rnn_units),
)
self.logits = nn.Sequential(
nn.Linear(rnn_units, input_size),
nn.BatchNorm1d(input_size)
)
self.logits.apply(self._init_weights)
self.predictor = nn.Sigmoid()
def forward(self, *X):
x = X[0]
x = torch.permute(x, (0, 2, 1) )
y = self.convPool(x)
y = torch.permute(y, (0, 2, 1) )
y, (h, c) = self.rnn(y, (X[1], X[2]) ) if len(X)==3 else self.rnn(y)
y = y[:,-1,:]
y = self.flatten(y)
y_logits = self.logits(y)
y = self.predictor(y_logits)
return y, y_logits, (h, c)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment