View running_meters.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Seamless running stats for (native python, numpy.ndarray, torch.tensor).""" | |
from collections import namedtuple | |
class MeanMeter: | |
"""Estimate the mean for a stream of values.""" | |
def __init__(self): | |
"""Initialize the meter.""" |
View covariance.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# https://github.com/pytorch/pytorch/issues/19037 | |
# https://discuss.pytorch.org/t/covariance-and-gradient-support/16217 | |
def cov(tensor, rowvar=True, bias=False): | |
"""Estimate a covariance matrix (np.cov)""" | |
tensor = tensor if rowvar else tensor.transpose(-1, -2) | |
tensor = tensor - tensor.mean(dim=-1, keepdim=True) | |
factor = 1 / (tensor.shape[-1] - int(not bool(bias))) | |
return factor * tensor @ tensor.transpose(-1, -2).conj() |
View frechet.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Frechet's distance between two multi-variate Gaussians""" | |
import torch | |
import torch.nn as nn | |
class FrechetDistance: | |
"""Frechet's distance between two multi-variate Gaussians | |
https://www.sciencedirect.com/science/article/pii/0047259X8290077X | |
""" | |
def __init__(self, double=True, num_iterations=20, eps=1e-12): |
View lambertw.ts
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/** | |
* Lambert W-function when k = 0 | |
* {@link https://gist.github.com/xmodar/baa392fc2bec447d10c2c20bbdcaf687} | |
* {@link https://link.springer.com/content/pdf/10.1007/s10444-017-9530-3.pdf} | |
*/ | |
export function lambertW(x: number, log = false): number { | |
if (log) return lambertWLog(x); // x is actually log(x) | |
if (x >= 0) return lambertWLog(Math.log(x)); // handles [0, Infinity] | |
const xE = x * Math.E; | |
if (isNaN(x) || xE < -1) return NaN; // handles NaN and [-Infinity, -1 / Math.E) |
View python.ts
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/** {@link https://gist.github.com/xmodar/d3a17bf51b8399534c5f8d27104a2a38} */ | |
export const operator = { | |
lt: <T>(a: T, b: T) => a < b, | |
le: <T>(a: T, b: T) => a <= b, | |
eq: <T>(a: T, b: T) => a === b, | |
ne: <T>(a: T, b: T) => a !== b, | |
ge: <T>(a: T, b: T) => a >= b, | |
gt: <T>(a: T, b: T) => a > b, | |
not: <T>(a: T) => !a, | |
abs: (a: number) => Math.abs(a), |
View resnet_svm.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Resnet + SVM""" | |
import torch | |
from torch import nn | |
import torchvision.transforms as T | |
from torchvision import models | |
class SVM(nn.Module): | |
"""Multi-Class SVM with Gaussian Kernel (Radial Basis Function) |
View rng.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
class RNG(): | |
"""Preserve the state of the random number generators of torch | |
https://gist.github.com/ModarTensai/2328b13bdb11c6309ba449195a6b551a | |
Inspired by torch.random.fork_rng(). | |
Seeding random number generators (RNGs): |
View invtorch.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""InvTorch: Core Invertible Utilities https://github.com/xmodar/invtorch""" | |
import itertools | |
import collections | |
import torch | |
from torch import nn | |
import torch.utils.checkpoint | |
__all__ = ['invertible_checkpoint', 'InvertibleModule'] |
View invertible_batchnorm.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Invertible BatchNorm""" | |
import torch | |
from torch import nn | |
class NonZero(nn.Module): | |
"""Parameterization to force the values to be nonzero""" | |
def __init__(self, eps=1e-5, preserve_sign=True): | |
super().__init__() | |
self.eps, self.preserve_sign = eps, preserve_sign |
View deconv.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Deconvolution https://api.semanticscholar.org/CorpusID:208192734""" | |
import torch | |
from torch import nn | |
class Deconv(nn.Module): | |
"""Inverse conv https://gist.github.com/ModarTensai/7921460648230eda5053fe06b7cd2f4d""" | |
def __init__(self, conv, output_padding=0): | |
dim = len(conv.padding) | |
if isinstance(output_padding, int): |
NewerOlder