Two hot encoding was introduced in 2017 in "Marc G Bellemare et all "A distributional perspective on reinforcement learning" but the clearest description is in the 2020 paper "Dreamer-v3" by Danijar Hafner et al.) where it is used for reward and value distributions.
two-hot encoding is a generalization of onehot encoding to continuous values. It produces a vector of length |B| where all elements are 0 except for the two entries closest to the encoded continuous number, at positions k and k + 1. These two entries sum up to 1, with more weight given to the entry that is closer to the encoded number
def calc_twohot(x, B):
"""
x shape:(n_vals, ) is tensor of values
B shape:(n_bins, ) is tensor of bin values
returns a twohot tensor of shape (n_vals, n_bins)
can verify this method is correct with:
- calc_twohot(x, B)@B == x # expected value reconstructs x
- (calc_twohot(x, B)>0).sum(dim=-1) == 2 # only two bins are hot
code from https://github.com/RyanNavillus/PPO-v3/blob/b81083a0f41e6b74245b1e130e32c044fd34cc3e/ppo_v3/ppo_envpool_tricks_dmc.py#L125
"""
twohot = torch.zeros((x.shape+B.shape), dtype=x.dtype, device=x.device)
k1 = (B[None, :] <= x[:, None]).sum(dim=-1)-1
k2 = k1+1
k1 = torch.clip(k1, 0, len(B) - 1)
k2 = torch.clip(k2, 0, len(B) - 1)
# Handle k1 == k2 case
equal = (k1 == k2)
dist_to_below = torch.where(equal, 1, torch.abs(B[k1] - x))
dist_to_above = torch.where(equal, 0, torch.abs(B[k2] - x))
# Assign values to two-hot tensor
total = dist_to_above + dist_to_below
weight_below = dist_to_above / total
weight_above = dist_to_below / total
x_range = np.arange(len(x))
twohot[x_range, k1] = weight_below # assign left
twohot[x_range, k2] = weight_above # assign right
return twohot
References:
- https://github.com/RyanNavillus/PPO-v3/blob/b81083a0f41e6b74245b1e130e32c044fd34cc3e/ppo_v3/ppo_envpool_tricks_dmc.py#L125
- https://github.dev/Eclectic-Sheep/sheeprl/blob/52f49be5971c5753e18bdf328d3035334fe688f1/sheeprl/utils/distribution.py#L224
- https://github.com/DuaneNielsen/dreamerv3/blob/72f86b633334dc39b75376ea7e26e79536072279/dists.py#L111
- https://github.com/google-deepmind/rlax/blob/df8e6006365ed3cba366747a00f0d1fd25a406e7/rlax/_src/transforms.py#L92
two hot encoding 2hot encoding TwoHotEncoding