Skip to content

Instantly share code, notes, and snippets.

@crowsonkb
crowsonkb / ring_attn.py
Created October 10, 2024 16:19
Ring attention for PyTorch.
"""Ring attention for PyTorch.
See https://github.com/nshepperd/flash_attn_jax/blob/main/src/flash_attn_jax/ring_attention.py.
"""
import flash_attn.flash_attn_interface as fai
import torch
import torch.distributed as dist
@crowsonkb
crowsonkb / mos.py
Last active April 11, 2024 21:23
Mixture of Softmaxes
"""Mixture of Softmaxes"""
import torch
from torch.nn import functional as F
class MixtureOfSoftmaxes(torch.autograd.Function):
@staticmethod
def forward(ctx, x, p):
with torch.cuda.amp.autocast(enabled=False):
"""Grouped linear layer using https://github.com/tgale96/grouped_gemm."""
from dataclasses import dataclass
import warnings
import torch
from torch import nn
try:
@crowsonkb
crowsonkb / spo_loss.py
Last active June 10, 2024 15:38
Scalar Preference Optimization
"""Scalar Preference Optimization."""
import torch
from torch.nn import functional as F
def logp_completion(logits, tokens, mask):
"""Compute the log probabilities of completions given their prompts.
Args:
@crowsonkb
crowsonkb / reinforce.py
Last active June 30, 2023 19:12
REINFORCE with exponential moving average baseline
"""REINFORCE (DiCE) with exponential moving average baseline. Implements "DiCE: The Infinitely
Differentiable Monte Carlo Estimator (https://arxiv.org/abs/1802.05098)."""
import torch
from torch import nn
from torch.nn import functional as F
from typing import Optional, Union
import itertools
import random
class WeightedSampler:
"""Samples k elements from a stream of weighted items without replacement.
See Weighted Random Sampling (Efraimidis, Spirakis 2005).
"""
"""Stochastic beam search.
Implements "Stochastic Beams and Where to Find Them: The Gumbel-Top-k Trick for
Sampling Sequences Without Replacement" (https://arxiv.org/abs/1903.06059)"""
import math
import torch
"""Vectorial TV loss using higher accuracy order finite difference operators."""
import torch
FINITE_DIFFERENCE_COEFFS = {
1: torch.tensor([-1, 1]),
2: torch.tensor([-3 / 2, 2, -1 / 2]),
3: torch.tensor([-11 / 6, 3, -3 / 2, 1 / 3]),
4: torch.tensor([-25 / 12, 4, -3, 4 / 3, -1 / 4]),
#!/usr/bin/env python3
"""Pads an image for Bluesky."""
import argparse
import math
from pathlib import Path
from PIL import Image
@crowsonkb
crowsonkb / jax_wavelet.py
Last active February 10, 2023 17:56
JAX implementation of the 2D DWT and IDWT.
"""JAX implementation of the 2D DWT and IDWT."""
from einops import rearrange
import jax
import jax.numpy as jnp
import pywt
def get_filter_bank(wavelet, dtype=jnp.float32):
"""Get the filter bank for a given pywavelets wavelet name."""