Skip to content

Instantly share code, notes, and snippets.

Created April 6, 2023 21:46
Show Gist options
  • Save mattf1n/514eaeb27cbce1ed038bc83e00c81f07 to your computer and use it in GitHub Desktop.
Save mattf1n/514eaeb27cbce1ed038bc83e00c81f07 to your computer and use it in GitHub Desktop.
import matplotlib.pyplot as plt
import numpy as np
import torch
from matplotlib import animation
from simplex import Composition
import operator
dim = 4
n = 20
logits = torch.tensor(
np.random.uniform(low=1, high=1000, size=(n, dim)), dtype=torch.float32
comps = list(map(Composition.from_logits, logits))
def p_to_xyz(p):
s3 = 1 / np.sqrt(3.0)
s6 = 1 / np.sqrt(6.0)
x = -1 * p[0] + 1 * p[1] + 0 * p[2] + 0 * p[3]
y = -s3 * p[0] - s3 * p[1] + 2 * s3 * p[2] + 0 * p[3]
z = -s3 * p[0] - s3 * p[1] - s3 * p[2] + 3 * s6 * p[3]
return x, y, z
fig = plt.figure()
ax = fig.add_subplot(projection="3d", xlim=(-1, 1), ylim=(-1, 1), zlim=(-1, 1))
for comp in comps:
scalars = np.linspace(0, 2, 2000)
scaled_comps = (p_to_xyz(operator.mul(a, comp).probs().numpy()) for a in scalars)
x, y, z = zip(*scaled_comps)
ax.plot(x, y, z)
def animate(i):
ax.view_init(elev=10.0, azim=i)
anim = animation.FuncAnimation(
fig, animate, init_func=lambda: (fig,), frames=360, interval=50
)"simplex.gif", fps=30)
import torch
class Composition:
def __init__(self, logprobs: torch.Tensor):
assert torch.isclose(logprobs.exp().sum(), torch.tensor(1.0), 1e-04)
self.logprobs = logprobs
def __len__(self):
return len(self.logprobs)
def __rmul__(self, scalar: float) -> "Composition":
return self.from_logits(scalar * self.logprobs)
def __truediv__(self, scalar: float) -> "Composition":
return self.from_logits(self.logprobs / scalar)
def __neg__(self) -> "Composition":
return self.from_logits(-self.logprobs)
def __add__(self, other: "Composition") -> "Composition":
return self.from_logits(self.logprobs + other.logprobs)
def __sub__(self, other: "Composition") -> "Composition":
return self.from_logits(self.logprobs - other.logprobs)
def __matmul__(self, other: "Composition") -> float:
return self.clr() @ other.clr()
def __repr__(self) -> str:
return f"Composition({self.logprobs.exp()})"
def norm(self) -> float:
return torch.sqrt(self @ self)
def zero(dim: int) -> "Composition":
return Composition.from_logits(torch.ones(dim))
def unit(self) -> "Composition":
if self.norm() == 0:
raise ValueError("The 0 vector has no direction.")
return self / self.norm()
def clr(self) -> torch.Tensor:
return self.logprobs - self.logprobs.mean()
def from_logits(logits) -> "Composition":
return Composition(logits - torch.logsumexp(logits, 0))
def angle(self, other):
return torch.arccos(torch.clip(self.unit() @ other.unit(), -1.0, 1.0))
def probs(self) -> torch.Tensor:
return self.logprobs.exp()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment