Skip to content

Instantly share code, notes, and snippets.

View RPegoud's full-sized avatar

Ryan Pégoud RPegoud

View GitHub Profile
@RPegoud
RPegoud / add_kernel.py
Created September 25, 2025 10:34
PyTorch wrapper around Triton vector addition kernel
def add(X: torch.Tensor, Y: torch.Tensor) -> torch.Tensor:
"""PyTorch wrapper for `add_kernel`."""
output = torch.zeros_like(x) # allocate memory for the output
n_elements = output.numel() # dimension of X and Y
# cdiv = ceil div, computes the number of blocks to use
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
# calling the kernel will automatically store `BLOCK_SIZE` in `meta`
# and update `output`
add_kernel[grid](X, Y, output, n_elements, BLOCK_SIZE=1024)
@RPegoud
RPegoud / add_kernel.py
Created September 25, 2025 10:32
Triton Vector Addition Kernel
import triton
import triton.language as tl
@triton.jit
def add_kernel(
x_ptr, # pointer to the first memory entry of x
y_ptr, # pointer to the first memory entry of y
output_ptr, # pointer to the first memory entry of the output
n_elements, # dimension of x and y
BLOCK_SIZE: tl.constexpr, # size of a single block
@RPegoud
RPegoud / stoix_rainbow_head.py
Last active July 11, 2024 06:58
Stoix Noisy Distributional Dueling Q-Network
class NoisyDistributionalDuelingQNetwork(nn.Module):
num_atoms: int
vmax: float
vmin: float
action_dim: int
epsilon: float
layer_sizes: Sequence[int]
sigma_zero: float
activation: str = "relu"
use_layer_norm: bool = False
@RPegoud
RPegoud / stoix_mutli_step_C51.py
Last active July 11, 2024 06:58
Stoix Multi-step C51
# --- Collect the n-step data for the new loss function ---
transition_sample = buffer_sample_fn(buffer_state, sample_key) # sample a trajectory with length N
transition_sequence: Transition = transition_sample.experience
# Extract the first and last observations.
step_0_obs = jax.tree_util.tree_map(lambda x: x[:, 0], transition_sequence).obs
step_0_actions = transition_sequence.action[:, 0]
step_n_obs = jax.tree_util.tree_map(lambda x: x[:, -1], transition_sequence).next_obs
# check if any of the transitions are done - this will be used to decide
# if bootstrapping is needed
n_step_done = jnp.any(transition_sequence.done, axis=-1)
@RPegoud
RPegoud / stoix_PER_snippets.py
Last active July 11, 2024 06:56
Stoix Prioritized Experience Replay Snippets
# --- Instantiate the replay buffer ---
buffer_fn = fbx.make_prioritised_trajectory_buffer(
max_size=config.system.buffer_size, # Number of experiences that the buffer can contain
min_length_time_axis=config.system.n_step, # Number of experiences required before we can sample
add_batch_size=config.arch.num_envs, # Number of experiences added to the replay buffer at once
sample_batch_size=config.system.batch_size, # Number of batches of experiences returned when sampling
sample_sequence_length=config.system.n_step, # Sequence length of a single batch of experiences
priority_exponent=config.system.priority_exponent, # Alpha parameter in Prioritized Experience Replay
period=1,
device="tpu",
@RPegoud
RPegoud / stoix_c51_network_and_loss.py
Last active July 10, 2024 07:26
Stoix Distributional DQN and categorical loss
# --- Network head ---
class DistributionalDiscreteQNetwork(nn.Module):
action_dim: int
epsilon: float
num_atoms: int
vmin: float
vmax: float
kernel_init: Initializer = lecun_normal()
@nn.compact
@RPegoud
RPegoud / stoix_dueling_q_network.py
Last active July 10, 2024 07:25
Stoix dueling Q-network
class DuelingQNetwork(nn.Module):
action_dim: int
epsilon: float
layer_sizes: Sequence[int]
activation: str = "relu"
use_layer_norm: bool = False
kernel_init: Initializer = orthogonal(np.sqrt(2.0))
@nn.compact
@RPegoud
RPegoud / stoix_noisy_layer.py
Last active July 11, 2024 15:06
Stoix noisy layer
class NoisyLinear(nn.Module):
"""
Noisy Linear Layer using independent Gaussian noise
as defined by Fortunato et al.
"""
features: int
use_bias: bool = True
dtype: Optional[Dtype] = None
param_dtype: Dtype = jnp.float32
@RPegoud
RPegoud / stoix_double_q_learning.py
Last active July 9, 2024 19:37
Stoix double Q learning
def double_q_learning(
q_tm1: chex.Array,
q_t_value: chex.Array,
a_tm1: chex.Array,
r_t: chex.Array,
d_t: chex.Array,
q_t_selector: chex.Array, # value prediction of the target network
huber_loss_parameter: chex.Array,
) -> jnp.ndarray:
"""Computes the double Q-learning loss. Each input is a batch."""
@RPegoud
RPegoud / stoix_q_learning.py
Last active July 9, 2024 19:37
Stoix Q learning
def q_learning(
q_tm1: chex.Array, # predicted value
a_tm1: chex.Array, # selected action
r_t: chex.Array, # reward
d_t: chex.Array, # discount
q_t: chex.Array, # predicted value
huber_loss_parameter: chex.Array,
) -> jnp.ndarray:
"""Computes the Q-learning loss. Each input is a batch."""
batch_indices = jnp.arange(a_tm1.shape[0])