Skip to content

Instantly share code, notes, and snippets.

@kgourgou
Last active January 11, 2024 14:49
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save kgourgou/875a5fcdada7b44fd66fbd4c3929ce38 to your computer and use it in GitHub Desktop.
Save kgourgou/875a5fcdada7b44fd66fbd4c3929ce38 to your computer and use it in GitHub Desktop.
import jax
import jax.numpy as jnp
def ent(X):
"""Calculates the entropy of a dataset X."""
N, D = X.shape # Get the number of samples and dimensions
# Compute pairwise squared distances
dist_sq = jnp.sum((X[:, jnp.newaxis, :] - X[jnp.newaxis, :, :]) ** 2, axis=-1)
# Set the diagonal to a large number so it doesn't affect the min calculation
dist_sq = fill_diagonal(dist_sq, jnp.inf)
# Find the minimum distance for each point
min_dist = jnp.sqrt(jnp.min(dist_sq, axis=1))
# Kozachenko-Leonenko estimator of the entropy (up to irrelevant constant)
return jnp.mean(jnp.log((N - 1) * min_dist ** D))
@jax.jit
def functional(X):
"""Calculates the functional value for a dataset X."""
return -ent(X) + jnp.mean(-log_density(X))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment