This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import timeit | |
# Create a simple module structure for demonstration purposes | |
class Module: | |
class Package: | |
class Subpack: | |
def func(self): | |
pass | |
module = Module() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from tensordict import TensorDict | |
from tensordict.nn import TensorDictModule as Mod | |
from torchrl.modules import MaskedCategorical, ProbabilisticActor, MLP | |
from torchrl.envs import set_exploration_type | |
import torch | |
td = TensorDict( | |
observation=torch.randn(3, 4), | |
mask=torch.zeros(3, 10, dtype=torch.bool).bernoulli_(0.9), | |
batch_size=(3,) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from tensordict import TensorDict | |
def view(): | |
td = TensorDict(a=0) | |
return td.view(1) | |
assert view().shape == (1,) | |
# check that the weakref of the td points to an object that is out of scope | |
td = view() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import time | |
import torch | |
DELAY = 100000000 | |
with torch.profiler.profile( | |
activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA], | |
record_shapes=True, | |
profile_memory=True, | |
#with_stack=True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import argparse | |
import torch | |
from tensordict import TensorDict | |
from torch.utils.benchmark import Timer |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from tensordict.nn import TensorDictModule as Mod, TensorDictSequential as Seq | |
import pydot | |
seq = Seq( | |
Mod(lambda x: x + 1, in_keys=["input"], out_keys=["intermediate"]), | |
Mod(lambda x, y: (x * y).sqrt(), in_keys=["input", "intermediate"], out_keys=["out_0"]), | |
Mod(lambda z, x: z - z, in_keys=["out_0", "intermediate"], out_keys=["out_1"]), | |
) | |
def edges(seq): |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
PPO performance with 100K steps |