Skip to content

Instantly share code, notes, and snippets.

View lucidrains's full-sized avatar

Phil Wang lucidrains

View GitHub Profile
ahennequ /
Created September 15, 2022 18:04
Use this program to find out about tensor core's accumulator warp register layout
#include <stdio.h>
// Check tensor core's warp register layout
// nvcc -arch=sm_75 -o mapping
// ./mapping
// Define some error checking macros.
#define cudaErrCheck(stat) { cudaErrCheck_((stat), __FILE__, __LINE__); }
void cudaErrCheck_(cudaError_t stat, const char *file, int line) {
if (stat != cudaSuccess) {
rom1504 /
Last active February 22, 2024 20:25
distributed dalle2 laion

Issue title: (working implementation) Fused multi-head attention for arbitrary sequence lengths.

TL;DR you can run multi-head attention (fwd+bwd) faster and with no extra memory – with any sequence length and head dim. We’d love to make it available via apex. We need your advice on how best to do that.

Why should I care? Here's how it compares against the standard multihead attention (blue) for one multi-head attention layer of GPT-J on an RTX 3080Ti.

time, with backward (ms) peak vram allocated (mb)
rmrao /
Last active October 26, 2021 14:12
Initial implementation of vecorized smithwaterman in torch
from typing import Tuple
import torch
import torch.nn.functional as F
import itertools
from timeit import default_timer as timer
class SoftmaxWeightedMean(torch.autograd.Function):
import torch
import torch.utils.dlpack
import jax
import jax.dlpack
# A generic mechanism for turning a JAX function into a PyTorch function.
def j2t(x_jax):
x_torch = torch.utils.dlpack.from_dlpack(jax.dlpack.to_dlpack(x_jax))
return x_torch
crowsonkb /
Created June 4, 2021 16:56
Complex momentum SGD and Adam. See
"""Complex momentum SGD and Adam. See"""
import math
import torch
from torch import optim
class ComplexSGD(optim.Optimizer):
def __init__(self, params, lr=1e-2, momentum=0.9, angle=math.pi / 8, weight_decay=0.):
AranKomat /
Last active November 11, 2021 02:52
Log-linear version of cumsum and cumprod
from functools import partial
import torch
def _const(example, val):
return torch.tensor(val, dtype=example.dtype)
def pad(x, axis, side):
shape = list(x.size())
if axis == -1:
axis = len(shape) - 1