Skip to content

Instantly share code, notes, and snippets.

@nikitinvv
Created June 22, 2022 20:42
Show Gist options
  • Save nikitinvv/3f787d2326346e9b9247863ce78e66a4 to your computer and use it in GitHub Desktop.
Save nikitinvv/3f787d2326346e9b9247863ce78e66a4 to your computer and use it in GitHub Desktop.
import numpy as np
import sys
import h5py
import cupy as cp # subpixel shifts on gpu
#import numpy as cp # subpixel shifts on cpu
## usage
### python merge_helical.py /data/helical/Coal_NaBr_075.h5 0.062480474851608875 3201
### where 0.062480474851608875 is shift in pixels between two projections
### 3201 - number of angles per 360 deg interval (<360 if only 1 rotation)
def copy_attributes(in_object, out_object):
'''Copy attributes between 2 HDF5 objects.'''
for key, value in in_object.attrs.items():
out_object.attrs[key] = value
def _report(operation, key, obj):
type_str = type(obj).__name__.split(".")[-1].lower()
print(f"{operation} {type_str} : {key}")
def h5py_compatible_attributes(in_object):
'''Are all attributes of an object readable in h5py?'''
try:
# Force obtaining the attributes so that error may appear
[ 0 for at in in_object.attrs.items() ]
return True
except:
return False
def copy_h5(in_object, out_object, filter_data=[None], log=False):
'''Recursively copy&compress the tree.
If attributes cannot be transferred, a copy is created.
Otherwise, dataset are compressed.
'''
for key, in_obj in in_object.items():
if (key in filter_data):
continue
if not isinstance(in_obj, h5py.Datatype) and h5py_compatible_attributes(in_obj):
if isinstance(in_obj, h5py.Group):
out_obj = out_object.create_group(key)
copy_h5(in_obj, out_obj, filter_data, log)
if log:
_report("Copied", key, in_obj)
elif isinstance(in_obj, h5py.Dataset):
out_obj = out_object.create_dataset(key, data=in_obj)
if log:
_report("Copied", key, in_obj)
else:
raise "Invalid object type %s" % type(in_obj)
copy_attributes(in_obj, out_obj)
else:
# We copy datatypes and objects with non-understandable attributes
# identically.
if log:
_report("Copied", key, in_obj)
in_object.copy(key, out_object)
def apply_shift_subpixel(data, shifts, pad=1):
"""Apply shifts for projections on GPU."""
[ntheta, nz, n] = data.shape
# padding
tmp = cp.zeros([ntheta, nz+2*pad, n], dtype='float32')
tmp[:, pad:-pad] = data
# shift in the frequency domain
y = cp.fft.fftfreq(nz+2*pad).astype('float32').reshape([nz+2*pad,1])
s = cp.exp(-2*np.pi*1j * (y*cp.array(shifts[:, None, None])))
data = cp.fft.irfft2(s*np.fft.rfft2(tmp))
return data
if __name__ == "__main__":
fname = sys.argv[1] # input file name
pixel_shift = float(sys.argv[2]) # shift in pixels between two projections
ntheta_out = int(sys.argv[3]) # number of angles in the resulting file, corresponds to the number of angles in a 360 deg interval (assuming regular sampling in angles)
ptheta = 32 #int(sys.argv[4]) # number of angles in a chunk (for large datasets)
pad = 1 #int(sys.argv[5]) # padding for subpixel shifts
fname_out = fname[:-3]+'_merged.h5'
with h5py.File(fname,'r') as fid, h5py.File(fname_out,'w') as fid_out:
# copy h5 file
filter_data = ['data','data_white','data_dark','theta'] # will not be copied
copy_h5(fid,fid_out,filter_data,log=True)
# create resulting dataset
data = fid['exchange/data']
[ntheta,nz,n] = data.shape
nz_out = nz+np.abs(int(np.ceil(pixel_shift*ntheta)))+2*pad
data_out = fid_out.create_dataset('exchange/data',[ntheta_out,nz_out,n],dtype='float32',fillvalue=0)
# create resulting angles
theta_out = fid['exchange/theta'][:ntheta_out]
fid_out.create_dataset('exchange/theta',data=theta_out)
# create resulting flat and dark fields
fid_out.create_dataset('exchange/data_dark',data=np.zeros([1,nz_out,n]),dtype='float32')
fid_out.create_dataset('exchange/data_white',data=np.ones([1,nz_out,n]),dtype='float32')
# read andaverage flat and dark
data_dark = np.median(fid['exchange/data_dark'].astype('float32'),axis=0)
data_white = np.median(fid['exchange/data_white'].astype('float32'),axis=0)
# calculate shifts
shifts = np.float32(np.arange(ntheta)*pixel_shift)
# shift data by chunks
for k in range(int(np.ceil(ntheta/ptheta))):
st = k*ptheta
end = min(ntheta,(k+1)*ptheta)
print(f'Processing angle chunk {st=},{end=}')
#flat field correction
data_chunk = (data[st:end]-data_dark)/(data_white-data_dark)
data_chunk = cp.array(data_chunk)
#check for 0 and inf
data_chunk = -cp.log(data_chunk)
data_chunk[cp.isnan(data_chunk)] = 6.0
data_chunk[cp.isinf(data_chunk)] = 0
# integer + float shifts
ishifts = np.int32(shifts[st:end])
fshifts = np.float32(shifts[st:end]-ishifts)
if pixel_shift>0:
#stage is moving up
stz = ishifts
endz = ishifts+nz+2*pad
else:
#stage is moving down
endz = nz_out-1+ishifts
stz = nz_out-1+ishifts-nz-2*pad
data_chunk = apply_shift_subpixel(data_chunk,fshifts,pad)
if not isinstance(data_chunk, np.ndarray):
data_chunk = data_chunk.get()
for kk in range(end-st):
data_out[(kk+st)%ntheta_out,stz[kk]:endz[kk]] += data_chunk[kk]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment