Skip to content

Instantly share code, notes, and snippets.

import torch
from torch.nn import Parameter
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
import math
import numpy as np
from torch.autograd import Function
try:
from apex.normalization import FusedLayerNorm as LayerNorm
class Axial(nn.Module):
def __init__(self, config):
super(Axial, self).__init__()
self.config = config
self._d_embs = config.axial_d_embs # e.g. (128, 384) sum of these two numbers should be equal to d_model
self._shape = config.axial_pos_shape # e.g. (64,128) product of these two numbers should be equal to seqlen
self.weights = []
for ax, d_emb in enumerate(self._d_embs):
ax_shape = [1] * len(self._shape)
ax_shape[ax] = self._shape[ax]
def shift_(x):
# x = [*, t_q, t_k]
zero_pad = torch.zeros(*x.size()[:-1], x.size(-2), device=x.device, dtype=x.dtype)
x = torch.cat([x, zero_pad], -1)
l = x.size(-1)
x = x.view(*x.size()[:-2], -1)
zero_pad = torch.zeros(*x.size()[:-1], -x.size(-1) % (l - 1), device=x.device, dtype=x.dtype)
return torch.cat([x, zero_pad], -1).view(*x.size()[:-1], -1, l - 1)
class PN_(torch.autograd.Function):
def __init__(self):
super(PN_, self).__init__()
@staticmethod
def forward(ctx, x, states): # x = [b, l, d]
eps, psi, nu = states
x_hat = x/(psi+eps)
ctx.save_for_backward(x_hat, eps, psi, nu)
return x_hat
# Trivial application of scatter_add_ to hadamard product and inner product
# The following links may be helpful for understanding:
# https://github.com/rusty1s/pytorch_scatter
# https://pytorch.org/docs/stable/tensors.html#torch.Tensor.scatter_add_
# Generalization to scatter_matmul or scatter_einsum requires custom cuda kernel.
# I hope somebody will make it in the future!
# Caveat: I found the current PyTorch implementation of scatter_add_ is slower with float16, so make the inputs float32.
def scatter_inner_prod(v, w, index, dim1, dim2):
@AranKomat
AranKomat / log_linear_cumsum_prod.py
Last active November 11, 2021 02:52
Log-linear version of cumsum and cumprod
from functools import partial
import torch
def _const(example, val):
return torch.tensor(val, dtype=example.dtype)
def pad(x, axis, side):
shape = list(x.size())
if axis == -1:
axis = len(shape) - 1
import torch
from torch.nn import Parameter
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
import math
import numpy as np
from torch.autograd import Function
try:
from torch_scatter import scatter
@AranKomat
AranKomat / arch.py
Created August 13, 2020 09:41
Incomplete implmenetation of extended MARGE architecture
import torch
from torch.nn import Parameter
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
import math
import numpy as np
from torch.autograd import Function
#from torch_scatter import scatter
import types
from typing import Any, Optional
from .moving_average import ExponentialMovingAverage
from flax import linen as nn
import jax
import jax.numpy as jnp
# inspired from Haiku's corresponding code to Flax
class VectorQuantizerEMA(nn.Module):
import argparse
import os
import json
import multiprocessing
from glob import glob
from os.path import join
from tqdm.contrib.concurrent import process_map
parser = argparse.ArgumentParser()
parser.add_argument('--data_path', default='')