Skip to content

Instantly share code, notes, and snippets.

@nikitinvv
Last active March 31, 2024 00:38
Show Gist options
  • Save nikitinvv/fef87cebeb882ae4fdcefb2148570d9a to your computer and use it in GitHub Desktop.
Save nikitinvv/fef87cebeb882ae4fdcefb2148570d9a to your computer and use it in GitHub Desktop.
import numpy as np
import matplotlib.pyplot as plt
PLANCK_CONSTANT = 4.135667696e-18 # [keV*s]
SPEED_OF_LIGHT = 299792458 # [m/s]
n = 1024 # 1d signal sizie
ns = 64+64
voxelsize = 2*10e-9 # object voxel size
energy = 33.35 # [keV] xray energy
wavelength = PLANCK_CONSTANT * SPEED_OF_LIGHT / energy
focusToDetectorDistance = 1.28 # [m] distance between focus and detector
# Sample to detector propagation
# z1 = np.array([4.584e-3, 4.765e-3, 5.488e-3,5.7e-3, 6.9895e-3])-3.7e-4 # [m] distances between focus and sample planes
# z1 = np.array([4.584e-3, 4.765e-2, 5.488e-2,5.7e-3, 6.9895e-3])-3.7e-4 # [m] distances between focus and sample planes
z1 = np.array([4.584e-3, 4.765e-2, 5.488e-2,5.7e-3, 6.9895e-3])-3.7e-4 # [m] distances between focus and sample planes
# ndist=8
z1=z1[:2]
z1*=2
# z1 = np.zeros(ndist)
# z1[0] = 4.584e-3
# for k in range(ndist):
# z1[k] = z1[0]*10
z2 = focusToDetectorDistance-z1 # [m] distances between `sample planes and detector
distances = z1#(z1*z2)/focusToDetectorDistance # [m] propagation distances after switching from the point source wave to plane wave
magnifications = focusToDetectorDistance/z1 # magnification when propagating from the sample plane to the detector
norm_magnifications = magnifications/magnifications[0]*0+1 # normalized magnifications
# distances = distances*norm_magnifications**2 # scaled propagation distances due to magnified probes
# Fresnel kernels
fP = np.zeros([len(distances), 2*n], dtype='complex64')
fPa = np.zeros([len(distances), 2*n], dtype='complex64')
fx = np.fft.fftfreq(2*n, d=voxelsize)
for i, d in enumerate(distances):
fP[i] = np.exp(-1j*np.pi*wavelength*d*fx**2)
fPa[i] = np.exp(1j*np.pi*wavelength*d*fx**2)
def propagate(f,fP):
'''Propagate ns signals with Fresnel propagator fP'''
# return f
ff = np.pad(f,(n//2,n//2),'symmetric')
ff = np.fft.ifft(np.fft.fft(ff)*fP)
ff = ff[n//2:-n//2]
return ff
def fwd(f):
g = np.zeros([len(distances),n],dtype='complex64')
for j in range(ns):
f0 = np.roll(f,j-ns//2)
# f0=f
for k in range(len(distances)):
g[k] += propagate(f0, fP[k])
return g
def adj(g):
f = np.zeros(n,dtype='complex64')
for j in range(ns):
for k in range(len(distances)):
g0 = propagate(g[k], fPa[k])
g0 = np.roll(g0,-(j-ns//2))
f += g0
return f
s = 64 # feature size
f = np.ones(n,dtype='complex64')
f[n//2-3*s//2:n//2-s//2] = np.exp(1j*0.01)
f[n//2+s//2:n//2+3*s//2] = np.exp(1j*0.01)
x = np.arange(-n/2,n/2)/n
v = np.exp(-x**2*n/4)
f = np.fft.fftshift(np.fft.ifft(np.fft.fftshift(np.fft.fftshift(np.fft.fft(np.fft.fftshift(f)))*v)))
g = fwd(f)
ff = adj(g)
print(np.sum(g*np.conj(g)),np.sum(f*np.conj(ff)))
data = np.abs(g)**2
def CTFPurePhase(rads, wlen, dists, fx, alpha):
"""
weak phase approximation from Cloetens et al. 2002
Parameters
----------
rad : 2D-array
projection.
wlen : float
X-ray wavelentgth assumes monochromatic source.
dist : float
Object to detector distance (propagation distance) in mm.
fx, fy : ndarray
Fourier conjugate / spatial frequency coordinates of x and y.
alpha : float
regularization factor.
Return
------
phase retrieved projection in real space
"""
numerator = 0
denominator = 0
for j in range(0, len(dists)):
rad_freq = np.fft.fft(rads[j])
taylorExp = np.sin(np.pi*wlen*dists[j]*(fx**2))
numerator = numerator + taylorExp * (rad_freq)
denominator = denominator + 2*taylorExp**2
numerator = numerator / len(dists)
denominator = (denominator / len(dists)) + alpha
phase = np.real( np.fft.ifft(numerator / denominator) )
# phase = 0.5 * phase
return phase
wlen = PLANCK_CONSTANT * SPEED_OF_LIGHT/energy
distances_rec = (distances/norm_magnifications**2)
# datac = np.pad(data,((0,0),(n//2,n//2)),'edge')
fx = np.fft.fftfreq(n,d=voxelsize)
recCTFPurePhase = CTFPurePhase(data, wlen, distances_rec, fx, 1e-7)/ns**2*2
# recCTFPurePhase = recCTFPurePhase[n//2:-n//2]
print(recCTFPurePhase.shape)
# exit()
def line_search(minf, gamma, u, fu, d, fd):
""" Line search for the step sizes gamma"""
while(minf(u, fu)-minf(u+gamma*d, fu+gamma*fd) < 0 and gamma > 1e-12):
gamma *= 0.5
if(gamma <= 1e-12): # direction not found
#print('no direction')
gamma = 0
return gamma
def cg_holo(data, init, piter):
"""Conjugate gradients method for holography"""
# minimization functional
def minf(psi, fpsi):
f = np.linalg.norm(np.abs(fpsi)-np.sqrt(data))**2
# f = np.linalg.norm(np.abs(fpsi)**2-data)**2
return f
psi = init.copy()
gamma = 1# init gamma as a large value
for i in range(piter):
fpsi = fwd(psi)
grad = adj(fpsi-np.sqrt(data)*np.exp(1j*np.angle(fpsi)))/n/8
# Dai-Yuan direction
if i == 0:
d = -grad
else:
d = -grad+np.linalg.norm(grad)**2 / \
((np.sum(np.conj(d)*(grad-grad0))))*d
grad0 = grad
# line search
fd = fwd(d)
gamma = line_search(minf, gamma, psi, fpsi, d, fd)
# if gamma<0.001:
# return psi
psi = psi + gamma*d
print(f'{i}) {gamma=}, err={minf(psi,fpsi)}')
return psi
init = f*0+1*np.exp(1j*recCTFPurePhase)
rec = cg_holo(data,init,96)
fig, axs = plt.subplots(2, 3, figsize=(24, 8))
fig.suptitle(f'Source size {ns} pixels, feature size {s} pixels')
im=axs[0,0].plot(np.angle(f))
axs[0,0].set_title('test1: init phase object')
im=axs[0,1].plot(np.abs(g[0])**2,'r')
im=axs[0,1].plot(np.abs(g[1])**2,'g')
# im=axs[0,1].plot(np.abs(g[2])**2,'b')
# im=axs[0,1].plot(np.abs(g[3])**2,'y')
axs[0,1].set_title('data for different distances')
im=axs[0,2].plot(recCTFPurePhase,'r')
axs[0,2].set_title('CTF reconstruction')
fP[:]=1
g = fwd(f)/ns
im=axs[1,0].plot(np.angle(f))
axs[1,0].set_title('test2 (no Fresnel): init abs object')
im=axs[1,1].plot(np.imag(g[0])/np.max(np.imag(g[0]))*0.01,'r')
recphase =np.angle(rec)
recphase-=np.min(recCTFPurePhase)
recphase = recphase/np.max(recphase)*0.01
axs[1,1].set_title('data without Fresnel propagation (abs)')
im=axs[1,2].plot(recphase,'b')
recCTFPurePhase-=np.min(recCTFPurePhase)
recCTFPurePhase = recCTFPurePhase/np.max(recCTFPurePhase)*0.01
im=axs[1,2].plot(recCTFPurePhase,'r')
axs[1,2].set_title('cg after ctf')
plt.show()
# def line_search(minf, gamma, u, fu, d, fd):
# """ Line search for the step sizes gamma"""
# while(minf(u, fu)-minf(u+gamma*d, fu+gamma*fd) < 0 and gamma > 1e-12):
# gamma *= 0.5
# if(gamma <= 1e-12): # direction not found
# #print('no direction')
# gamma = 0
# return gamma
# def cg_holo(data, init, piter):
# """Conjugate gradients method for holography"""
# # minimization functional
# def minf(psi, fpsi):
# f = np.linalg.norm(np.abs(fpsi)-np.sqrt(data))**2
# # f = np.linalg.norm(np.abs(fpsi)**2-data)**2
# return f
# psi = init.copy()
# gamma = 1# init gamma as a large value
# for i in range(piter):
# fpsi = fwd(psi)
# grad = adj(fpsi-np.sqrt(data)*np.exp(1j*np.angle(fpsi)))/n
# # Dai-Yuan direction
# if i == 0:
# d = -grad
# else:
# d = -grad+np.linalg.norm(grad)**2 / \
# ((np.sum(np.conj(d)*(grad-grad0))))*d
# grad0 = grad
# # line search
# fd = fwd(d)
# gamma = line_search(minf, gamma, psi, fpsi, d, fd)
# psi = psi + gamma*d
# print(f'{i}) {gamma=}, err={minf(psi,fpsi)}')
# return psi
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment