This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch.nn import Parameter | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch.nn.init as init | |
import math | |
import numpy as np | |
from torch.autograd import Function | |
try: | |
from apex.normalization import FusedLayerNorm as LayerNorm |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class Axial(nn.Module): | |
def __init__(self, config): | |
super(Axial, self).__init__() | |
self.config = config | |
self._d_embs = config.axial_d_embs # e.g. (128, 384) sum of these two numbers should be equal to d_model | |
self._shape = config.axial_pos_shape # e.g. (64,128) product of these two numbers should be equal to seqlen | |
self.weights = [] | |
for ax, d_emb in enumerate(self._d_embs): | |
ax_shape = [1] * len(self._shape) | |
ax_shape[ax] = self._shape[ax] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def shift_(x): | |
# x = [*, t_q, t_k] | |
zero_pad = torch.zeros(*x.size()[:-1], x.size(-2), device=x.device, dtype=x.dtype) | |
x = torch.cat([x, zero_pad], -1) | |
l = x.size(-1) | |
x = x.view(*x.size()[:-2], -1) | |
zero_pad = torch.zeros(*x.size()[:-1], -x.size(-1) % (l - 1), device=x.device, dtype=x.dtype) | |
return torch.cat([x, zero_pad], -1).view(*x.size()[:-1], -1, l - 1) | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class PN_(torch.autograd.Function): | |
def __init__(self): | |
super(PN_, self).__init__() | |
@staticmethod | |
def forward(ctx, x, states): # x = [b, l, d] | |
eps, psi, nu = states | |
x_hat = x/(psi+eps) | |
ctx.save_for_backward(x_hat, eps, psi, nu) | |
return x_hat |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Trivial application of scatter_add_ to hadamard product and inner product | |
# The following links may be helpful for understanding: | |
# https://github.com/rusty1s/pytorch_scatter | |
# https://pytorch.org/docs/stable/tensors.html#torch.Tensor.scatter_add_ | |
# Generalization to scatter_matmul or scatter_einsum requires custom cuda kernel. | |
# I hope somebody will make it in the future! | |
# Caveat: I found the current PyTorch implementation of scatter_add_ is slower with float16, so make the inputs float32. | |
def scatter_inner_prod(v, w, index, dim1, dim2): |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from functools import partial | |
import torch | |
def _const(example, val): | |
return torch.tensor(val, dtype=example.dtype) | |
def pad(x, axis, side): | |
shape = list(x.size()) | |
if axis == -1: | |
axis = len(shape) - 1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch.nn import Parameter | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch.nn.init as init | |
import math | |
import numpy as np | |
from torch.autograd import Function | |
try: | |
from torch_scatter import scatter |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch.nn import Parameter | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch.nn.init as init | |
import math | |
import numpy as np | |
from torch.autograd import Function | |
#from torch_scatter import scatter |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import types | |
from typing import Any, Optional | |
from .moving_average import ExponentialMovingAverage | |
from flax import linen as nn | |
import jax | |
import jax.numpy as jnp | |
# inspired from Haiku's corresponding code to Flax | |
class VectorQuantizerEMA(nn.Module): |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import os | |
import json | |
import multiprocessing | |
from glob import glob | |
from os.path import join | |
from tqdm.contrib.concurrent import process_map | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--data_path', default='') |
OlderNewer