Skip to content

Instantly share code, notes, and snippets.

View spaceybread's full-sized avatar
👾
...

spaceybread spaceybread

👾
...
View GitHub Profile
class GPT(nn.Module):
# omitted forward and init code
@torch.no_grad()
def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
for _ in range(max_new_tokens):
idx_cond = idx[:, -self.block_size:]
logits, _ = self(idx_cond)
logits = logits[:, -1, :] / temperature
class Head(nn.Module):
# omitted init code
def forward(self, x):
B, T, C = x.shape
k = self.key(x)
q = self.query(x)
wei = q @ k.transpose(-2, -1) * k.shape[-1]**-0.5
wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
dead = self.cluster_size < 1.0
if dead.any():
n_dead = int(dead.sum().item())
random_idx = torch.randint(0, z_flat.size(0), (n_dead,), device=z_flat.device)
noise = torch.randn(n_dead, self.embedding_dim, device=z_flat.device) * 0.01
self.codebook[dead].copy_(z_flat[random_idx] + noise)
self.cluster_size[dead].fill_(1.0)
self.embed_avg[dead].copy_(self.codebook[dead])
class VectorQuantizer(nn.Module):
# omitted init code
def forward(self, z):
B, C, D, H, W = z.shape
z_flat = z.permute(0,2,3,4,1).reshape(-1, C).detach()
z_sq = (z_flat ** 2).sum(1, keepdim=True)
cb_sq = (self.codebook ** 2).sum(1, keepdim=True).T
import matplotlib.pyplot as plt
from collections import Counter
import random
class Coin:
def __init__(self, isFair = True):
self.isFair = isFair
def flip(self):
if self.isFair:
@spaceybread
spaceybread / binPrime.py
Created March 19, 2024 05:16
A019565, A087006, A261144
def base(n, b):
if n == 0:
return 0
result = ""
while n > 0:
r = n % b
result = str(r) + result
n //= b