This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from PIL import Image | |
from skimage.color import lab2rgb, rgb2lab | |
class RandomColorToning: | |
def __init__(self, scale_mean, scale_std, shift_mean, shift_std): | |
self.scale_mean = scale_mean | |
self.scale_std = scale_std |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from torch import nn | |
class Generator(nn.Module): | |
def __init__(self, input_dim, image_shape, memory): | |
super().__init__() | |
self.memory = memory | |
self.input_dim = input_dim | |
self.image_shape = image_shape |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import json | |
import math | |
from argparse import ArgumentParser | |
from contextlib import contextmanager | |
from pathlib import Path | |
import torch | |
import torchvision.transforms as T | |
from torch import nn | |
from torch.optim import lr_scheduler |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class LARC: | |
"""Layer-wise Adaptive Rate Control. | |
LARC is LARS that supports clipping along with scaling: | |
https://arxiv.org/abs/1708.03888 | |
This implementation is inspired by: | |
https://github.com/NVIDIA/apex/blob/master/apex/parallel/LARC.py | |
See also: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch import nn | |
from torch import functional as F | |
class Expression: | |
def __init__(self, out=None, **units): | |
self.out = out | |
self.terms = {} | |
self.coeffs = {} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from argparse import Namespace | |
import torch | |
def grid_search(objective, *bounds, density=10, eps=1e-5, max_steps=None): | |
"""Perfrom coarse-to-fine grid search for the minimum objective. | |
>>> def f(x, y): | |
>>> x = x + 0.5 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# https://github.com/pytorch/pytorch/issues/19037 | |
# https://discuss.pytorch.org/t/covariance-and-gradient-support/16217 | |
def cov(tensor, rowvar=True, bias=False): | |
"""Estimate a covariance matrix (np.cov)""" | |
tensor = tensor if rowvar else tensor.transpose(-1, -2) | |
tensor = tensor - tensor.mean(dim=-1, keepdim=True) | |
factor = 1 / (tensor.shape[-1] - int(not bool(bias))) | |
return factor * tensor @ tensor.transpose(-1, -2).conj() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
def unravel_index(index, shape): | |
out = [] | |
for dim in reversed(shape): | |
out.append(index % dim) | |
index = index // dim | |
return tuple(reversed(out)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from math import log | |
import torch | |
from torch import nn | |
class L0Sparse(nn.Module): | |
def __init__(self, layer, init_sparsity=0.5, heat=2 / 3, stretch=0.1): | |
assert all(0 < x < 1 for x in [init_sparsity, heat, stretch]) | |
super().__init__() | |
self.layer = layer |
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.