Skip to content

Instantly share code, notes, and snippets.

@mattf1n
Created April 6, 2023 21:46
import matplotlib.pyplot as plt
import numpy as np
import torch
from matplotlib import animation
from simplex import Composition
import operator
dim = 4
n = 20
logits = torch.tensor(
np.random.uniform(low=1, high=1000, size=(n, dim)), dtype=torch.float32
)
comps = list(map(Composition.from_logits, logits))
def p_to_xyz(p):
s3 = 1 / np.sqrt(3.0)
s6 = 1 / np.sqrt(6.0)
x = -1 * p[0] + 1 * p[1] + 0 * p[2] + 0 * p[3]
y = -s3 * p[0] - s3 * p[1] + 2 * s3 * p[2] + 0 * p[3]
z = -s3 * p[0] - s3 * p[1] - s3 * p[2] + 3 * s6 * p[3]
return x, y, z
fig = plt.figure()
ax = fig.add_subplot(projection="3d", xlim=(-1, 1), ylim=(-1, 1), zlim=(-1, 1))
for comp in comps:
scalars = np.linspace(0, 2, 2000)
scaled_comps = (p_to_xyz(operator.mul(a, comp).probs().numpy()) for a in scalars)
x, y, z = zip(*scaled_comps)
ax.plot(x, y, z)
ax.view_init()
ax.set_axis_off()
def animate(i):
ax.view_init(elev=10.0, azim=i)
anim = animation.FuncAnimation(
fig, animate, init_func=lambda: (fig,), frames=360, interval=50
)
anim.save("simplex.gif", fps=30)
plt.show()
import torch
class Composition:
def __init__(self, logprobs: torch.Tensor):
assert torch.isclose(logprobs.exp().sum(), torch.tensor(1.0), 1e-04)
self.logprobs = logprobs
def __len__(self):
return len(self.logprobs)
def __rmul__(self, scalar: float) -> "Composition":
return self.from_logits(scalar * self.logprobs)
def __truediv__(self, scalar: float) -> "Composition":
return self.from_logits(self.logprobs / scalar)
def __neg__(self) -> "Composition":
return self.from_logits(-self.logprobs)
def __add__(self, other: "Composition") -> "Composition":
return self.from_logits(self.logprobs + other.logprobs)
def __sub__(self, other: "Composition") -> "Composition":
return self.from_logits(self.logprobs - other.logprobs)
def __matmul__(self, other: "Composition") -> float:
return self.clr() @ other.clr()
def __repr__(self) -> str:
return f"Composition({self.logprobs.exp()})"
def norm(self) -> float:
return torch.sqrt(self @ self)
@staticmethod
def zero(dim: int) -> "Composition":
return Composition.from_logits(torch.ones(dim))
def unit(self) -> "Composition":
if self.norm() == 0:
raise ValueError("The 0 vector has no direction.")
else:
return self / self.norm()
def clr(self) -> torch.Tensor:
return self.logprobs - self.logprobs.mean()
@staticmethod
def from_logits(logits) -> "Composition":
return Composition(logits - torch.logsumexp(logits, 0))
def angle(self, other):
return torch.arccos(torch.clip(self.unit() @ other.unit(), -1.0, 1.0))
def probs(self) -> torch.Tensor:
return self.logprobs.exp()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment