Skip to content

Instantly share code, notes, and snippets.

@mattiaspaul
Created August 17, 2021 13:46
Show Gist options
  • Save mattiaspaul/0b13c6581ef91d6c758795f8fbde1fcc to your computer and use it in GitHub Desktop.
Save mattiaspaul/0b13c6581ef91d6c758795f8fbde1fcc to your computer and use it in GitHub Desktop.
import torch
import torch.nn as nn
import torch.nn.functional as F
import nibabel as nib
import time
import numpy as np
import scipy.ndimage
from scipy.ndimage.interpolation import zoom as zoom
from scipy.ndimage.interpolation import map_coordinates
from argparse import ArgumentParser
def gpu_usage():
print('gpu usage (current/max): {:.2f} / {:.2f} GB'.format(torch.cuda.memory_allocated()*1e-9, torch.cuda.max_memory_allocated()*1e-9))
def adamreg(input_seg_fixed,input_seg_moving,output_field,output_warped,grid_sp,lambda_weight):
seg_fixed = torch.from_numpy(nib.load(input_seg_fixed).get_fdata()).float()
seg_moving = torch.from_numpy(nib.load(input_seg_moving).get_fdata()).float()
H,W,D = seg_fixed.shape
if(grid_sp is None):
grid_sp = 4
else:
grid_sp = int(grid_sp)
#extract MIND patches
torch.cuda.synchronize()
t0 = time.time()
#compute MIND descriptors and downsample (using average pooling)
with torch.no_grad():
mindssc_fix = F.one_hot(seg_fixed.cuda().view(1,H,W,D).long()).float().permute(0,4,1,2,3).contiguous()
mindssc_mov = F.one_hot(seg_moving.cuda().view(1,H,W,D).long()).float().permute(0,4,1,2,3).contiguous()
mind_fix = F.avg_pool3d(mindssc_fix,grid_sp,stride=grid_sp)
mind_mov = F.avg_pool3d(mindssc_mov,grid_sp,stride=grid_sp)
with torch.no_grad():
patch_mind_fix = nn.Flatten(5,)(F.pad(mind_fix,(1,1,1,1,1,1)).unfold(2,3,1).unfold(3,3,1).unfold(4,3,1)).permute(0,1,5,2,3,4).reshape(1,-1,H//grid_sp,W//grid_sp,D//grid_sp)
patch_mind_mov = nn.Flatten(5,)(F.pad(mind_mov,(1,1,1,1,1,1)).unfold(2,3,1).unfold(3,3,1).unfold(4,3,1)).permute(0,1,5,2,3,4).reshape(1,-1,H//grid_sp,W//grid_sp,D//grid_sp)
#print(patch_mind_fix.shape)
#create optimisable displacement grid
net = nn.Sequential(nn.Conv3d(3,1,(H//grid_sp,W//grid_sp,D//grid_sp),bias=False))
net[0].weight.data[:] = 0
net.cuda()
optimizer = torch.optim.Adam(net.parameters(), lr=1)
torch.cuda.synchronize()
t0 = time.time()
grid0 = F.affine_grid(torch.eye(3,4).unsqueeze(0).cuda(),(1,1,H//grid_sp,W//grid_sp,D//grid_sp),align_corners=False)
#run Adam optimisation with diffusion regularisation and B-spline smoothing
if(lambda_weight is None):
lambda_weight = .75# sad: 10, ssd:0.75
else:
lambda_weight = float(lambda_weight)
for iter in range(150):
optimizer.zero_grad()
disp_sample = F.avg_pool3d(F.avg_pool3d(net[0].weight,5,stride=1,padding=2),5,stride=1,padding=2).permute(0,2,3,4,1)
reg_loss = lambda_weight*((disp_sample[0,:,1:,:]-disp_sample[0,:,:-1,:])**2).mean()+\
lambda_weight*((disp_sample[0,1:,:,:]-disp_sample[0,:-1,:,:])**2).mean()+\
lambda_weight*((disp_sample[0,:,:,1:]-disp_sample[0,:,:,:-1])**2).mean()
#grid_disp = grid0.view(-1,3).cuda().float()+((disp_sample.view(-1,3))/torch.tensor([63/2,63/2,68/2]).unsqueeze(0).cuda()).flip(1)
scale = torch.tensor([(H//grid_sp-1)/2,(W//grid_sp-1)/2,(D//grid_sp-1)/2]).cuda().unsqueeze(0)
grid_disp = grid0.view(-1,3).cuda().float()+((disp_sample.view(-1,3))/scale).flip(1).float()
patch_mov_sampled = F.grid_sample(patch_mind_mov.float(),grid_disp.view(1,H//grid_sp,W//grid_sp,D//grid_sp,3).cuda(),align_corners=True,mode='bilinear')#,padding_mode='border')
sampled_cost = (patch_mov_sampled-patch_mind_fix).pow(2).mean(1)*12
#sampled_cost = F.grid_sample(ssd2.view(-1,1,17,17,17).float(),disp_sample.view(-1,1,1,1,3)/disp_hw,align_corners=True,padding_mode='border')
loss = sampled_cost.mean()
(loss+reg_loss).backward()
optimizer.step()
torch.cuda.synchronize()
t1 = time.time()
print(t1-t0,'sec (optim)')
fitted_grid = disp_sample.permute(0,4,1,2,3).detach()
disp_hr = F.interpolate(fitted_grid*grid_sp,size=(H,W,D),mode='trilinear',align_corners=False)
#disp = disp_hr.cpu().float().permute(0,2,3,4,1)/torch.Tensor([H-1,W-1,D-1]).view(1,1,1,1,3)*2
#disp = disp.flip(4)
disp_field = disp_hr.cpu().float().numpy()
#convert field to half-resolution npz
x = disp_field[0,0,:,:,:]
y = disp_field[0,1,:,:,:]
z = disp_field[0,2,:,:,:]
x1 = zoom(x,1/2,order=2).astype('float16')
y1 = zoom(y,1/2,order=2).astype('float16')
z1 = zoom(z,1/2,order=2).astype('float16')
#write out field
np.savez_compressed(output_field,np.stack((x1,y1,z1),0))
if(output_warped is not None):
D, H, W = fixed.shape
identity = np.stack(np.meshgrid(np.arange(D), np.arange(H), np.arange(W), indexing='ij'),0)
moving_warped = map_coordinates(moving.numpy(), identity + disp_field[0], order=1)
nib.save(nib.Nifti1Image(moving_warped,np.eye(4)),output_warped)
# torch.save(disp.cpu().data,output_field)
if __name__ == "__main__":
parser = ArgumentParser()
#parser.add_argument("--input_img_fixed",
# type=str,required=True,
# help="path to input fixed nifti")
#parser.add_argument("--input_img_moving",
# type=str,required=True,
# help="path to input moving nifti")
parser.add_argument("--input_seg_fixed",
type=str,required=True,
help="path to input fixed labels")
parser.add_argument("--input_seg_moving",
type=str,required=True,
help="path to input moving labels")
parser.add_argument("--output_field",
type=str,required=True,
help="path to output displacement file")
parser.add_argument("--output_warped",
type=str,required=False,
help="path to output warped image nifti")
parser.add_argument("--grid_sp",
type=str,required=False,
help="integer value for grid_spacing (default = 4) ")
parser.add_argument("--lambda_weight",
type=str,required=False,
help="floating point value for regularisation weight (default = .75) ")
adamreg(**vars(parser.parse_args()))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment