Skip to content

Instantly share code, notes, and snippets.

@NathanGavenski
Created December 13, 2023 15:40
Show Gist options
  • Save NathanGavenski/ec904c7c3bf06b6361a0897b798206ac to your computer and use it in GitHub Desktop.
Save NathanGavenski/ec904c7c3bf06b6361a0897b798206ac to your computer and use it in GitHub Desktop.
Sequence BaselineDataset from IL-Dataset
from imitation_datasets.dataset import BaselineDataset
class SequenceDataset(BaselineDataset):
"""
Squence dataset for the BaselineDataset from IL-Dataset.
"""
def __init__(
self,
path: str,
source: str = "local",
split: str = "train",
n_episodes: int = None,
) -> None:
super().__init__(path, source, split, n_episodes)
episode_starts = list(np.where(self.data["episode_starts"] == 1)[0])
episode_starts.append(len(self.data["episode_starts"]))
if n_episodes is not None:
if split == "train":
episode_starts = episode_starts[:n_episodes + 1]
else:
episode_starts = episode_starts[n_episodes:]
self.lenghts = []
self.sequences = []
self.sequences_actions = []
for start, end in zip(episode_starts, tqdm(episode_starts[1:], desc="Creating sequence")):
episode = self.data["obs"][start:end]
episode = torch.from_numpy(episode)
actions = torch.from_numpy(self.data["actions"][start:end].reshape((-1, 1)))
self.lenghts.append(episode.shape[0])
self.sequences.append(episode)
self.sequences_actions.append(actions)
self.sequences = pad_sequence(self.sequences, batch_first=True)
self.sequences_actions = pad_sequence(self.sequences_actions, batch_first=True)
def __len__(self) -> int:
return self.sequences.shape[0]
def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor, list[int]]:
return self.sequences[index], self.lenghts[index], self.sequences_actions[index]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment