Instantly share code, notes, and snippets.

Last active December 25, 2023 07:37
torch.cov and torch.corrcoef equivalent to np.cov and np.corrcoef, respectively with gradient support.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
 # https://github.com/pytorch/pytorch/issues/19037 # https://discuss.pytorch.org/t/covariance-and-gradient-support/16217 def cov(tensor, rowvar=True, bias=False): """Estimate a covariance matrix (np.cov)""" tensor = tensor if rowvar else tensor.transpose(-1, -2) tensor = tensor - tensor.mean(dim=-1, keepdim=True) factor = 1 / (tensor.shape[-1] - int(not bool(bias))) return factor * tensor @ tensor.transpose(-1, -2).conj()
Last active February 10, 2023 17:49
Seamless running stats for (native python, numpy.ndarray, torch.tensor). Also see: https://gist.github.com/davidbau/00a9b6763a260be8274f6ba22df9a145
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
 """Seamless running stats for (native python, numpy.ndarray, torch.tensor).""" from collections import namedtuple class MeanMeter: """Estimate the mean for a stream of values.""" def __init__(self): """Initialize the meter."""
Last active August 8, 2022 18:54
Frechet's distance entirely in PyTorch with data batches streaming support.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
 """Frechet's distance between two multi-variate Gaussians""" import torch import torch.nn as nn class FrechetDistance: """Frechet's distance between two multi-variate Gaussians https://www.sciencedirect.com/science/article/pii/0047259X8290077X """ def __init__(self, double=True, num_iterations=20, eps=1e-12):
Last active April 5, 2022 04:32
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
 /** * Lambert W-function when k = 0 * {@link https://gist.github.com/xmodar/baa392fc2bec447d10c2c20bbdcaf687} * {@link https://link.springer.com/content/pdf/10.1007/s10444-017-9530-3.pdf} */ export function lambertW(x: number, log = false): number { if (log) return lambertWLog(x); // x is actually log(x) if (x >= 0) return lambertWLog(Math.log(x)); // handles [0, Infinity] const xE = x * Math.E; if (isNaN(x) || xE < -1) return NaN; // handles NaN and [-Infinity, -1 / Math.E)
Last active March 21, 2022 01:53
JavaScript utilities to mimic Python
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
 /** {@link https://gist.github.com/xmodar/d3a17bf51b8399534c5f8d27104a2a38} */ export const operator = { lt: (a: T, b: T) => a < b, le: (a: T, b: T) => a <= b, eq: (a: T, b: T) => a === b, ne: (a: T, b: T) => a !== b, ge: (a: T, b: T) => a >= b, gt: (a: T, b: T) => a > b, not: (a: T) => !a, abs: (a: number) => Math.abs(a),
Created February 10, 2022 20:59
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
 """Resnet + SVM""" import torch from torch import nn import torchvision.transforms as T from torchvision import models class SVM(nn.Module): """Multi-Class SVM with Gaussian Kernel (Radial Basis Function)
Last active November 20, 2021 03:44
A more flexible context manager than `torch.random.fork_rng()` to preserve the state of the random number generator in PyTorch for the desired devices.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
 import torch class RNG(): """Preserve the state of the random number generators of torch https://gist.github.com/ModarTensai/2328b13bdb11c6309ba449195a6b551a Inspired by torch.random.fork_rng(). Seeding random number generators (RNGs):
Last active November 17, 2021 15:38
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
 """InvTorch: Core Invertible Utilities https://github.com/xmodar/invtorch""" import itertools import collections import torch from torch import nn import torch.utils.checkpoint __all__ = ['invertible_checkpoint', 'InvertibleModule']
Created November 3, 2021 23:24
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
 """Invertible BatchNorm""" import torch from torch import nn class NonZero(nn.Module): """Parameterization to force the values to be nonzero""" def __init__(self, eps=1e-5, preserve_sign=True): super().__init__() self.eps, self.preserve_sign = eps, preserve_sign
Last active October 25, 2021 17:47
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
 """Deconvolution https://api.semanticscholar.org/CorpusID:208192734""" import torch from torch import nn class Deconv(nn.Module): """Inverse conv https://gist.github.com/ModarTensai/7921460648230eda5053fe06b7cd2f4d""" def __init__(self, conv, output_padding=0): dim = len(conv.padding) if isinstance(output_padding, int):