Skip to content

Instantly share code, notes, and snippets.

import math
import jax
import jax.numpy as jnp
from flax import linen as nn
from flax.core import freeze, unfreeze
from mingpt.utils import CfgNode as CN
# -----------------------------------------------------------------------------
import jax
import jax.numpy as jnp
from functools import partial
from jax import vmap
def scatter(input, dim, index, src, reduce=None):
# Works like PyTorch's scatter. See https://pytorch.org/docs/stable/generated/torch.Tensor.scatter_.html
dnums = jax.lax.ScatterDimensionNumbers(update_window_dims=(), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,))
import io
import zipfile
from pathlib import Path
import threading as th
from multiprocessing.pool import ThreadPool
import cv2
import urllib
from contextlib import contextmanager
from datadings.tools.cached_property import cached_property
from string import hexdigits
import os
import json
from time import time
from time import time
t = time()
TXT_PATH = "./80m-dataset/img"
paths = []
max_num = 90000000
max_idx = max_num // 1000
import argparse
import os
import json
import multiprocessing
from glob import glob
from os.path import join
from tqdm.contrib.concurrent import process_map
parser = argparse.ArgumentParser()
parser.add_argument('--data_path', default='')
import types
from typing import Any, Optional
from .moving_average import ExponentialMovingAverage
from flax import linen as nn
import jax
import jax.numpy as jnp
# inspired from Haiku's corresponding code to Flax
class VectorQuantizerEMA(nn.Module):
@AranKomat
AranKomat / arch.py
Created August 13, 2020 09:41
Incomplete implmenetation of extended MARGE architecture
import torch
from torch.nn import Parameter
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
import math
import numpy as np
from torch.autograd import Function
#from torch_scatter import scatter
import torch
from torch.nn import Parameter
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
import math
import numpy as np
from torch.autograd import Function
try:
from torch_scatter import scatter
@AranKomat
AranKomat / log_linear_cumsum_prod.py
Last active November 11, 2021 02:52
Log-linear version of cumsum and cumprod
from functools import partial
import torch
def _const(example, val):
return torch.tensor(val, dtype=example.dtype)
def pad(x, axis, side):
shape = list(x.size())
if axis == -1:
axis = len(shape) - 1
# Trivial application of scatter_add_ to hadamard product and inner product
# The following links may be helpful for understanding:
# https://github.com/rusty1s/pytorch_scatter
# https://pytorch.org/docs/stable/tensors.html#torch.Tensor.scatter_add_
# Generalization to scatter_matmul or scatter_einsum requires custom cuda kernel.
# I hope somebody will make it in the future!
# Caveat: I found the current PyTorch implementation of scatter_add_ is slower with float16, so make the inputs float32.
def scatter_inner_prod(v, w, index, dim1, dim2):