Skip to content

Instantly share code, notes, and snippets.

@maxastyler
Created March 23, 2021 14:08
Show Gist options
  • Save maxastyler/b45d3df4d7f3a13aacec1f760f0d68fc to your computer and use it in GitHub Desktop.
Save maxastyler/b45d3df4d7f3a13aacec1f760f0d68fc to your computer and use it in GitHub Desktop.
Centering Code
from scipy.optimize import curve_fit
import numpy as np
import matplotlib.pyplot as plt
def gaussian(positions, x, y, sigma, amplitude, background):
return (amplitude * np.exp(-((positions[0] - x)**2 + (positions[1] - y)**2)/(2 * sigma**2)) + background)
def camera_coords(width, height):
return np.mgrid[0:(width-1):width*1j, 0:(height-1):height*1j]
def centre_of_mass(image):
x, y = camera_coords(*image.shape)
p = image / image.sum()
return ((p * x).sum(), (p * y).sum())
def variance(image):
p = image/image.sum()
x, y = camera_coords(*image.shape)
mean_x, mean_y = centre_of_mass(image)
return (np.sqrt((p * (x - mean_x) ** 2).sum()) + np.sqrt((p * (y - mean_y) ** 2).sum()))/2
def fit_gaussian(image):
x, y = camera_coords(*image.shape)
(p_opt, _) = curve_fit(lambda *x: gaussian(*x).ravel(),
(x, y), image.ravel(),
p0=[*centre_of_mass(image), variance(image), np.max(image), np.mean(image)])
return p_opt
def k_freq(width, height):
return np.meshgrid(np.fft.fftfreq(width), np.fft.fftfreq(height))
def transmission(width, height, wavelength, distance):
k_x, k_y = k_freq(width, height)
k = 2*np.pi/wavelength
k_z = np.sqrt(k**2 - k_x**2 - k_y**2)
return np.exp(-1j * k_z * distance).T
def propagate(mode, distance, wavelength = 1):
return np.fft.ifft2(transmission(*mode.shape, wavelength, distance) * np.fft.fft2(mode))
size = (1000, 800)
x, y = camera_coords(*size)
mode = gaussian((x, y), size[0]/2, size[1]/2, 40, 3, 0)
sim_im = np.abs(propagate(mode * mask, 5e4, wavelength=1))
plt.imshow(mode)
plt.show()
plt.imshow(sim_im)
plt.show()
#plt.plot(sim_im.sum(axis=1))
#plt.show()
data = []
for i in np.linspace(-300, 300, 200):
mask = np.ones(size, dtype=np.complex128)
mask[:size[0]//2+int(i)] = -1
axis_sum = np.abs(propagate(mode * mask, 5e5, wavelength=1)).sum(axis=1)
axis_norm = axis_sum / axis_sum.sum()
data.append((axis_norm * np.arange(len(axis_norm))).sum())
plt.plot(data)
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment