Skip to content

Instantly share code, notes, and snippets.

@adamconkey
Last active October 19, 2020 18:28
Show Gist options
  • Save adamconkey/d9991600054bcdc9de398b1da8db4e40 to your computer and use it in GitHub Desktop.
Save adamconkey/d9991600054bcdc9de398b1da8db4e40 to your computer and use it in GitHub Desktop.
See all of the possible distribution combinations that can be used in PyTorch's kl_divergence function.
from torch.distributions.kl import _KL_REGISTRY
def view_kl_options():
"""
Displays all combinations of distributions that can be used in
torch's kl_divergence function. Iterates through the registry
and prints out the registered name combos.
"""
names = [(k[0].__name__, k[1].__name__) for k in _KL_REGISTRY.keys()]
max_name_len = max([len(t[0]) for t in names])
for arg1, arg2 in sorted(names):
print(f" {arg1:>{max_name_len}} || {arg2}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment