Skip to content

Instantly share code, notes, and snippets.

@etienne87
Last active January 16, 2022 09:37
Show Gist options
  • Save etienne87/0e56cd7ccf6c684407bcb6d3ce6e1eb3 to your computer and use it in GitHub Desktop.
Save etienne87/0e56cd7ccf6c684407bcb6d3ce6e1eb3 to your computer and use it in GitHub Desktop.
understanding spatial transform in pytorch (simulate 2 vessels)
import matplotlib.pyplot as plt
plt.switch_backend('tkagg')
import torch as th
import torch.nn.functional as F
import numpy as np
import cv2
from scipy.spatial.transform import Rotation
def compute_meshgrid(b,d,h,w):
theta = th.eye(4)[:3][None].repeat(b,1,1)
grid = F.affine_grid(theta, (b,1,d,h,w))
return grid
def affine(vecs, trans, scale, rot, order='fw'):
out = vecs.clone()
b = len(out)
scale = scale.view((b,)+(1,)*(vecs.ndim-1))
trans = trans.view((b,)+(1,)*(vecs.ndim-2)+(3,))
if order == 'bw':
out = th.einsum('b...j,bjk->b...k', out, rot.permute(0,2,1))
out = out + trans
out = out * scale
else:
out = out / scale
out = out - trans
out = th.einsum('b...j,bjk->b...k', out, rot)
return out
def stn(x, trans, scale, rot, padding_mode='zeros'):
b,_,d,h,w = x.shape
# my method
mgrid = compute_meshgrid(b,d,h,w)
grid = affine(mgrid, trans, scale, rot, 'fw')
# alternative using direct composition
# t = compose_tsr(trans, scale, rot, 'bw')
# grid = F.affine_grid(t, x.size())
out = F.grid_sample(x, grid, padding_mode=padding_mode, align_corners=False)
return out
def generate_random_affine(num):
Rot = Rotation.random(num)
R = Rot.as_matrix().astype(np.float32)
R = th.from_numpy(R)
T = th.zeros((num,3), dtype=th.float32).uniform_(-0.7,0.7)
S = th.zeros((num,1), dtype=th.float32).uniform_(0.7,1.3)#.repeat(1,3)
return T,S,R
def compose_tsr(trans, scale, rot, order='bw'):
b = len(trans)
for i in range(3):
rot[:,i,i] *= scale[:,0]
T = th.zeros((b,4,4), dtype=th.float32)
T[:,:3,:3] = rot
T[:,:3,3] = trans
T[:,3,3] = 1
if order == 'fw':
T = th.linalg.inv(T)
return T[:,:3]
def test_stn_affine(d=64,h=64,w=64,radius=4,num=1):
# A Volume
vol = th.zeros((d,h,w),dtype=th.float32)
# Draw a cylinder
offx = -5
offy = -4
centers = []
for i in range(d):
offx += (i*0.005)**2
offy += (i*0.003)**2
center = (int(w//2+offx),int(h//2+offy))
centers.append(center)
cv2.circle(vol[i].numpy(), center, radius, 1, 0)
# 3 centers
center0 = centers[0]
center1 = centers[d//2]
center2 = centers[-1]
cx,cy,cz = center1[0],center1[1], d//2
orig = th.LongTensor([cx,cy,cz])
size = th.LongTensor([w//2,h//2,d//2])
px = center2[0]-center1[0]
py = center2[1]-center1[1]
pz = d//2
s = 0
vec = th.LongTensor([px,py,pz])
pt = orig + vec
prx = center0[0]-center1[0]
pry = center0[1]-center1[1]
prz = -d//2
vecpr = th.LongTensor([prx,pry,prz])
pt_pr = orig + vecpr
#Rot = Rotation.from_euler('xyz', [45,0,0], degrees=True)
T,S,R = generate_random_affine(num)
vol1 = vol[None,None].repeat(num,1,1,1,1)
vol2 = stn(vol1, T, S, R).squeeze()
z,y,x = th.where(vol > 0)
z2,y2,x2 = th.where(vol2 > 0)
# Find Direction using scaled Translation!
T2 = T*size
vec2 = affine(vec[None].float(), T2, S, R, 'fw')
vx2,vy2,vz2 = vec2.squeeze().numpy().tolist()
vec3 = affine(vecpr[None].float(), T2, S, R, 'fw')
vx3,vy3,vz3 = vec3.squeeze().numpy().tolist()
# Plot volume, transformed volume, direction & transformed direction
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.scatter(x,y,z, marker='.', alpha=0.2, color='gray', s=3, label='volume')
ax.scatter(x2,y2,z2, marker='.', alpha=0.1, color='gray', s=3, label='volume aug')
ax.quiver(cx,cy,cz,px,py,pz, linewidths=3, color=['blue'], label='direction_gt')
ax.quiver(cx,cy,cz,prx,pry,prz, linewidths=3, color=['orange'], label='direction_gt_prev')
ax.quiver(cx,cy,cz,vx2,vy2,vz2, linewidths=3, color=['blue'], label='direction_gt_aug')
ax.quiver(cx,cy,cz,vx3,vy3,vz3, linewidths=3, color=['orange'], label='direction_gt_prev_aug')
ax.set_xlim3d(0, w)
ax.set_ylim3d(0, h)
ax.set_zlim3d(0, d)
ax.set_xlabel('X Label')
ax.set_ylabel('Y Label')
ax.set_zlabel('Z Label')
ax.legend()
plt.show()
if __name__ == '__main__':
import fire;fire.Fire(test_stn_affine)
@etienne87
Copy link
Author

vessel

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment