This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def add(X: torch.Tensor, Y: torch.Tensor) -> torch.Tensor: | |
"""PyTorch wrapper for `add_kernel`.""" | |
output = torch.zeros_like(x) # allocate memory for the output | |
n_elements = output.numel() # dimension of X and Y | |
# cdiv = ceil div, computes the number of blocks to use | |
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) | |
# calling the kernel will automatically store `BLOCK_SIZE` in `meta` | |
# and update `output` | |
add_kernel[grid](X, Y, output, n_elements, BLOCK_SIZE=1024) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import triton | |
import triton.language as tl | |
@triton.jit | |
def add_kernel( | |
x_ptr, # pointer to the first memory entry of x | |
y_ptr, # pointer to the first memory entry of y | |
output_ptr, # pointer to the first memory entry of the output | |
n_elements, # dimension of x and y | |
BLOCK_SIZE: tl.constexpr, # size of a single block |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class NoisyDistributionalDuelingQNetwork(nn.Module): | |
num_atoms: int | |
vmax: float | |
vmin: float | |
action_dim: int | |
epsilon: float | |
layer_sizes: Sequence[int] | |
sigma_zero: float | |
activation: str = "relu" | |
use_layer_norm: bool = False |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# --- Collect the n-step data for the new loss function --- | |
transition_sample = buffer_sample_fn(buffer_state, sample_key) # sample a trajectory with length N | |
transition_sequence: Transition = transition_sample.experience | |
# Extract the first and last observations. | |
step_0_obs = jax.tree_util.tree_map(lambda x: x[:, 0], transition_sequence).obs | |
step_0_actions = transition_sequence.action[:, 0] | |
step_n_obs = jax.tree_util.tree_map(lambda x: x[:, -1], transition_sequence).next_obs | |
# check if any of the transitions are done - this will be used to decide | |
# if bootstrapping is needed | |
n_step_done = jnp.any(transition_sequence.done, axis=-1) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# --- Instantiate the replay buffer --- | |
buffer_fn = fbx.make_prioritised_trajectory_buffer( | |
max_size=config.system.buffer_size, # Number of experiences that the buffer can contain | |
min_length_time_axis=config.system.n_step, # Number of experiences required before we can sample | |
add_batch_size=config.arch.num_envs, # Number of experiences added to the replay buffer at once | |
sample_batch_size=config.system.batch_size, # Number of batches of experiences returned when sampling | |
sample_sequence_length=config.system.n_step, # Sequence length of a single batch of experiences | |
priority_exponent=config.system.priority_exponent, # Alpha parameter in Prioritized Experience Replay | |
period=1, | |
device="tpu", |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# --- Network head --- | |
class DistributionalDiscreteQNetwork(nn.Module): | |
action_dim: int | |
epsilon: float | |
num_atoms: int | |
vmin: float | |
vmax: float | |
kernel_init: Initializer = lecun_normal() | |
@nn.compact |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class DuelingQNetwork(nn.Module): | |
action_dim: int | |
epsilon: float | |
layer_sizes: Sequence[int] | |
activation: str = "relu" | |
use_layer_norm: bool = False | |
kernel_init: Initializer = orthogonal(np.sqrt(2.0)) | |
@nn.compact |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class NoisyLinear(nn.Module): | |
""" | |
Noisy Linear Layer using independent Gaussian noise | |
as defined by Fortunato et al. | |
""" | |
features: int | |
use_bias: bool = True | |
dtype: Optional[Dtype] = None | |
param_dtype: Dtype = jnp.float32 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def double_q_learning( | |
q_tm1: chex.Array, | |
q_t_value: chex.Array, | |
a_tm1: chex.Array, | |
r_t: chex.Array, | |
d_t: chex.Array, | |
q_t_selector: chex.Array, # value prediction of the target network | |
huber_loss_parameter: chex.Array, | |
) -> jnp.ndarray: | |
"""Computes the double Q-learning loss. Each input is a batch.""" |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def q_learning( | |
q_tm1: chex.Array, # predicted value | |
a_tm1: chex.Array, # selected action | |
r_t: chex.Array, # reward | |
d_t: chex.Array, # discount | |
q_t: chex.Array, # predicted value | |
huber_loss_parameter: chex.Array, | |
) -> jnp.ndarray: | |
"""Computes the Q-learning loss. Each input is a batch.""" | |
batch_indices = jnp.arange(a_tm1.shape[0]) |
NewerOlder