Skip to content

Instantly share code, notes, and snippets.

@fzimmermann89
Last active July 28, 2021 14:06
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save fzimmermann89/f3e8b61abc0f085a00baf4d86a3dcfc4 to your computer and use it in GitHub Desktop.
Save fzimmermann89/f3e8b61abc0f085a00baf4d86a3dcfc4 to your computer and use it in GitHub Desktop.
simplecorrelator
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"#!pip install cupy-cuda113\n",
"#import cupy as cp\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"running tests on CPU\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/cds/home/z/zimmf/.conda/envs/ml/lib/python3.7/site-packages/ipykernel_launcher.py:102: RuntimeWarning: divide by zero encountered in true_divide\n",
"/cds/home/z/zimmf/.conda/envs/ml/lib/python3.7/site-packages/ipykernel_launcher.py:102: RuntimeWarning: invalid value encountered in true_divide\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"running tests on GPU\n",
"all tests passed\n"
]
}
],
"source": [
"import numpy as np\n",
"import scipy as sp\n",
"import abc\n",
"\n",
"\n",
"class simplecorrelator(abc.ABC):\n",
" \"\"\"\n",
" base class\n",
" \"\"\"\n",
"\n",
" _arraybackend = np # should provide .asarray and .empty\n",
"\n",
" @staticmethod\n",
" def _tonumpy(x):\n",
" # overwrite this if anything else has to be done to retur a numpy array\n",
" return x\n",
"\n",
" @abc.abstractstaticmethod\n",
" def _fftfun(*args, **kwargs):\n",
" # the forward fft to use\n",
" pass\n",
"\n",
" @abc.abstractstaticmethod\n",
" def _ifftfun(*args, **kwargs):\n",
" # the ifft to use\n",
" pass\n",
"\n",
" @classmethod\n",
" def _update(cls, fft, old, f):\n",
" # can be overwritten to vectorize better\n",
" return (cls._abs2(fft) - old) * f\n",
"\n",
" @staticmethod\n",
" def _abs2(array):\n",
" # can be overwritten to vectoize better\n",
" return array.real ** 2 + array.imag ** 2\n",
"\n",
" def _shiftfun(self, x):\n",
" # this is the shift for using real ffts\n",
" return np.fft.fftshift(x)[tuple(slice(f // 2 - i, f // 2 + i) for i, f in zip(self._inshape, self._fftshape))]\n",
"\n",
" def __init__(self, mask=None, shape=None, dtype=np.float64, com=np.inf):\n",
" \"\"\"\n",
" simple correlator\n",
" will take a padded fft of each real image added with add(), average over them (optionally as EWA),\n",
" and have meanCor() return the ifft divided by the autocorrelation of the mask\n",
"\n",
" mask or shape are required.\n",
" mask: use only areas of images where mask is positive\n",
" shape: shape of images that will be added\n",
" dtype: dtype used (float64/float32)\n",
" com: moving average center of mass. 0-> mean returns last image to inf->average over all images\n",
"\n",
" the shape of mean will always be 2*inputshape, the zero-shift value will be at inputshape.\n",
" \"\"\"\n",
" if mask is None:\n",
" if shape is None:\n",
" raise ValueError(\"mask or shape must be given\")\n",
" else:\n",
" shape = tuple(shape)\n",
" mask = np.ones(shape)\n",
" elif shape is None:\n",
" shape = mask.shape\n",
" elif mask.shape != shape:\n",
" raise ValueError(\"shape is different from mask.shape\")\n",
" self._inshape = shape\n",
" self._mask = self._arraybackend.asarray(mask > 0, dtype=dtype)\n",
" self._fftshape = tuple([sp.fft.next_fast_len(2 * s) for s in shape]) # pad to avoid circular correlation\n",
" self._outshape = tuple([(2 * s) for s in shape])\n",
" self._meanF = None\n",
" self._N = 0\n",
" self._meanCor = None\n",
" self._dtype = dtype\n",
" self._counts = None\n",
" self._com = com\n",
" self._meanInput = None\n",
"\n",
" def add(self, image):\n",
" \"\"\"\n",
" add an image\n",
" \"\"\"\n",
" self._meanCor = None\n",
" self._N += 1\n",
" inputimage = self._arraybackend.asarray(image, dtype=self._dtype) * self._mask\n",
" fft = self._fftfun(inputimage, self._fftshape)\n",
" if self._N == 1:\n",
" self._meanF = self._abs2(fft)\n",
" self._meanInput = inputimage\n",
" else:\n",
" factor = 1 / min(self._N, self._com + 1)\n",
" self._meanF += self._update(fft, self._meanF, factor)\n",
" self._meanInput += (inputimage - self._meanInput) * factor\n",
"\n",
" @property\n",
" def meanCor(self):\n",
" \"\"\"\n",
" get the current mean of the correlations\n",
" \"\"\"\n",
" if self._N == 0:\n",
" ret = self._arraybackend.empty(self._outshape, dtype=self._dtype)\n",
" ret[:] = np.nan\n",
" return ret\n",
" counts = self._maskCor\n",
" if self._meanCor is None:\n",
" mean = self._ifftfun(self._meanF) / counts\n",
" mean[counts < 0.5] = np.nan\n",
" mean = self._shiftfun(mean)\n",
" self._meanCor = mean\n",
" return self._meanCor\n",
"\n",
" @property\n",
" def _maskCor(self):\n",
" \"\"\"\n",
" autocorrelation of the mask to account for correlation pairs in dft sum\n",
" \"\"\"\n",
" if self._counts is None:\n",
" self._counts = self._ifftfun(self._abs2(self._fftfun(self._mask, self._fftshape)))\n",
" return self._counts\n",
"\n",
" @property\n",
" def meanCor_numpy(self):\n",
" \"\"\"\n",
" return mean of correlations as numpy array\n",
" \"\"\"\n",
" return self._tonumpy(self.meanCor)\n",
"\n",
" def reset(self, mask=None, com=None):\n",
" \"\"\"\n",
" reset the average\n",
" \"\"\"\n",
" self._N = 0\n",
" self._meanCor = None\n",
" self._meanF = None\n",
" self._meanInput = None\n",
"\n",
" if mask is not None:\n",
" self._inshape = mask.shape\n",
" self._mask = self._arraybackend.asarray(mask > 0, dtype=self._dtype)\n",
" self._fftshape = [sp.fft.next_fast_len(2 * s) for s in self._inshape]\n",
" self._outshape = [(2 * s) for s in self._inshape]\n",
" self._counts = None\n",
" if com is not None:\n",
" self._com = com\n",
"\n",
" @property\n",
" def meanInput(self):\n",
" \"\"\"\n",
" get the mean of the (masked) inputs\n",
" \"\"\"\n",
" return self._meanInput\n",
"\n",
" @property\n",
" def meanInput_numpy(self):\n",
" \"\"\"\n",
" get the mean of the (masked) inputs as numpy array\n",
" \"\"\"\n",
" return self._tonumpy(self.meanInput)\n",
"\n",
" @property\n",
" def corMeanInput(self):\n",
" \"\"\"\n",
" autocorrelation of the mean of the inputs\n",
" \"\"\"\n",
" if self._N == 0:\n",
" ret = self._arraybackend.empty(self._outshape, self._dtype)\n",
" ret[:] = np.nan\n",
" return ret\n",
" counts = self._maskCor\n",
" cor = self._ifftfun(self._abs2(self._fftfun(self._meanInput, self._fftshape))) / counts\n",
" cor[counts < 0.5] = np.nan\n",
" cor = self._shiftfun(cor)\n",
" return cor\n",
"\n",
" @property\n",
" def corMeanInput_numpy(self):\n",
" \"\"\"\n",
" autocorrelation of the mean of the inputs as numpy\n",
" \"\"\"\n",
" return self._tonumpy(self.corMeanInput)\n",
"\n",
" @property\n",
" def N(self):\n",
" return self._N\n",
"\n",
" def __len__(self):\n",
" return self._N\n",
"\n",
" @property\n",
" def corshape(self):\n",
" \"\"\"\n",
" outputshape of the correlations\n",
" \"\"\"\n",
" return tuple(self._outshape)\n",
"\n",
" @property\n",
" def inputshape(self):\n",
" \"\"\"\n",
" excepted shape of the inputs\n",
" \"\"\"\n",
" return tuple(self._inshape)\n",
"\n",
" @property\n",
" def mask(self):\n",
" \"\"\"\n",
" currently used mask\n",
" \"\"\"\n",
" return self._mask\n",
"\n",
" @property\n",
" def dtype(self):\n",
" return self._dtype\n",
"\n",
"\n",
"import cupy as cp\n",
"\n",
"\n",
"class simplecorrelator_GPU(simplecorrelator):\n",
" \"\"\"\n",
" using cupy for gpu\n",
" \"\"\"\n",
"\n",
" _arraybackend = cp\n",
"\n",
" @staticmethod\n",
" def _tonumpy(x):\n",
" return x.get()\n",
"\n",
" @staticmethod\n",
" def _fftfun(*args, **kwargs):\n",
" return cp.fft.rfftn(*args, **kwargs)\n",
"\n",
" @staticmethod\n",
" def _ifftfun(*args, **kwargs):\n",
" return cp.fft.irfftn(*args, **kwargs)\n",
"\n",
" @staticmethod\n",
" @cp.fuse\n",
" def _update(fft, old, f):\n",
" # small optimisation to use fusing\n",
" return (simplecorrelator_GPU._abs2(fft) - old) * f\n",
"\n",
" @staticmethod\n",
" @cp.fuse\n",
" def _abs2(array):\n",
" # small optimisation to use fusing\n",
" return cp.real(array) ** 2 + cp.imag(array) ** 2\n",
"\n",
"\n",
"import numexpr as ne\n",
"\n",
"\n",
"class simplecorrelator_CPU(simplecorrelator):\n",
" _arraybackend = np\n",
" _workers = 8\n",
"\n",
" def _fftfun(self, *args, **kwargs):\n",
" return sp.fft.rfftn(*args, **kwargs, workers=self._workers)\n",
"\n",
" def _ifftfun(self, *args, **kwargs):\n",
" return sp.fft.irfftn(*args, **kwargs, workers=self._workers)\n",
"\n",
" @staticmethod\n",
" def _update(fft, old, f):\n",
" return ne.evaluate(\"(real(fft)**2+imag(fft)**2 - old) * f\")\n",
"\n",
" @staticmethod\n",
" def _abs2(array):\n",
" return ne.evaluate(\"real(array)**2+imag(array)**2\")\n",
"\n",
"\n",
"def test(correlator):\n",
" \"\"\"\n",
" some basic tests, not as unittest.Testcase for easiear use in a notebook\n",
" \"\"\"\n",
" import scipy.signal as ss\n",
" import pandas as pd\n",
"\n",
" mask = (np.random.rand(32, 19) > 0.3).astype(float)\n",
" data = np.random.randn(50, 32, 19)\n",
"\n",
" # ground truth scipy.signal\n",
" with np.errstate(all=\"ignore\"):\n",
" ssr = np.array([ss.correlate2d(d * mask, d * mask) for d in data]) / ss.correlate2d(mask, mask)\n",
" # init\n",
" cor = correlator(shape=(10, 10), dtype=np.float32)\n",
"\n",
" # empty correlator\n",
" np.testing.assert_allclose(cor.meanCor_numpy, np.nan * np.ones((20, 20)))\n",
"\n",
" # reset test\n",
" cor.add(np.ones((10, 10)))\n",
" cor.reset(mask=mask)\n",
" np.testing.assert_allclose(cor.meanCor_numpy, np.nan * np.ones((data.shape[1] * 2, data.shape[2] * 2)))\n",
"\n",
" # now actual correlations\n",
" for d in data:\n",
" cor.add(d)\n",
" # basic properties\n",
" np.testing.assert_allclose(len(cor), len(data))\n",
" np.testing.assert_allclose(cor.inputshape, data.shape[1:])\n",
" np.testing.assert_allclose(cor.corshape, cor.meanCor.shape)\n",
" np.testing.assert_allclose(cor.meanInput_numpy, data.mean(0) * mask, atol=1e-6)\n",
"\n",
" # mean of correlation\n",
" np.testing.assert_allclose(np.unravel_index(np.nanargmax(cor.meanCor_numpy), cor.corshape), cor.inputshape, err_msg=\"zero peak position wrong\")\n",
" np.testing.assert_allclose(cor.meanCor_numpy[1:, 1:], ssr.mean(0), atol=1e-5, err_msg=\"failed comparison with scipy on mean of correlations\")\n",
"\n",
" # correlation of mean\n",
" with np.errstate(all=\"ignore\"):\n",
" ssmeanr = ss.correlate2d(data.mean(0) * mask, data.mean(0) * mask) / ss.correlate2d(mask, mask)\n",
" np.testing.assert_allclose(cor.corMeanInput_numpy[1:, 1:], ssmeanr, atol=1e-5, err_msg=\"failed comparison with scipy on correlation of mean\")\n",
"\n",
" # moving average\n",
" corEWM = correlator(mask, com=3)\n",
" for d in data:\n",
" corEWM.add(d)\n",
" # for correlations\n",
" pdc = pd.DataFrame(ssr.reshape(len(data), -1)).ewm(com=3).mean().to_numpy().reshape(ssr.shape)[-1, ...]\n",
" np.testing.assert_allclose(corEWM.meanCor_numpy[1:, 1:], pdc, atol=1e-3, err_msg=\"failed moving average correlation comparison\")\n",
" # for input\n",
" pdi = pd.DataFrame(data.reshape(len(data), -1)).ewm(com=3).mean().to_numpy().reshape(data.shape)[-1, ...] * mask\n",
" np.testing.assert_allclose(corEWM.meanInput_numpy, pdi, atol=1e-3, err_msg=\"failed moving average input comparison\")\n",
"\n",
"\n",
"if __name__ == \"__main__\":\n",
" print(\"running tests on CPU\")\n",
" test(simplecorrelator_CPU)\n",
"\n",
" print(\"running tests on GPU\")\n",
" test(simplecorrelator_CPU)\n",
"\n",
" print(\"all tests passed\")\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"import torch, torch.fft\n",
"\n",
"\n",
"class _torch_array_funs:\n",
" def asarray(*args, **kwargs):\n",
" if \"dtype\" in kwargs:\n",
" kwargs[\"dtype\"] = getattr(torch, np.dtype(kwargs[\"dtype\"]).name)\n",
" return torch.as_tensor(*args, **kwargs, device=\"cuda\")\n",
"\n",
" def empty(*args, **kwargs):\n",
" if \"dtype\" in kwargs:\n",
" kwargs[\"dtype\"] = getattr(torch, np.dtype(kwargs[\"dtype\"]).name)\n",
" return torch.empty(*args, **kwargs, device=\"cuda\")\n",
"\n",
"\n",
"class simplecorrelator_TORCH(simplecorrelator):\n",
" \"\"\"\n",
" torch gpu version\n",
" \"\"\"\n",
"\n",
" _arraybackend = _torch_array_funs\n",
"\n",
" @staticmethod\n",
" def _fftfun(*args, **kwargs):\n",
" return torch.fft.rfftn(*args, **kwargs)\n",
"\n",
" @staticmethod\n",
" def _ifftfun(*args, **kwargs):\n",
" return torch.fft.irfftn(*args, **kwargs)\n",
"\n",
" @staticmethod\n",
" def _tonumpy(x):\n",
" if isinstance(x, torch.Tensor):\n",
" return x.cpu().numpy()\n",
" else:\n",
" return np.asarray(x)\n",
"\n",
" def _shiftfun(self, x):\n",
" # this is the shift for using real ffts\n",
" return np.fft.fftshift(x.cpu().numpy())[tuple(slice(f // 2 - i, f // 2 + i) for i, f in zip(self._inshape, self._fftshape))]\n",
"\n",
"\n",
"if __name__ == \"__main__\":\n",
" test(simplecorrelator_TORCH)\n"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1 loop, best of 5: 2.13 s per loop\n"
]
}
],
"source": [
"%%timeit data=np.random.rand(100,2048,2048)\n",
"# including transfer time for each image\n",
"c = simplecorrelator_GPU(mask=np.ones_like(data[0]))\n",
"for d in data:\n",
" c.add(d)\n",
"r = c.meanCor_numpy\n"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1 loop, best of 5: 2.14 s per loop\n"
]
}
],
"source": [
"%%timeit data=np.random.rand(100,2048,2048)\n",
"# and testing torch as a backend\n",
"c = simplecorrelator_TORCH(mask=np.ones_like(data[0]))\n",
"for d in data:\n",
" c.add(d)\n",
"r = c.meanCor_numpy\n"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/cds/home/z/zimmf/.conda/envs/ml/lib/python3.7/site-packages/ipykernel_launcher.py:102: RuntimeWarning: divide by zero encountered in true_divide\n",
"/cds/home/z/zimmf/.conda/envs/ml/lib/python3.7/site-packages/ipykernel_launcher.py:102: RuntimeWarning: invalid value encountered in true_divide\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"1 loop, best of 5: 22 s per loop\n"
]
}
],
"source": [
"%%timeit data=np.random.rand(100,2048,2048)\n",
"# CPU comparison\n",
"c = simplecorrelator_CPU(mask=np.ones_like(data[0]))\n",
"for d in data:\n",
" c.add(d)\n",
"r = c.meanCor_numpy"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1 loop, best of 5: 14.8 s per loop\n"
]
}
],
"source": [
"%%timeit data=np.random.rand(100,2048,2048)\n",
"# worst case, ask for the numpy mean after each image (two ffts and two transfers)\n",
"c = simplecorrelator_GPU(mask=np.ones_like(data[0]))\n",
"for d in data:\n",
" c.add(d)\n",
" r = c.meanCor_numpy\n"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/cds/home/z/zimmf/.conda/envs/ml/lib/python3.7/site-packages/ipykernel_launcher.py:102: RuntimeWarning: divide by zero encountered in true_divide\n",
"/cds/home/z/zimmf/.conda/envs/ml/lib/python3.7/site-packages/ipykernel_launcher.py:102: RuntimeWarning: invalid value encountered in true_divide\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"1 loop, best of 5: 1min 21s per loop\n"
]
}
],
"source": [
"%%timeit data=np.random.rand(100,2048,2048)\n",
"# worst case, ask for the numpy mean after each image, CPU\n",
"c = simplecorrelator_CPU(mask=np.ones_like(data[0]))\n",
"for d in data:\n",
" c.add(d)\n",
" r = c.meanCor_numpy"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1 loop, best of 5: 1.61 s per loop\n"
]
}
],
"source": [
"%%timeit data=cp.array(np.random.rand(100,2048,2048))\n",
"# already on gpu\n",
"c = simplecorrelator_GPU(shape=data.shape[1:])\n",
"for d in data:\n",
" c.add(d)\n",
"r = c.meanCor_numpy"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1 loop, best of 5: 552 ms per loop\n"
]
}
],
"source": [
"%%timeit data=np.random.rand(100,2048,2048).astype(np.float32)\n",
"# float32\n",
"c = simplecorrelator_GPU(shape=data.shape[1:], dtype=np.float32)\n",
"for d in data:\n",
" c.add(d)\n",
"r = c.meanCor_numpy\n"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1 loop, best of 5: 400 ms per loop\n"
]
}
],
"source": [
"%%timeit data=cp.array(np.random.rand(100,2048,2048).astype(np.float32))\n",
"# float32 already on gpu\n",
"c = simplecorrelator_GPU(shape=data.shape[1:], dtype=np.float32)\n",
"for d in data:\n",
" c.add(d)\n",
"r = c.meanCor_numpy"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 ml",
"language": "python",
"name": "python3ml"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.9"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment