This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# https://github.com/pytorch/pytorch/issues/19037 | |
# https://discuss.pytorch.org/t/covariance-and-gradient-support/16217 | |
def cov(tensor, rowvar=True, bias=False): | |
"""Estimate a covariance matrix (np.cov)""" | |
tensor = tensor if rowvar else tensor.transpose(-1, -2) | |
tensor = tensor - tensor.mean(dim=-1, keepdim=True) | |
factor = 1 / (tensor.shape[-1] - int(not bool(bias))) | |
return factor * tensor @ tensor.transpose(-1, -2).conj() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Seamless running stats for (native python, numpy.ndarray, torch.tensor).""" | |
from collections import namedtuple | |
class MeanMeter: | |
"""Estimate the mean for a stream of values.""" | |
def __init__(self): | |
"""Initialize the meter.""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Frechet's distance between two multi-variate Gaussians""" | |
import torch | |
import torch.nn as nn | |
class FrechetDistance: | |
"""Frechet's distance between two multi-variate Gaussians | |
https://www.sciencedirect.com/science/article/pii/0047259X8290077X | |
""" | |
def __init__(self, double=True, num_iterations=20, eps=1e-12): |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/** | |
* Lambert W-function when k = 0 | |
* {@link https://gist.github.com/xmodar/baa392fc2bec447d10c2c20bbdcaf687} | |
* {@link https://link.springer.com/content/pdf/10.1007/s10444-017-9530-3.pdf} | |
*/ | |
export function lambertW(x: number, log = false): number { | |
if (log) return lambertWLog(x); // x is actually log(x) | |
if (x >= 0) return lambertWLog(Math.log(x)); // handles [0, Infinity] | |
const xE = x * Math.E; | |
if (isNaN(x) || xE < -1) return NaN; // handles NaN and [-Infinity, -1 / Math.E) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/** {@link https://gist.github.com/xmodar/d3a17bf51b8399534c5f8d27104a2a38} */ | |
export const operator = { | |
lt: <T>(a: T, b: T) => a < b, | |
le: <T>(a: T, b: T) => a <= b, | |
eq: <T>(a: T, b: T) => a === b, | |
ne: <T>(a: T, b: T) => a !== b, | |
ge: <T>(a: T, b: T) => a >= b, | |
gt: <T>(a: T, b: T) => a > b, | |
not: <T>(a: T) => !a, | |
abs: (a: number) => Math.abs(a), |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Resnet + SVM""" | |
import torch | |
from torch import nn | |
import torchvision.transforms as T | |
from torchvision import models | |
class SVM(nn.Module): | |
"""Multi-Class SVM with Gaussian Kernel (Radial Basis Function) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
class RNG(): | |
"""Preserve the state of the random number generators of torch | |
https://gist.github.com/ModarTensai/2328b13bdb11c6309ba449195a6b551a | |
Inspired by torch.random.fork_rng(). | |
Seeding random number generators (RNGs): |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""InvTorch: Core Invertible Utilities https://github.com/xmodar/invtorch""" | |
import itertools | |
import collections | |
import torch | |
from torch import nn | |
import torch.utils.checkpoint | |
__all__ = ['invertible_checkpoint', 'InvertibleModule'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Invertible BatchNorm""" | |
import torch | |
from torch import nn | |
class NonZero(nn.Module): | |
"""Parameterization to force the values to be nonzero""" | |
def __init__(self, eps=1e-5, preserve_sign=True): | |
super().__init__() | |
self.eps, self.preserve_sign = eps, preserve_sign |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Deconvolution https://api.semanticscholar.org/CorpusID:208192734""" | |
import torch | |
from torch import nn | |
class Deconv(nn.Module): | |
"""Inverse conv https://gist.github.com/ModarTensai/7921460648230eda5053fe06b7cd2f4d""" | |
def __init__(self, conv, output_padding=0): | |
dim = len(conv.padding) | |
if isinstance(output_padding, int): |
NewerOlder