Created
August 17, 2021 13:46
-
-
Save mattiaspaul/0b13c6581ef91d6c758795f8fbde1fcc to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import nibabel as nib | |
import time | |
import numpy as np | |
import scipy.ndimage | |
from scipy.ndimage.interpolation import zoom as zoom | |
from scipy.ndimage.interpolation import map_coordinates | |
from argparse import ArgumentParser | |
def gpu_usage(): | |
print('gpu usage (current/max): {:.2f} / {:.2f} GB'.format(torch.cuda.memory_allocated()*1e-9, torch.cuda.max_memory_allocated()*1e-9)) | |
def adamreg(input_seg_fixed,input_seg_moving,output_field,output_warped,grid_sp,lambda_weight): | |
seg_fixed = torch.from_numpy(nib.load(input_seg_fixed).get_fdata()).float() | |
seg_moving = torch.from_numpy(nib.load(input_seg_moving).get_fdata()).float() | |
H,W,D = seg_fixed.shape | |
if(grid_sp is None): | |
grid_sp = 4 | |
else: | |
grid_sp = int(grid_sp) | |
#extract MIND patches | |
torch.cuda.synchronize() | |
t0 = time.time() | |
#compute MIND descriptors and downsample (using average pooling) | |
with torch.no_grad(): | |
mindssc_fix = F.one_hot(seg_fixed.cuda().view(1,H,W,D).long()).float().permute(0,4,1,2,3).contiguous() | |
mindssc_mov = F.one_hot(seg_moving.cuda().view(1,H,W,D).long()).float().permute(0,4,1,2,3).contiguous() | |
mind_fix = F.avg_pool3d(mindssc_fix,grid_sp,stride=grid_sp) | |
mind_mov = F.avg_pool3d(mindssc_mov,grid_sp,stride=grid_sp) | |
with torch.no_grad(): | |
patch_mind_fix = nn.Flatten(5,)(F.pad(mind_fix,(1,1,1,1,1,1)).unfold(2,3,1).unfold(3,3,1).unfold(4,3,1)).permute(0,1,5,2,3,4).reshape(1,-1,H//grid_sp,W//grid_sp,D//grid_sp) | |
patch_mind_mov = nn.Flatten(5,)(F.pad(mind_mov,(1,1,1,1,1,1)).unfold(2,3,1).unfold(3,3,1).unfold(4,3,1)).permute(0,1,5,2,3,4).reshape(1,-1,H//grid_sp,W//grid_sp,D//grid_sp) | |
#print(patch_mind_fix.shape) | |
#create optimisable displacement grid | |
net = nn.Sequential(nn.Conv3d(3,1,(H//grid_sp,W//grid_sp,D//grid_sp),bias=False)) | |
net[0].weight.data[:] = 0 | |
net.cuda() | |
optimizer = torch.optim.Adam(net.parameters(), lr=1) | |
torch.cuda.synchronize() | |
t0 = time.time() | |
grid0 = F.affine_grid(torch.eye(3,4).unsqueeze(0).cuda(),(1,1,H//grid_sp,W//grid_sp,D//grid_sp),align_corners=False) | |
#run Adam optimisation with diffusion regularisation and B-spline smoothing | |
if(lambda_weight is None): | |
lambda_weight = .75# sad: 10, ssd:0.75 | |
else: | |
lambda_weight = float(lambda_weight) | |
for iter in range(150): | |
optimizer.zero_grad() | |
disp_sample = F.avg_pool3d(F.avg_pool3d(net[0].weight,5,stride=1,padding=2),5,stride=1,padding=2).permute(0,2,3,4,1) | |
reg_loss = lambda_weight*((disp_sample[0,:,1:,:]-disp_sample[0,:,:-1,:])**2).mean()+\ | |
lambda_weight*((disp_sample[0,1:,:,:]-disp_sample[0,:-1,:,:])**2).mean()+\ | |
lambda_weight*((disp_sample[0,:,:,1:]-disp_sample[0,:,:,:-1])**2).mean() | |
#grid_disp = grid0.view(-1,3).cuda().float()+((disp_sample.view(-1,3))/torch.tensor([63/2,63/2,68/2]).unsqueeze(0).cuda()).flip(1) | |
scale = torch.tensor([(H//grid_sp-1)/2,(W//grid_sp-1)/2,(D//grid_sp-1)/2]).cuda().unsqueeze(0) | |
grid_disp = grid0.view(-1,3).cuda().float()+((disp_sample.view(-1,3))/scale).flip(1).float() | |
patch_mov_sampled = F.grid_sample(patch_mind_mov.float(),grid_disp.view(1,H//grid_sp,W//grid_sp,D//grid_sp,3).cuda(),align_corners=True,mode='bilinear')#,padding_mode='border') | |
sampled_cost = (patch_mov_sampled-patch_mind_fix).pow(2).mean(1)*12 | |
#sampled_cost = F.grid_sample(ssd2.view(-1,1,17,17,17).float(),disp_sample.view(-1,1,1,1,3)/disp_hw,align_corners=True,padding_mode='border') | |
loss = sampled_cost.mean() | |
(loss+reg_loss).backward() | |
optimizer.step() | |
torch.cuda.synchronize() | |
t1 = time.time() | |
print(t1-t0,'sec (optim)') | |
fitted_grid = disp_sample.permute(0,4,1,2,3).detach() | |
disp_hr = F.interpolate(fitted_grid*grid_sp,size=(H,W,D),mode='trilinear',align_corners=False) | |
#disp = disp_hr.cpu().float().permute(0,2,3,4,1)/torch.Tensor([H-1,W-1,D-1]).view(1,1,1,1,3)*2 | |
#disp = disp.flip(4) | |
disp_field = disp_hr.cpu().float().numpy() | |
#convert field to half-resolution npz | |
x = disp_field[0,0,:,:,:] | |
y = disp_field[0,1,:,:,:] | |
z = disp_field[0,2,:,:,:] | |
x1 = zoom(x,1/2,order=2).astype('float16') | |
y1 = zoom(y,1/2,order=2).astype('float16') | |
z1 = zoom(z,1/2,order=2).astype('float16') | |
#write out field | |
np.savez_compressed(output_field,np.stack((x1,y1,z1),0)) | |
if(output_warped is not None): | |
D, H, W = fixed.shape | |
identity = np.stack(np.meshgrid(np.arange(D), np.arange(H), np.arange(W), indexing='ij'),0) | |
moving_warped = map_coordinates(moving.numpy(), identity + disp_field[0], order=1) | |
nib.save(nib.Nifti1Image(moving_warped,np.eye(4)),output_warped) | |
# torch.save(disp.cpu().data,output_field) | |
if __name__ == "__main__": | |
parser = ArgumentParser() | |
#parser.add_argument("--input_img_fixed", | |
# type=str,required=True, | |
# help="path to input fixed nifti") | |
#parser.add_argument("--input_img_moving", | |
# type=str,required=True, | |
# help="path to input moving nifti") | |
parser.add_argument("--input_seg_fixed", | |
type=str,required=True, | |
help="path to input fixed labels") | |
parser.add_argument("--input_seg_moving", | |
type=str,required=True, | |
help="path to input moving labels") | |
parser.add_argument("--output_field", | |
type=str,required=True, | |
help="path to output displacement file") | |
parser.add_argument("--output_warped", | |
type=str,required=False, | |
help="path to output warped image nifti") | |
parser.add_argument("--grid_sp", | |
type=str,required=False, | |
help="integer value for grid_spacing (default = 4) ") | |
parser.add_argument("--lambda_weight", | |
type=str,required=False, | |
help="floating point value for regularisation weight (default = .75) ") | |
adamreg(**vars(parser.parse_args())) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment