Skip to content

Instantly share code, notes, and snippets.

View xmodar's full-sized avatar

Modar M. Alfadly xmodar

View GitHub Profile
@xmodar
xmodar / covariance.py
Last active December 25, 2023 07:37
torch.cov and torch.corrcoef equivalent to np.cov and np.corrcoef, respectively with gradient support.
# https://github.com/pytorch/pytorch/issues/19037
# https://discuss.pytorch.org/t/covariance-and-gradient-support/16217
def cov(tensor, rowvar=True, bias=False):
"""Estimate a covariance matrix (np.cov)"""
tensor = tensor if rowvar else tensor.transpose(-1, -2)
tensor = tensor - tensor.mean(dim=-1, keepdim=True)
factor = 1 / (tensor.shape[-1] - int(not bool(bias)))
return factor * tensor @ tensor.transpose(-1, -2).conj()
@xmodar
xmodar / running_meters.py
Last active February 10, 2023 17:49
Seamless running stats for (native python, numpy.ndarray, torch.tensor). Also see: https://gist.github.com/davidbau/00a9b6763a260be8274f6ba22df9a145
"""Seamless running stats for (native python, numpy.ndarray, torch.tensor)."""
from collections import namedtuple
class MeanMeter:
"""Estimate the mean for a stream of values."""
def __init__(self):
"""Initialize the meter."""
@xmodar
xmodar / frechet.py
Last active August 8, 2022 18:54
Frechet's distance entirely in PyTorch with data batches streaming support.
"""Frechet's distance between two multi-variate Gaussians"""
import torch
import torch.nn as nn
class FrechetDistance:
"""Frechet's distance between two multi-variate Gaussians
https://www.sciencedirect.com/science/article/pii/0047259X8290077X
"""
def __init__(self, double=True, num_iterations=20, eps=1e-12):
/**
* Lambert W-function when k = 0
* {@link https://gist.github.com/xmodar/baa392fc2bec447d10c2c20bbdcaf687}
* {@link https://link.springer.com/content/pdf/10.1007/s10444-017-9530-3.pdf}
*/
export function lambertW(x: number, log = false): number {
if (log) return lambertWLog(x); // x is actually log(x)
if (x >= 0) return lambertWLog(Math.log(x)); // handles [0, Infinity]
const xE = x * Math.E;
if (isNaN(x) || xE < -1) return NaN; // handles NaN and [-Infinity, -1 / Math.E)
@xmodar
xmodar / python.ts
Last active March 21, 2022 01:53
JavaScript utilities to mimic Python
/** {@link https://gist.github.com/xmodar/d3a17bf51b8399534c5f8d27104a2a38} */
export const operator = {
lt: <T>(a: T, b: T) => a < b,
le: <T>(a: T, b: T) => a <= b,
eq: <T>(a: T, b: T) => a === b,
ne: <T>(a: T, b: T) => a !== b,
ge: <T>(a: T, b: T) => a >= b,
gt: <T>(a: T, b: T) => a > b,
not: <T>(a: T) => !a,
abs: (a: number) => Math.abs(a),
"""Resnet + SVM"""
import torch
from torch import nn
import torchvision.transforms as T
from torchvision import models
class SVM(nn.Module):
"""Multi-Class SVM with Gaussian Kernel (Radial Basis Function)
@xmodar
xmodar / rng.py
Last active November 20, 2021 03:44
A more flexible context manager than `torch.random.fork_rng()` to preserve the state of the random number generator in PyTorch for the desired devices.
import torch
class RNG():
"""Preserve the state of the random number generators of torch
https://gist.github.com/ModarTensai/2328b13bdb11c6309ba449195a6b551a
Inspired by torch.random.fork_rng().
Seeding random number generators (RNGs):
@xmodar
xmodar / invtorch.py
Last active November 17, 2021 15:38
"""InvTorch: Core Invertible Utilities https://github.com/xmodar/invtorch"""
import itertools
import collections
import torch
from torch import nn
import torch.utils.checkpoint
__all__ = ['invertible_checkpoint', 'InvertibleModule']
"""Invertible BatchNorm"""
import torch
from torch import nn
class NonZero(nn.Module):
"""Parameterization to force the values to be nonzero"""
def __init__(self, eps=1e-5, preserve_sign=True):
super().__init__()
self.eps, self.preserve_sign = eps, preserve_sign
@xmodar
xmodar / deconv.py
Last active October 25, 2021 17:47
"""Deconvolution https://api.semanticscholar.org/CorpusID:208192734"""
import torch
from torch import nn
class Deconv(nn.Module):
"""Inverse conv https://gist.github.com/ModarTensai/7921460648230eda5053fe06b7cd2f4d"""
def __init__(self, conv, output_padding=0):
dim = len(conv.padding)
if isinstance(output_padding, int):