Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save grey-area/17c0d03c3990515a40d169af8f16b7cb to your computer and use it in GitHub Desktop.
Save grey-area/17c0d03c3990515a40d169af8f16b7cb to your computer and use it in GitHub Desktop.
import torch
from torch import nn
class MyModule(nn.Module):
def __init__(self, num_input_features):
super().__init__()
num_hidden = 5
num_layers = 2
num_output_features = 2
self.gru = nn.GRU(
input_size=num_input_features,
hidden_size=num_hidden,
num_layers=num_layers
)
self.linear = nn.Linear(num_hidden, num_output_features)
def forward(self, x):
gru_out, _ = self.gru(x)
linear_out = self.linear(gru_out)
return linear_out
def predict_frame(self, x, gru_hidden=None):
gru_out, gru_hidden = self.gru(x, gru_hidden)
linear_out = self.linear(gru_out)
return linear_out, gru_hidden
if __name__ == "__main__":
num_input_features = 3
batch_size = 1
sequence_length = 100
model = MyModule(num_input_features)
model_input = torch.randn(size=(sequence_length, batch_size, num_input_features))
# This is the output of the model when passing in the whole sequence without any hidden state, as during training
output_whole_sequence = model(model_input)
# This is the output of the model when passing in the sequence one frame at a time, plus the previous hidden state
# to a method that knows how to handle the hidden state and returns the new hidden state.
# In practice, this loop would be done in the application, i.e., handled in Unity/C#
output_frame_by_frame = []
gru_hidden = None
for i in range(sequence_length):
model_input_frame = model_input[i:i + 1, ...]
output_frame, gru_hidden = model.predict_frame(model_input_frame, gru_hidden)
output_frame_by_frame.append(output_frame)
output_frame_by_frame = torch.cat(output_frame_by_frame, dim=0)
max_absolute_output_difference = (output_whole_sequence - output_frame_by_frame).abs().max()
print(max_absolute_output_difference.item())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment