Last active
July 29, 2023 20:59
-
-
Save mattiaspaul/d314c22ac97d37c2cf05e99780bd54c4 to your computer and use it in GitHub Desktop.
AdamReg MIND
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#note: an advanced baseline with additional infos for registration settings | |
#for Learn2Reg 2021 and recommended pre-processing of lung scans can be found | |
#here https://github.com/multimodallearning/convexAdam | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import nibabel as nib | |
import time | |
import numpy as np | |
import scipy.ndimage | |
from scipy.ndimage.interpolation import zoom as zoom | |
from scipy.ndimage.interpolation import map_coordinates | |
from argparse import ArgumentParser | |
def pdist_squared(x): | |
xx = (x**2).sum(dim=1).unsqueeze(2) | |
yy = xx.permute(0, 2, 1) | |
dist = xx + yy - 2.0 * torch.bmm(x.permute(0, 2, 1), x) | |
dist[dist != dist] = 0 | |
dist = torch.clamp(dist, 0.0)#, np.inf) | |
return dist | |
def MINDSSC(img, radius=2, dilation=2): | |
# see http://mpheinrich.de/pub/miccai2013_943_mheinrich.pdf for details on the MIND-SSC descriptor | |
# kernel size | |
kernel_size = radius * 2 + 1 | |
# define start and end locations for self-similarity pattern | |
six_neighbourhood = torch.tensor([[0,1,1], | |
[1,1,0], | |
[1,0,1], | |
[1,1,2], | |
[2,1,1], | |
[1,2,1]]).long() | |
# squared distances | |
dist = pdist_squared(six_neighbourhood.t().unsqueeze(0)).squeeze(0) | |
# define comparison mask | |
x, y = torch.meshgrid(torch.arange(6), torch.arange(6)) | |
mask = ((x > y).view(-1) & (dist == 2).view(-1)) | |
# build kernel | |
idx_shift1 = six_neighbourhood.unsqueeze(1).repeat(1,6,1).view(-1,3)[mask,:] | |
idx_shift2 = six_neighbourhood.unsqueeze(0).repeat(6,1,1).view(-1,3)[mask,:] | |
mshift1 = torch.zeros(12, 1, 3, 3, 3).cuda() | |
mshift1.view(-1)[torch.arange(12) * 27 + idx_shift1[:,0] * 9 + idx_shift1[:, 1] * 3 + idx_shift1[:, 2]] = 1 | |
mshift2 = torch.zeros(12, 1, 3, 3, 3).cuda() | |
mshift2.view(-1)[torch.arange(12) * 27 + idx_shift2[:,0] * 9 + idx_shift2[:, 1] * 3 + idx_shift2[:, 2]] = 1 | |
rpad1 = nn.ReplicationPad3d(dilation) | |
rpad2 = nn.ReplicationPad3d(radius) | |
# compute patch-ssd | |
ssd = F.avg_pool3d(rpad2((F.conv3d(rpad1(img), mshift1, dilation=dilation) - F.conv3d(rpad1(img), mshift2, dilation=dilation)) ** 2), kernel_size, stride=1) | |
# MIND equation | |
mind = ssd - torch.min(ssd, 1, keepdim=True)[0] | |
mind_var = torch.mean(mind, 1, keepdim=True) | |
mind_var = torch.min(torch.max(mind_var, mind_var.mean()*0.001), mind_var.mean()*1000) | |
mind /= mind_var | |
mind = torch.exp(-mind) | |
#permute to have same ordering as C++ code | |
mind = mind[:, torch.tensor([6, 8, 1, 11, 2, 10, 0, 7, 9, 4, 5, 3]).long(), :, :, :] | |
return mind | |
def mind_loss(x, y): | |
return torch.mean( (MINDSSC(x) - MINDSSC(y)) ** 2 ) | |
def gpu_usage(): | |
print('gpu usage (current/max): {:.2f} / {:.2f} GB'.format(torch.cuda.memory_allocated()*1e-9, torch.cuda.max_memory_allocated()*1e-9)) | |
def inverse_consistency(disp_field1s,disp_field2s,iter=20): | |
#factor = 1 | |
B,C,H,W,D = disp_field1s.size() | |
#make inverse consistent | |
with torch.no_grad(): | |
disp_field1i = disp_field1s.clone() | |
disp_field2i = disp_field2s.clone() | |
identity = F.affine_grid(torch.eye(3,4).unsqueeze(0),(1,1,H,W,D),align_corners=False).permute(0,4,1,2,3).to(disp_field1s.device).to(disp_field1s.dtype) | |
for i in range(iter): | |
disp_field1s = disp_field1i.clone() | |
disp_field2s = disp_field2i.clone() | |
disp_field1i = 0.5*(disp_field1s-F.grid_sample(disp_field2s,(identity+disp_field1s).permute(0,2,3,4,1),align_corners=False)) | |
disp_field2i = 0.5*(disp_field2s-F.grid_sample(disp_field1s,(identity+disp_field2s).permute(0,2,3,4,1),align_corners=False)) | |
return disp_field1i,disp_field2i | |
def adamreg(input_img_fixed,input_img_moving,output_field,output_warped,grid_sp,lambda_weight,inverse): | |
fixed = torch.from_numpy(nib.load(input_img_fixed).get_fdata()).float() | |
moving = torch.from_numpy(nib.load(input_img_moving).get_fdata()).float() | |
H,W,D = fixed.shape | |
if(inverse is None): | |
inverse = 0 | |
else: | |
inverse = int(inverse) | |
if(grid_sp is None): | |
grid_sp = 4 | |
else: | |
grid_sp = int(grid_sp) | |
#extract MIND patches | |
torch.cuda.synchronize() | |
t0 = time.time() | |
#compute MIND descriptors and downsample (using average pooling) | |
with torch.no_grad(): | |
mindssc_fix = MINDSSC(fixed.unsqueeze(0).unsqueeze(1).cuda(),2,2).half()#.cpu() | |
mindssc_mov = MINDSSC(moving.unsqueeze(0).unsqueeze(1).cuda(),2,2).half()#.cpu() | |
mind_fix = F.avg_pool3d(mindssc_fix,grid_sp,stride=grid_sp) | |
mind_mov = F.avg_pool3d(mindssc_mov,grid_sp,stride=grid_sp) | |
with torch.no_grad(): | |
patch_mind_fix = nn.Flatten(5,)(F.pad(mind_fix,(1,1,1,1,1,1)).unfold(2,3,1).unfold(3,3,1).unfold(4,3,1)).permute(0,1,5,2,3,4).reshape(1,12*27,H//grid_sp,W//grid_sp,D//grid_sp) | |
patch_mind_mov = nn.Flatten(5,)(F.pad(mind_mov,(1,1,1,1,1,1)).unfold(2,3,1).unfold(3,3,1).unfold(4,3,1)).permute(0,1,5,2,3,4).reshape(1,12*27,H//grid_sp,W//grid_sp,D//grid_sp) | |
#print(patch_mind_fix.shape) | |
#create optimisable displacement grid | |
net = nn.Sequential(nn.Conv3d(3,1,(H//grid_sp,W//grid_sp,D//grid_sp),bias=False)) | |
net[0].weight.data[:] = 0 | |
net.cuda() | |
optimizer = torch.optim.Adam(net.parameters(), lr=1) | |
torch.cuda.synchronize() | |
t0 = time.time() | |
grid0 = F.affine_grid(torch.eye(3,4).unsqueeze(0).cuda(),(1,1,H//grid_sp,W//grid_sp,D//grid_sp),align_corners=False) | |
#run Adam optimisation with diffusion regularisation and B-spline smoothing | |
if(lambda_weight is None): | |
lambda_weight = .75# sad: 10, ssd:0.75 | |
else: | |
lambda_weight = float(lambda_weight) | |
for iter in range(150): | |
optimizer.zero_grad() | |
disp_sample = F.avg_pool3d(F.avg_pool3d(net[0].weight,5,stride=1,padding=2),5,stride=1,padding=2).permute(0,2,3,4,1) | |
reg_loss = lambda_weight*((disp_sample[0,:,1:,:]-disp_sample[0,:,:-1,:])**2).mean()+\ | |
lambda_weight*((disp_sample[0,1:,:,:]-disp_sample[0,:-1,:,:])**2).mean()+\ | |
lambda_weight*((disp_sample[0,:,:,1:]-disp_sample[0,:,:,:-1])**2).mean() | |
#grid_disp = grid0.view(-1,3).cuda().float()+((disp_sample.view(-1,3))/torch.tensor([63/2,63/2,68/2]).unsqueeze(0).cuda()).flip(1) | |
scale = torch.tensor([(H//grid_sp-1)/2,(W//grid_sp-1)/2,(D//grid_sp-1)/2]).cuda().unsqueeze(0) | |
grid_disp = grid0.view(-1,3).cuda().float()+((disp_sample.view(-1,3))/scale).flip(1).float() | |
patch_mov_sampled = F.grid_sample(patch_mind_mov.float(),grid_disp.view(1,H//grid_sp,W//grid_sp,D//grid_sp,3).cuda(),align_corners=True,mode='bilinear')#,padding_mode='border') | |
sampled_cost = (patch_mov_sampled-patch_mind_fix).pow(2).mean(1)*12 | |
#sampled_cost = F.grid_sample(ssd2.view(-1,1,17,17,17).float(),disp_sample.view(-1,1,1,1,3)/disp_hw,align_corners=True,padding_mode='border') | |
loss = sampled_cost.mean() | |
(loss+reg_loss).backward() | |
optimizer.step() | |
torch.cuda.synchronize() | |
t1 = time.time() | |
print(t1-t0,'sec (optim)') | |
fitted_grid = disp_sample.permute(0,4,1,2,3).detach() | |
disp_hr = F.interpolate(fitted_grid*grid_sp,size=(H,W,D),mode='trilinear',align_corners=False) | |
#disp = disp_hr.cpu().float().permute(0,2,3,4,1)/torch.Tensor([H-1,W-1,D-1]).view(1,1,1,1,3)*2 | |
#disp = disp.flip(4) | |
if(inverse): | |
disp = disp_hr.cpu().float()/(torch.tensor([H-1,W-1,D-1]).view(1,3,1,1,1)/2) | |
disp = disp.flip(1) | |
disp,_ = inverse_consistency(-disp.cuda(),disp.cuda(),10) | |
disp_hr = disp.flip(1).cpu()*(torch.tensor([H-1,W-1,D-1]).view(1,3,1,1,1)/2) | |
disp_field = disp_hr.cpu().float().numpy() | |
#convert field to half-resolution npz | |
x = disp_field[0,0,:,:,:] | |
y = disp_field[0,1,:,:,:] | |
z = disp_field[0,2,:,:,:] | |
x1 = zoom(x,1/2,order=2).astype('float16') | |
y1 = zoom(y,1/2,order=2).astype('float16') | |
z1 = zoom(z,1/2,order=2).astype('float16') | |
#write out field | |
np.savez_compressed(output_field,np.stack((x1,y1,z1),0)) | |
if(output_warped is not None): | |
D, H, W = fixed.shape | |
identity = np.stack(np.meshgrid(np.arange(D), np.arange(H), np.arange(W), indexing='ij'),0) | |
moving_warped = map_coordinates(moving.numpy(), identity + disp_field[0], order=1) | |
nib.save(nib.Nifti1Image(moving_warped,np.eye(4)),output_warped) | |
# torch.save(disp.cpu().data,output_field) | |
if __name__ == "__main__": | |
parser = ArgumentParser() | |
parser.add_argument("--input_img_fixed", | |
type=str,required=True, | |
help="path to input fixed nifti") | |
parser.add_argument("--input_img_moving", | |
type=str,required=True, | |
help="path to input moving nifti") | |
parser.add_argument("--output_field", | |
type=str,required=True, | |
help="path to output displacement file") | |
parser.add_argument("--output_warped", | |
type=str,required=False, | |
help="path to output warped image nifti") | |
parser.add_argument("--grid_sp", | |
type=str,required=False, | |
help="integer value for grid_spacing (default = 4) ") | |
parser.add_argument("--lambda_weight", | |
type=str,required=False, | |
help="floating point value for regularisation weight (default = .75) ") | |
parser.add_argument("--inverse", | |
type=str,required=False, | |
help="integer value/boolean whether the transform should be inverted ") | |
adamreg(**vars(parser.parse_args())) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment