Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
from typing import Tuple
import torch
import torch.nn as nn
from torch.nn.utils.rnn import PackedSequence, invert_permutation
def _build_pack_info_from_dones(
dones, T: int
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Create the indexing info needed to make the PackedSequence based on the dones.
PackedSequences are PyTorch's way of supporting a single RNN forward
call where each input in the batch can have an arbitrary sequence length
They work as follows: Given the sequences [c], [x, y, z], [a, b],
we generate data [x, a, c, y, b, z] and batch_sizes [3, 2, 1]. The
data is a flattened out version of the input sequences (the ordering in
data is determined by sequence length). batch_sizes tells you that
for each index, how many sequences have a length of (index + 1) or greater.
This method will generate the new index ordering such that you can
construct the data for a PackedSequence from a (N*T, ...) tensor
via x.index_select(0, select_inds)
"""
num_samples = len(dones)
rollout_boundaries = dones.clone().detach()
rollout_boundaries[T - 1 :: T] = 1 # end of each rollout is the boundary
rollout_boundaries = rollout_boundaries.nonzero().squeeze() + 1
rollout_lengths = rollout_boundaries[1:] - rollout_boundaries[:-1]
first_len = rollout_boundaries[0]
rollout_lengths = torch.cat([first_len.unsqueeze(0), rollout_lengths])
rollout_starts_orig = rollout_boundaries - rollout_lengths
# done=True for the last step in the episode, so done flags rolled 1 step to the right will indicate
# first frames in the episodes
is_new_episode = dones.clone().detach().view((-1, T))
is_new_episode = is_new_episode.roll(1, 1)
# roll() is cyclical, so done=True in the last position in the rollout will roll to 0th position
# we want to avoid it here. (note to self: is there a function that does two of these things at once?)
is_new_episode[:, 0] = 0
is_new_episode = is_new_episode.view((-1, ))
lengths, sorted_indices = torch.sort(rollout_lengths, descending=True)
# We will want these on the CPU for torch.unique_consecutive,
# so move now.
cpu_lengths = lengths.to(device="cpu", non_blocking=True)
# We need to keep the original unpermuted rollout_starts, because the permutation is later applied
# internally in the RNN implementation.
# From modules/rnn.py:
# Each batch of the hidden state should match the input sequence that
# the user believes he/she is passing in.
# hx = self.permute_hidden(hx, sorted_indices)
rollout_starts_sorted = rollout_starts_orig.index_select(0, sorted_indices)
select_inds = torch.empty(num_samples, device=dones.device, dtype=torch.int64)
max_length = int(cpu_lengths[0].item())
# batch_sizes is *always* on the CPU
batch_sizes = torch.empty((max_length,), device="cpu", dtype=torch.int64)
offset = 0
prev_len = 0
num_valid_for_length = lengths.size(0)
unique_lengths = torch.unique_consecutive(cpu_lengths)
# Iterate over all unique lengths in reverse as they sorted
# in decreasing order
for i in range(len(unique_lengths) - 1, -1, -1):
valids = lengths[0:num_valid_for_length] > prev_len
num_valid_for_length = int(valids.float().sum().item())
next_len = int(unique_lengths[i])
batch_sizes[prev_len:next_len] = num_valid_for_length
new_inds = (
rollout_starts_sorted[0:num_valid_for_length].view(1, num_valid_for_length)
+ torch.arange(prev_len, next_len, device=rollout_starts_sorted.device).view(next_len - prev_len, 1)
).view(-1)
# for a set of sequences [1, 2, 3], [4, 5], [6, 7], [8]
# these indices will be 1,4,6,8,2,5,7,3
# (all first steps in all trajectories, then all second steps, etc.)
select_inds[offset : offset + new_inds.numel()] = new_inds
offset += new_inds.numel()
prev_len = next_len
# Make sure we have an index for all elements
assert offset == num_samples
assert is_new_episode.shape[0] == num_samples
return rollout_starts_orig, is_new_episode, select_inds, batch_sizes, sorted_indices
def build_rnn_inputs(x, dones, dones_cpu, rnn_states, T: int):
r"""Create a PackedSequence input for an RNN such that each
set of steps that are part of the same episode are all part of
a batch in the PackedSequence.
Use the returned select_inds and build_core_out_from_seq to invert this.
:param x: A (N*T, -1) tensor of the data to build the PackedSequence out of
:param dones: A (N*T) tensor where dones[i] == 1.0 indicates an episode is done
:param dones_cpu: same but a CPU-bound tensor
:param rnn_states: A (N*T, -1) tensor of the rnn_hidden_states
:param T: The length of the rollout
:return: tuple(x_seq, rnn_states, select_inds)
WHERE
x_seq is the PackedSequence version of x to pass to the RNN
rnn_states are the corresponding rnn state
inverted_select_inds can be passed to build_core_out_from_seq so the RNN output can be retrieved
"""
(
rollout_starts,
is_new_episode,
select_inds,
batch_sizes,
sorted_indices,
) = _build_pack_info_from_dones(dones_cpu, T)
inverted_select_inds = invert_permutation(select_inds)
select_inds = select_inds.to(device=x.device)
inverted_select_inds = inverted_select_inds.to(device=x.device)
sorted_indices = sorted_indices.to(device=x.device)
x_seq = PackedSequence(x.index_select(0, select_inds), batch_sizes, sorted_indices)
rollout_starts = rollout_starts.to(device=x.device)
# We zero-out rnn states for timesteps at the beginning of the episode.
# rollout_starts are indices of all starts of sequences
# (which can be due to episode boundary or just boundary of a rollout)
# (1 - is_new_episode.view(-1, 1)).index_select(0, rollout_starts) gives us a zero for every beginning of
# the sequence that is actually also a start of a new episode, and by multiplying this RNN state by zero
# we ensure no information transfer across episode boundaries.
rnn_states = (rnn_states.index_select(0, rollout_starts) * (1 - is_new_episode.view(-1, 1)).index_select(0, rollout_starts))
return x_seq, rnn_states, inverted_select_inds
def build_core_out_from_seq(x_seq: PackedSequence, inverted_select_inds):
return x_seq.data.index_select(0, inverted_select_inds)
T = 97
N = 64
D = 128
rnn = nn.GRU(D, D, 1)
total_frames = 0
for _ in range(100):
# dones = torch.randint(0, 2, (N * T,))
rnn_hidden_states_random = torch.rand(T * N, D)
dones = torch.zeros((N * T,) )
for i in range(1, N * T, 7):
dones[i] = 1.0
x = torch.randn(T * N, D)
rnn_hidden_states = rnn_hidden_states_random.clone().detach()
x_seq, seq_states, inverted_select_inds = build_rnn_inputs(
x, dones, dones, rnn_hidden_states, T
)
new_out, _ = rnn(x_seq, seq_states.unsqueeze(0))
new_out = build_core_out_from_seq(new_out, inverted_select_inds)
rnn_hidden_states = rnn_hidden_states_random.clone().detach()
rnn_hidden_states = rnn_hidden_states[::T].unsqueeze(0)
old_outputs = []
for t in range(T):
rnn_out, rnn_hidden_states = rnn(x[t::T].view(1, N, -1), rnn_hidden_states)
old_outputs.append(rnn_out.view(N, -1))
rnn_hidden_states = rnn_hidden_states * (1 - dones[t::T].view(1, N, 1))
old_outputs = torch.stack(old_outputs, dim=1).view(N * T, -1)
print(torch.norm(new_out - old_outputs))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment