Skip to content

Instantly share code, notes, and snippets.

@iamhatesz
Created April 18, 2019 14:04
Show Gist options
  • Save iamhatesz/3ef34254febe482aa48e3e489f89b07b to your computer and use it in GitHub Desktop.
Save iamhatesz/3ef34254febe482aa48e3e489f89b07b to your computer and use it in GitHub Desktop.
from dataclasses import dataclass
import numpy as np
import torch
from memory_profiler import profile
from tqdm import tqdm
@dataclass
class NumPyTransition:
state: np.ndarray
action: np.ndarray
next_state: np.ndarray
reward: np.ndarray
done: np.ndarray
@staticmethod
def randn():
return NumPyTransition(
state=np.random.randn(10).astype(np.float32),
action=np.random.randn(1).astype(np.float32),
next_state=np.random.randn(10).astype(np.float32),
reward=np.random.randn(1).astype(np.float32),
done=np.random.randn(1).astype(np.float32)
)
@dataclass
class PyTorchTransition:
state: torch.Tensor
action: torch.Tensor
next_state: torch.Tensor
reward: torch.Tensor
done: torch.Tensor
@staticmethod
def randn():
return PyTorchTransition(
state=torch.randn(10, dtype=torch.float),
action=torch.randn(1, dtype=torch.float),
next_state=torch.randn(10, dtype=torch.float),
reward=torch.randn(1, dtype=torch.float),
done=torch.randn(1, dtype=torch.float)
)
@profile
def main():
num_samples = int(1e6)
replay = [NumPyTransition.randn() for _ in tqdm(range(num_samples))]
del replay
replay = [PyTorchTransition.randn() for _ in tqdm(range(num_samples))]
del replay
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment