Created
April 6, 2023 21:46
-
-
Save mattf1n/514eaeb27cbce1ed038bc83e00c81f07 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import matplotlib.pyplot as plt | |
import numpy as np | |
import torch | |
from matplotlib import animation | |
from simplex import Composition | |
import operator | |
dim = 4 | |
n = 20 | |
logits = torch.tensor( | |
np.random.uniform(low=1, high=1000, size=(n, dim)), dtype=torch.float32 | |
) | |
comps = list(map(Composition.from_logits, logits)) | |
def p_to_xyz(p): | |
s3 = 1 / np.sqrt(3.0) | |
s6 = 1 / np.sqrt(6.0) | |
x = -1 * p[0] + 1 * p[1] + 0 * p[2] + 0 * p[3] | |
y = -s3 * p[0] - s3 * p[1] + 2 * s3 * p[2] + 0 * p[3] | |
z = -s3 * p[0] - s3 * p[1] - s3 * p[2] + 3 * s6 * p[3] | |
return x, y, z | |
fig = plt.figure() | |
ax = fig.add_subplot(projection="3d", xlim=(-1, 1), ylim=(-1, 1), zlim=(-1, 1)) | |
for comp in comps: | |
scalars = np.linspace(0, 2, 2000) | |
scaled_comps = (p_to_xyz(operator.mul(a, comp).probs().numpy()) for a in scalars) | |
x, y, z = zip(*scaled_comps) | |
ax.plot(x, y, z) | |
ax.view_init() | |
ax.set_axis_off() | |
def animate(i): | |
ax.view_init(elev=10.0, azim=i) | |
anim = animation.FuncAnimation( | |
fig, animate, init_func=lambda: (fig,), frames=360, interval=50 | |
) | |
anim.save("simplex.gif", fps=30) | |
plt.show() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
class Composition: | |
def __init__(self, logprobs: torch.Tensor): | |
assert torch.isclose(logprobs.exp().sum(), torch.tensor(1.0), 1e-04) | |
self.logprobs = logprobs | |
def __len__(self): | |
return len(self.logprobs) | |
def __rmul__(self, scalar: float) -> "Composition": | |
return self.from_logits(scalar * self.logprobs) | |
def __truediv__(self, scalar: float) -> "Composition": | |
return self.from_logits(self.logprobs / scalar) | |
def __neg__(self) -> "Composition": | |
return self.from_logits(-self.logprobs) | |
def __add__(self, other: "Composition") -> "Composition": | |
return self.from_logits(self.logprobs + other.logprobs) | |
def __sub__(self, other: "Composition") -> "Composition": | |
return self.from_logits(self.logprobs - other.logprobs) | |
def __matmul__(self, other: "Composition") -> float: | |
return self.clr() @ other.clr() | |
def __repr__(self) -> str: | |
return f"Composition({self.logprobs.exp()})" | |
def norm(self) -> float: | |
return torch.sqrt(self @ self) | |
@staticmethod | |
def zero(dim: int) -> "Composition": | |
return Composition.from_logits(torch.ones(dim)) | |
def unit(self) -> "Composition": | |
if self.norm() == 0: | |
raise ValueError("The 0 vector has no direction.") | |
else: | |
return self / self.norm() | |
def clr(self) -> torch.Tensor: | |
return self.logprobs - self.logprobs.mean() | |
@staticmethod | |
def from_logits(logits) -> "Composition": | |
return Composition(logits - torch.logsumexp(logits, 0)) | |
def angle(self, other): | |
return torch.arccos(torch.clip(self.unit() @ other.unit(), -1.0, 1.0)) | |
def probs(self) -> torch.Tensor: | |
return self.logprobs.exp() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment