Skip to content

Instantly share code, notes, and snippets.

@louity
louity / poisson_fft.py
Created October 13, 2022 08:55
Solve poisson eq. on double periodic domain
import matplotlib.pyplot as plt
import torch
def laplacian_per(f, dx, dy):
f_per = torch.cat([f[...,[-1]], f, f[...,[0]]], dim=-1)
f_per = torch.cat([f_per[...,[-1],:], f_per, f_per[...,[0],:]], dim=-2)
return ((f_per[...,2:,1:-1] + f_per[...,:-2,1:-1] - 2*f_per[...,1:-1,1:-1]) / dx**2 \
+ (f_per[...,1:-1,2:] + f_per[...,1:-1,:-2]- 2*f_per[...,1:-1,1:-1]) / dy**2)
xmin = 0.0
@louity
louity / DCT-DST.py
Created May 17, 2022 10:25
Type II DCT and DST iwth PyTorch. Note that iDCT-II is DCT-III upt to normalizing constant and t iDST-II is DST-III similarly.
import torch
import scipy.fftpack
import numpy as np
np.set_printoptions(precision=4, linewidth=200)
N = 8
x = torch.DoubleTensor(8).normal_()
exp_vec_1 = 2 * torch.exp(-1j*torch.pi*torch.arange(N)/(2*N))
@louity
louity / finite_diff_stag.py
Created April 27, 2022 07:39
Finite difference of order 2 and order 4 on staggered grid
import numpy as np
import matplotlib.pyplot as plt
x = 1.3
f = np.exp
fp = np.exp
errs2, errs4 = [], []
dxs = np.linspace(1e-4, 1e-1, 200)
for dx in dxs:
@louity
louity / poisson_dst.py
Created April 20, 2022 08:04
Solve poisson with homogeneous dirichlet BC using Discrete Sine Transform and PyTorch
import torch
import torch.nn.functional as F
def compute_laplace_dst(nx, ny, dx, dy, arr_kwargs):
"""Discrete sine transform of the 2D centered discrete laplacian
operator."""
x, y = torch.meshgrid(torch.arange(1,nx-1, **arr_kwargs),
torch.arange(1,ny-1, **arr_kwargs),
indexing='ij')
return 2*(torch.cos(torch.pi/(nx-1)*x) - 1)/dx**2 + 2*(torch.cos(torch.pi/(ny-1)*y) - 1)/dy**2
@louity
louity / gaussian_filter_pytorch.py
Last active April 19, 2022 13:34
Gaussian filtering in pytorch
import torch
import matplotlib.pyplot as plt
inp = torch.FloatTensor(1,1,32,32).uniform_(-1,1)
plt.imshow(inp[0,0])
plt.show()
# noyau gaussien
gauss_ker_7 = torch.FloatTensor(1,1,7,7)
x,y = torch.meshgrid(torch.linspace(-3,3,7), torch.linspace(-3,3,7), indexing='xy')
@louity
louity / dst_typeI_2D.py
Last active November 25, 2021 05:38
Pytorch implementation of two dimensional type-I Discrete Sine Transform
"""DST I using FFT routines, Louis Thiry
Method 1 is 'naive' and used FFTs with twice bigger input signal.
Method 2 is more sophisticated and used iRFFT with half the input signal size.
The naive method 1 seems however to be more efficient, and JIT compilation is not key.
"""
import numpy as np
import scipy.fftpack
import torch
@louity
louity / knn.py
Created March 5, 2021 06:26
K-nearest-neighbors on CIFAR-10
# 58.4 % accuracy with K-nearest-neighbor classifier on CIFAR.
# Images are whitened and normalized
import pickle
import numpy as np
import os
from sklearn.neighbors import KNeighborsClassifier
def compute_whitening_op(X, reg=0.1):
X = X.astype('float64')
mean = X.mean(axis=0, keepdims=True)
@louity
louity / AFS.py
Last active April 19, 2022 15:35
Angular Fourier Series in python
"""Python implementation of the Angular Fourier Series descriptors defined in the paper
'On representing chemical environments', DOI: https://doi.org/10.1103/PhysRevB.87.184115
"""
import argparse
import os
import numpy as np
import scipy
import scipy.spatial as spatial
from mpl_toolkits.mplot3d import axes3d # noqa: f401 unused import
import scipy.spatial as spatial
import numpy as np
configuration = np.random.rand((1024, 3))
point_tree = spatial.cKDTree(configuration)
r_cut = 0.5
i_atom = 0
neighbors_indices = point_tree.query_ball_point(configuration[i_atom], r_cut)
@louity
louity / sinkhorn_logsumexp.py
Last active July 5, 2024 05:16
Minimal logsumexp sinkhorn
def sinkhorn_logsumexp(cost_matrix, reg=1e-1, maxiter=30, momentum=0.):
"""Log domain version on sinkhorn distance algorithm ( https://arxiv.org/abs/1306.0895 ).
Inspired by https://github.com/gpeyre/SinkhornAutoDiff/blob/master/sinkhorn_pointcloud.py ."""
m, n = cost_matrix.size()
mu = torch.FloatTensor(m).fill_(1./m)
nu = torch.FloatTensor(n).fill_(1./n)
if torch.cuda.is_available():
mu, nu = mu.cuda(), nu.cuda()