Skip to content

Instantly share code, notes, and snippets.

Last active April 28, 2022 14:29
Show Gist options
  • Save grey-area/98033b9708827f14a2a82d2022d1cbfa to your computer and use it in GitHub Desktop.
Save grey-area/98033b9708827f14a2a82d2022d1cbfa to your computer and use it in GitHub Desktop.
import torch.nn as nn
import torch
# NOTE: I've just put this here so that I don't have to import any other part of your code base
# to try out / run this model
control_signals_labels = ['rhand', 'lhand', 'head']
residual_block_linear = 1024
# NOTE: this is just because we're now doing 1d batch norm on a 3 dimensional tensor,
# and nn.BatchNorm1d assumes our feature dimension is the second dimension, not the third,
# so we have to shuffle and shuffle back
class BatchNorm1d(nn.Module):
def __init__(self, n_features):
super().__init__() = nn.BatchNorm1d(n_features)
def forward(self, x):
return, 2)).transpose(1, 2)
class ResidualBlock(nn.Module):
def __init__(self, n_features):
self.fc_layer1 = nn.Linear(n_features, n_features)
self.bn1 = BatchNorm1d(n_features)
self.relu = nn.LeakyReLU(inplace=True)
self.dropout = nn.Dropout(0.5)
self.fc_layer2 = nn.Linear(n_features, n_features)
self.bn2 = BatchNorm1d(n_features)
def forward(self, x):
residual = x
out = self.fc_layer1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.dropout(out)
out = self.fc_layer2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.dropout(out)
out += residual
return out
class Pose_Generator(nn.Module):
# NOTE: model no longer needs to know about seqlen
def __init__(self, n_control_features, n_angle_signal_features, no_of_angles):
hidden_channels = 512
num_layers = 1
dropout = 0.0
bidirectional = False
# NOTE: This wasn't a strictly necessary change, but I've removed the wrapper class around nn.GRU
self.control_flow_layer = nn.GRU(n_control_features, hidden_channels, num_layers, batch_first=True, dropout=dropout, bidirectional=bidirectional)
self.angle_signal_flow_layer = nn.GRU(n_angle_signal_features, hidden_channels, num_layers, batch_first=True, dropout=dropout, bidirectional=bidirectional)
# NOTE: input is each frame of the output of the GRU, not the flattened whole output
self.e2d = nn.Sequential(
nn.Linear(hidden_channels * 2, residual_block_linear),
self.generator = nn.Sequential(
nn.Linear(residual_block_linear, no_of_angles)
# NOTE: control and angle hidden states are optionally passed in
def encoder2d(self, control_signal, angle_signal, control_hidden=None, angle_hidden=None):
cs, control_hidden = self.control_flow_layer(control_signal, control_hidden)
av, angle_hidden = self.angle_signal_flow_layer(angle_signal, angle_hidden)
# NOTE: no flattening here, the e2d and generator operate on every frame independently (after the GRU)
out =, av), dim=-1)
enc = self.e2d(out)
return enc, (control_hidden, angle_hidden)
# NOTE: control and angle hidden states are optionally passed in
def forward(self, control_signal, angle_signal, control_hidden=None, angle_hidden=None):
enc, hidden_state = self.encoder2d(control_signal, angle_signal, control_hidden, angle_hidden)
gen = self.generator(enc)
return gen, hidden_state
if __name__ == "__main__":
# NOTE: made up dimensionalities for playing with the model below
n_control_features = 10
n_angle_signal_features = 3
no_of_angles = 5
seqlen = 15
batch_size = 7
control_in = torch.zeros(batch_size, seqlen, n_control_features)
angles_in = torch.zeros(batch_size, seqlen, n_angle_signal_features)
model = Pose_Generator(n_control_features, n_angle_signal_features, no_of_angles)
# NOTE: example of using the model during training
# NOTE: we have a target value for each of the 40 frames, not just the last,
# so effectively 40x as much training data!
# NOTE: during training, you never have to pass in the hidden states and you ignore the hidden states that come out
target = torch.zeros(batch_size, seqlen, no_of_angles)
out, _ = model(control_in, angles_in)
# NOTE: Prediction (this would be how the model would be used on the Unity side, with the initially 'blank' hidden state being passed in and returned in a loop)
# and inputs being passed in one frame at a time.
# NOTE: This way you can let whoever is handling the Unity side worry about normalization.
model = model.eval()
control_hidden = torch.zeros(1, 1, 512)
angle_hidden = torch.zeros(1, 1, 512)
for i in range(10): # This would be an ongoing loop
control = torch.zeros(1, 1, n_control_features)
angle = torch.zeros(1, 1, n_angle_signal_features)
frame_out, (control_hidden, angle_hidden) = model(control, angle, control_hidden, angle_hidden)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment