This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import math | |
import jax | |
import jax.numpy as jnp | |
from flax import linen as nn | |
from flax.core import freeze, unfreeze | |
from mingpt.utils import CfgNode as CN | |
# ----------------------------------------------------------------------------- |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import jax | |
import jax.numpy as jnp | |
from functools import partial | |
from jax import vmap | |
def scatter(input, dim, index, src, reduce=None): | |
# Works like PyTorch's scatter. See https://pytorch.org/docs/stable/generated/torch.Tensor.scatter_.html | |
dnums = jax.lax.ScatterDimensionNumbers(update_window_dims=(), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,)) | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import io | |
import zipfile | |
from pathlib import Path | |
import threading as th | |
from multiprocessing.pool import ThreadPool | |
import cv2 | |
import urllib | |
from contextlib import contextmanager | |
from datadings.tools.cached_property import cached_property | |
from string import hexdigits |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import json | |
from time import time | |
from time import time | |
t = time() | |
TXT_PATH = "./80m-dataset/img" | |
paths = [] | |
max_num = 90000000 | |
max_idx = max_num // 1000 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import os | |
import json | |
import multiprocessing | |
from glob import glob | |
from os.path import join | |
from tqdm.contrib.concurrent import process_map | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--data_path', default='') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import types | |
from typing import Any, Optional | |
from .moving_average import ExponentialMovingAverage | |
from flax import linen as nn | |
import jax | |
import jax.numpy as jnp | |
# inspired from Haiku's corresponding code to Flax | |
class VectorQuantizerEMA(nn.Module): |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch.nn import Parameter | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch.nn.init as init | |
import math | |
import numpy as np | |
from torch.autograd import Function | |
#from torch_scatter import scatter |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch.nn import Parameter | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch.nn.init as init | |
import math | |
import numpy as np | |
from torch.autograd import Function | |
try: | |
from torch_scatter import scatter |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from functools import partial | |
import torch | |
def _const(example, val): | |
return torch.tensor(val, dtype=example.dtype) | |
def pad(x, axis, side): | |
shape = list(x.size()) | |
if axis == -1: | |
axis = len(shape) - 1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Trivial application of scatter_add_ to hadamard product and inner product | |
# The following links may be helpful for understanding: | |
# https://github.com/rusty1s/pytorch_scatter | |
# https://pytorch.org/docs/stable/tensors.html#torch.Tensor.scatter_add_ | |
# Generalization to scatter_matmul or scatter_einsum requires custom cuda kernel. | |
# I hope somebody will make it in the future! | |
# Caveat: I found the current PyTorch implementation of scatter_add_ is slower with float16, so make the inputs float32. | |
def scatter_inner_prod(v, w, index, dim1, dim2): |
NewerOlder