Skip to content

Instantly share code, notes, and snippets.

@mattf1n
Created April 6, 2023 21:46
Show Gist options
  • Save mattf1n/514eaeb27cbce1ed038bc83e00c81f07 to your computer and use it in GitHub Desktop.
Save mattf1n/514eaeb27cbce1ed038bc83e00c81f07 to your computer and use it in GitHub Desktop.
import matplotlib.pyplot as plt
import numpy as np
import torch
from matplotlib import animation
from simplex import Composition
import operator
dim = 4
n = 20
logits = torch.tensor(
np.random.uniform(low=1, high=1000, size=(n, dim)), dtype=torch.float32
)
comps = list(map(Composition.from_logits, logits))
def p_to_xyz(p):
s3 = 1 / np.sqrt(3.0)
s6 = 1 / np.sqrt(6.0)
x = -1 * p[0] + 1 * p[1] + 0 * p[2] + 0 * p[3]
y = -s3 * p[0] - s3 * p[1] + 2 * s3 * p[2] + 0 * p[3]
z = -s3 * p[0] - s3 * p[1] - s3 * p[2] + 3 * s6 * p[3]
return x, y, z
fig = plt.figure()
ax = fig.add_subplot(projection="3d", xlim=(-1, 1), ylim=(-1, 1), zlim=(-1, 1))
for comp in comps:
scalars = np.linspace(0, 2, 2000)
scaled_comps = (p_to_xyz(operator.mul(a, comp).probs().numpy()) for a in scalars)
x, y, z = zip(*scaled_comps)
ax.plot(x, y, z)
ax.view_init()
ax.set_axis_off()
def animate(i):
ax.view_init(elev=10.0, azim=i)
anim = animation.FuncAnimation(
fig, animate, init_func=lambda: (fig,), frames=360, interval=50
)
anim.save("simplex.gif", fps=30)
plt.show()
import torch
class Composition:
def __init__(self, logprobs: torch.Tensor):
assert torch.isclose(logprobs.exp().sum(), torch.tensor(1.0), 1e-04)
self.logprobs = logprobs
def __len__(self):
return len(self.logprobs)
def __rmul__(self, scalar: float) -> "Composition":
return self.from_logits(scalar * self.logprobs)
def __truediv__(self, scalar: float) -> "Composition":
return self.from_logits(self.logprobs / scalar)
def __neg__(self) -> "Composition":
return self.from_logits(-self.logprobs)
def __add__(self, other: "Composition") -> "Composition":
return self.from_logits(self.logprobs + other.logprobs)
def __sub__(self, other: "Composition") -> "Composition":
return self.from_logits(self.logprobs - other.logprobs)
def __matmul__(self, other: "Composition") -> float:
return self.clr() @ other.clr()
def __repr__(self) -> str:
return f"Composition({self.logprobs.exp()})"
def norm(self) -> float:
return torch.sqrt(self @ self)
@staticmethod
def zero(dim: int) -> "Composition":
return Composition.from_logits(torch.ones(dim))
def unit(self) -> "Composition":
if self.norm() == 0:
raise ValueError("The 0 vector has no direction.")
else:
return self / self.norm()
def clr(self) -> torch.Tensor:
return self.logprobs - self.logprobs.mean()
@staticmethod
def from_logits(logits) -> "Composition":
return Composition(logits - torch.logsumexp(logits, 0))
def angle(self, other):
return torch.arccos(torch.clip(self.unit() @ other.unit(), -1.0, 1.0))
def probs(self) -> torch.Tensor:
return self.logprobs.exp()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment