This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Ring attention for PyTorch. | |
See https://github.com/nshepperd/flash_attn_jax/blob/main/src/flash_attn_jax/ring_attention.py. | |
""" | |
import flash_attn.flash_attn_interface as fai | |
import torch | |
import torch.distributed as dist | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Mixture of Softmaxes""" | |
import torch | |
from torch.nn import functional as F | |
class MixtureOfSoftmaxes(torch.autograd.Function): | |
@staticmethod | |
def forward(ctx, x, p): | |
with torch.cuda.amp.autocast(enabled=False): |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Grouped linear layer using https://github.com/tgale96/grouped_gemm.""" | |
from dataclasses import dataclass | |
import warnings | |
import torch | |
from torch import nn | |
try: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Scalar Preference Optimization.""" | |
import torch | |
from torch.nn import functional as F | |
def logp_completion(logits, tokens, mask): | |
"""Compute the log probabilities of completions given their prompts. | |
Args: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""REINFORCE (DiCE) with exponential moving average baseline. Implements "DiCE: The Infinitely | |
Differentiable Monte Carlo Estimator (https://arxiv.org/abs/1802.05098).""" | |
import torch | |
from torch import nn | |
from torch.nn import functional as F | |
from typing import Optional, Union | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import itertools | |
import random | |
class WeightedSampler: | |
"""Samples k elements from a stream of weighted items without replacement. | |
See Weighted Random Sampling (Efraimidis, Spirakis 2005). | |
""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Stochastic beam search. | |
Implements "Stochastic Beams and Where to Find Them: The Gumbel-Top-k Trick for | |
Sampling Sequences Without Replacement" (https://arxiv.org/abs/1903.06059)""" | |
import math | |
import torch | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Vectorial TV loss using higher accuracy order finite difference operators.""" | |
import torch | |
FINITE_DIFFERENCE_COEFFS = { | |
1: torch.tensor([-1, 1]), | |
2: torch.tensor([-3 / 2, 2, -1 / 2]), | |
3: torch.tensor([-11 / 6, 3, -3 / 2, 1 / 3]), | |
4: torch.tensor([-25 / 12, 4, -3, 4 / 3, -1 / 4]), |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
"""Pads an image for Bluesky.""" | |
import argparse | |
import math | |
from pathlib import Path | |
from PIL import Image |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""JAX implementation of the 2D DWT and IDWT.""" | |
from einops import rearrange | |
import jax | |
import jax.numpy as jnp | |
import pywt | |
def get_filter_bank(wavelet, dtype=jnp.float32): | |
"""Get the filter bank for a given pywavelets wavelet name.""" |
NewerOlder