Skip to content

Instantly share code, notes, and snippets.

@dannys4
Created September 2, 2025 19:07
Show Gist options
  • Select an option

  • Save dannys4/817394000cb75eeef22aefcd2a5645c2 to your computer and use it in GitHub Desktop.

Select an option

Save dannys4/817394000cb75eeef22aefcd2a5645c2 to your computer and use it in GitHub Desktop.
Joker distribution logpdf and sampler
import numpy as np
from dataclasses import dataclass
from functools import partial
from typing import Callable, Any
__any__ = ['logpdf', 'sample']
def joker_eye_logpdf(mean: np.ndarray, inv_std: np.ndarray, x: np.ndarray) -> np.ndarray:
z = (x - mean) @ inv_std
return -0.5*np.sum(np.square(z), axis=-1) + np.log(np.linalg.det(inv_std)) - 0.5*np.sqrt(np.pi*2)
def joker_smile_logpdf(mean: np.ndarray, inv_std: np.ndarray, x: np.ndarray) -> np.ndarray:
z = (x - mean) @ inv_std
z1 = z[..., 0]
z2 = z[..., 1] - z1**2
return -0.5*(z1**2 + z2**2) + np.log(np.linalg.det(inv_std)) - 0.5*np.sqrt(np.pi*2)
@dataclass
class JokerData:
right_eye_shift = np.array([4., 10.])
left_eye_shift = np.array([-5., 10.])
right_eye_inv_stds = np.diag(1 / np.array([2.0, 2.0]))
left_eye_inv_stds = np.diag(1 / np.array([1.0, 3.0]))
smile_inv_stds = np.diag(1 / np.array([5.0, 3.0]))
smile_shift = np.array([0.0, -15.])
face_shift = np.array([0., 0.25])
face_inv_scale = np.linalg.inv(np.array([
[0.197907, 0.000539511],
[0.000539511, 0.0911001]
]))
def joker_logpdf_full(data: JokerData, x):
x = (x - data.face_shift) @ data.face_inv_scale
left_eye_eval = joker_eye_logpdf(
data.left_eye_shift, data.left_eye_inv_stds, x
)
right_eye_eval = joker_eye_logpdf(
data.right_eye_shift, data.right_eye_inv_stds, x
)
smile_eval = joker_smile_logpdf(data.smile_shift, data.smile_inv_stds, x)
eval_shift = np.max(
np.array([
left_eye_eval.max(axis=-1),
right_eye_eval.max(axis=-1),
smile_eval.max(axis=-1)
]), axis=0
)
left_eye_eval -= eval_shift
right_eye_eval -= eval_shift
smile_eval -= eval_shift
mix_pdf = np.exp(left_eye_eval) + \
np.exp(right_eye_eval) + np.exp(smile_eval)
return np.nan_to_num(np.log(mix_pdf) + eval_shift, nan=-1000)
def sample_eye(std, shift, sample):
return sample @ std + shift
def sample_mouth(std, shift, samples):
z = samples
x_norm = np.column_stack((z[..., 0], z[..., 1] + z[..., 0]**2))
x = x_norm @ std + shift
return x
def JokerSampler(data: JokerData) -> Callable[[Any, int], Any]:
left_std = np.linalg.inv(data.left_eye_inv_stds)
right_std = np.linalg.inv(data.right_eye_inv_stds)
smile_std = np.linalg.inv(data.smile_inv_stds)
face_scale = np.linalg.inv(data.face_inv_scale)
left_eye = partial(sample_eye, left_std, data.left_eye_shift)
right_eye = partial(sample_eye, right_std, data.right_eye_shift)
mouth = partial(sample_mouth, smile_std, data.smile_shift)
def sampler(rng: np.random.Generator, N_samples: int) -> np.ndarray:
z_samples = rng.normal(size=(N_samples, 2))
which_modes = rng.integers(3, size=N_samples)
mode_0 = left_eye(z_samples[which_modes == 0])
mode_1 = right_eye(z_samples[which_modes == 1])
mode_2 = mouth(z_samples[which_modes == 2])
return np.concat((mode_0, mode_1, mode_2)) @ face_scale + data.face_shift
return sampler
logpdf: Callable[[np.ndarray], np.ndarray] = partial(
joker_logpdf_full, JokerData())
sample: Callable[[np.random.Generator, int],
np.ndarray] = JokerSampler(JokerData())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment