Last active December 25, 2023 07:37
torch.cov and torch.corrcoef equivalent to np.cov and np.corrcoef, respectively with gradient support.
 # https://github.com/pytorch/pytorch/issues/19037 # https://discuss.pytorch.org/t/covariance-and-gradient-support/16217 def cov(tensor, rowvar=True, bias=False): """Estimate a covariance matrix (np.cov)""" tensor = tensor if rowvar else tensor.transpose(-1, -2) tensor = tensor - tensor.mean(dim=-1, keepdim=True) factor = 1 / (tensor.shape[-1] - int(not bool(bias))) return factor * tensor @ tensor.transpose(-1, -2).conj()
Last active February 10, 2023 17:49
Seamless running stats for (native python, numpy.ndarray, torch.tensor). Also see: https://gist.github.com/davidbau/00a9b6763a260be8274f6ba22df9a145
 """Seamless running stats for (native python, numpy.ndarray, torch.tensor).""" from collections import namedtuple class MeanMeter: """Estimate the mean for a stream of values.""" def __init__(self): """Initialize the meter."""
Last active August 8, 2022 18:54
Frechet's distance entirely in PyTorch with data batches streaming support.
 """Frechet's distance between two multi-variate Gaussians""" import torch import torch.nn as nn class FrechetDistance: """Frechet's distance between two multi-variate Gaussians https://www.sciencedirect.com/science/article/pii/0047259X8290077X """ def __init__(self, double=True, num_iterations=20, eps=1e-12):
Last active April 5, 2022 04:32
 /** * Lambert W-function when k = 0 * {@link https://gist.github.com/xmodar/baa392fc2bec447d10c2c20bbdcaf687} * {@link https://link.springer.com/content/pdf/10.1007/s10444-017-9530-3.pdf} */ export function lambertW(x: number, log = false): number { if (log) return lambertWLog(x); // x is actually log(x) if (x >= 0) return lambertWLog(Math.log(x)); // handles [0, Infinity] const xE = x * Math.E; if (isNaN(x) || xE < -1) return NaN; // handles NaN and [-Infinity, -1 / Math.E)
Last active March 21, 2022 01:53
JavaScript utilities to mimic Python
 /** {@link https://gist.github.com/xmodar/d3a17bf51b8399534c5f8d27104a2a38} */ export const operator = { lt: (a: T, b: T) => a < b, le: (a: T, b: T) => a <= b, eq: (a: T, b: T) => a === b, ne: (a: T, b: T) => a !== b, ge: (a: T, b: T) => a >= b, gt: (a: T, b: T) => a > b, not: (a: T) => !a, abs: (a: number) => Math.abs(a),
Created February 10, 2022 20:59
 """Resnet + SVM""" import torch from torch import nn import torchvision.transforms as T from torchvision import models class SVM(nn.Module): """Multi-Class SVM with Gaussian Kernel (Radial Basis Function)
Last active November 20, 2021 03:44
A more flexible context manager than `torch.random.fork_rng()` to preserve the state of the random number generator in PyTorch for the desired devices.
 import torch class RNG(): """Preserve the state of the random number generators of torch https://gist.github.com/ModarTensai/2328b13bdb11c6309ba449195a6b551a Inspired by torch.random.fork_rng(). Seeding random number generators (RNGs):
Last active November 17, 2021 15:38
 """InvTorch: Core Invertible Utilities https://github.com/xmodar/invtorch""" import itertools import collections import torch from torch import nn import torch.utils.checkpoint __all__ = ['invertible_checkpoint', 'InvertibleModule']
Created November 3, 2021 23:24
 """Invertible BatchNorm""" import torch from torch import nn class NonZero(nn.Module): """Parameterization to force the values to be nonzero""" def __init__(self, eps=1e-5, preserve_sign=True): super().__init__() self.eps, self.preserve_sign = eps, preserve_sign
Last active October 25, 2021 17:47
 """Deconvolution https://api.semanticscholar.org/CorpusID:208192734""" import torch from torch import nn class Deconv(nn.Module): """Inverse conv https://gist.github.com/ModarTensai/7921460648230eda5053fe06b7cd2f4d""" def __init__(self, conv, output_padding=0): dim = len(conv.padding) if isinstance(output_padding, int):