Skip to content

Instantly share code, notes, and snippets.

View vmoens's full-sized avatar

Vincent Moens vmoens

View GitHub Profile
@vmoens
vmoens / import_speed.py
Created February 28, 2025 15:34
import_speed
import timeit
# Create a simple module structure for demonstration purposes
class Module:
class Package:
class Subpack:
def func(self):
pass
module = Module()
@vmoens
vmoens / prob_actor_masked.py
Created February 26, 2025 16:38
prob_actor_masked.py
from tensordict import TensorDict
from tensordict.nn import TensorDictModule as Mod
from torchrl.modules import MaskedCategorical, ProbabilisticActor, MLP
from torchrl.envs import set_exploration_type
import torch
td = TensorDict(
observation=torch.randn(3, 4),
mask=torch.zeros(3, 10, dtype=torch.bool).bernoulli_(0.9),
batch_size=(3,)
@vmoens
vmoens / wr.py
Created February 5, 2025 16:42
weakref graph break
import torch
from tensordict import TensorDict
def view():
td = TensorDict(a=0)
return td.view(1)
assert view().shape == (1,)
# check that the weakref of the td points to an object that is out of scope
td = view()
import time
import torch
DELAY = 100000000
with torch.profiler.profile(
activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA],
record_shapes=True,
profile_memory=True,
#with_stack=True
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import torch
from tensordict import TensorDict
from torch.utils.benchmark import Timer
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
from tensordict.nn import TensorDictModule as Mod, TensorDictSequential as Seq
import pydot
seq = Seq(
Mod(lambda x: x + 1, in_keys=["input"], out_keys=["intermediate"]),
Mod(lambda x, y: (x * y).sqrt(), in_keys=["input", "intermediate"], out_keys=["out_0"]),
Mod(lambda z, x: z - z, in_keys=["out_0", "intermediate"], out_keys=["out_1"]),
)
def edges(seq):
@vmoens
vmoens / ppo.png
Last active January 11, 2024 11:14
ppo_perf
PPO performance with 100K steps