Skip to content

Instantly share code, notes, and snippets.

@grey-area
Created August 19, 2022 12:44
Show Gist options
  • Save grey-area/d040283ee7391e45ca0aaeca2350f254 to your computer and use it in GitHub Desktop.
Save grey-area/d040283ee7391e45ca0aaeca2350f254 to your computer and use it in GitHub Desktop.
import torch
from math import log
import matplotlib.pyplot as plt
def get_positional_encoding(cycle_limit):
max_len = 5000
d_model = 256
position = torch.arange(max_len).unsqueeze(1)
div_term = torch.exp(-torch.arange(0, d_model, 2) / d_model * log(cycle_limit / (2 * torch.pi)))
pe = torch.zeros(max_len, 1, d_model)
pe[:, 0, 0::2] = torch.sin(position * div_term)
pe[:, 0, 1::2] = torch.cos(position * div_term)
return pe
def plot_lowest_frequency(ax, cycle_limit):
pe = get_positional_encoding(cycle_limit)
ax.plot(pe[:, 0, -1].numpy())
ax.set_title(f'{cycle_limit=}')
if __name__ == "__main__":
cycle_limits = [250, 500, 1000, 2000]
fig, axes = plt.subplots(4)
for ax, cycle_limit in zip(axes, cycle_limits):
plot_lowest_frequency(ax, cycle_limit)
plt.tight_layout()
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment