Skip to content

Instantly share code, notes, and snippets.

@wassname
Last active January 14, 2024 02:09
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save wassname/a5b6430308832f64a6380043b72bb269 to your computer and use it in GitHub Desktop.
Save wassname/a5b6430308832f64a6380043b72bb269 to your computer and use it in GitHub Desktop.
two-hot encoding notes

What is two-hot encoding?

Description

Two hot encoding was introduced in 2017 in "Marc G Bellemare et all "A distributional perspective on reinforcement learning" but the clearest description is in the 2020 paper "Dreamer-v3" by Danijar Hafner et al.) where it is used for reward and value distributions.

two-hot encoding is a generalization of onehot encoding to continuous values. It produces a vector of length |B| where all elements are 0 except for the two entries closest to the encoded continuous number, at positions k and k + 1. These two entries sum up to 1, with more weight given to the entry that is closer to the encoded number

Code samples

def calc_twohot(x, B):
    """
    x shape:(n_vals, ) is tensor of values
    B shape:(n_bins, ) is tensor of bin values
    returns a twohot tensor of shape (n_vals, n_bins)

    can verify this method is correct with:
     - calc_twohot(x, B)@B == x # expected value reconstructs x
     - (calc_twohot(x, B)>0).sum(dim=-1) == 2 # only two bins are hot

    code from https://github.com/RyanNavillus/PPO-v3/blob/b81083a0f41e6b74245b1e130e32c044fd34cc3e/ppo_v3/ppo_envpool_tricks_dmc.py#L125
    """
    twohot = torch.zeros((x.shape+B.shape), dtype=x.dtype, device=x.device)
    k1 = (B[None, :] <= x[:, None]).sum(dim=-1)-1
    k2 = k1+1
    k1 = torch.clip(k1, 0, len(B) - 1)
    k2 = torch.clip(k2, 0, len(B) - 1)

    # Handle k1 == k2 case
    equal = (k1 == k2)
    dist_to_below = torch.where(equal, 1, torch.abs(B[k1] - x))
    dist_to_above = torch.where(equal, 0, torch.abs(B[k2] - x))

    # Assign values to two-hot tensor
    total = dist_to_above + dist_to_below
    weight_below = dist_to_above / total
    weight_above = dist_to_below / total
    x_range = np.arange(len(x))
    twohot[x_range, k1] = weight_below   # assign left
    twohot[x_range, k2] = weight_above   # assign right
    return twohot

References:

Synonyms

two hot encoding 2hot encoding TwoHotEncoding

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment