Created
April 18, 2019 14:04
-
-
Save iamhatesz/3ef34254febe482aa48e3e489f89b07b to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from dataclasses import dataclass | |
import numpy as np | |
import torch | |
from memory_profiler import profile | |
from tqdm import tqdm | |
@dataclass | |
class NumPyTransition: | |
state: np.ndarray | |
action: np.ndarray | |
next_state: np.ndarray | |
reward: np.ndarray | |
done: np.ndarray | |
@staticmethod | |
def randn(): | |
return NumPyTransition( | |
state=np.random.randn(10).astype(np.float32), | |
action=np.random.randn(1).astype(np.float32), | |
next_state=np.random.randn(10).astype(np.float32), | |
reward=np.random.randn(1).astype(np.float32), | |
done=np.random.randn(1).astype(np.float32) | |
) | |
@dataclass | |
class PyTorchTransition: | |
state: torch.Tensor | |
action: torch.Tensor | |
next_state: torch.Tensor | |
reward: torch.Tensor | |
done: torch.Tensor | |
@staticmethod | |
def randn(): | |
return PyTorchTransition( | |
state=torch.randn(10, dtype=torch.float), | |
action=torch.randn(1, dtype=torch.float), | |
next_state=torch.randn(10, dtype=torch.float), | |
reward=torch.randn(1, dtype=torch.float), | |
done=torch.randn(1, dtype=torch.float) | |
) | |
@profile | |
def main(): | |
num_samples = int(1e6) | |
replay = [NumPyTransition.randn() for _ in tqdm(range(num_samples))] | |
del replay | |
replay = [PyTorchTransition.randn() for _ in tqdm(range(num_samples))] | |
del replay | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment