Created March 7, 2021 23:29
Slerp for quaternions in TF
def slerp(v0, v1, t=0.05):
Interpolate between quaternions v0 and v1.
v0 = safe_normalize(v0)
v1 = safe_normalize(v1)
dot = tf.reduce_sum(v0*v1,axis=-1,keepdims=True)
# If the dot product is negative, slerp won't take
# the shorter path. Note that v1 and -v1 are equivalent when
# the negation is applied to all four components. Fix by
# reversing one quaternion.
signflip = tf.where(tf.less_equal(dot,0.),-1.*tf.ones_like(dot),tf.ones_like(dot))
v1 *= signflip
dot *= signflip
# Linear answer.
linq = safe_normalize(v0 + t*(v1-v0))
sdot = tf.clip_by_value(dot,-1.0,1.0)
theta_0 = tf.acos(sdot)
theta = theta_0*t
sin_theta = tf.sin(theta)
sin_theta_0 = tf.sin(theta_0)
s0 = tf.cos(theta) - dot * sin_theta / (sin_theta_0+1e-19)
s1 = sin_theta / (sin_theta_0+1e-19)
sq = safe_normalize((s0 * v0) + (s1 * v1))
tdot = tf.concat([dot,dot,dot,dot],axis=-1)
slerpd = tf.where(tf.greater(tdot,DOT_THRESHOLD),linq,sq)
ttiled = tf.concat([t,t,t,t],axis=-1)
slerpdorv1 = tf.where(tf.greater(ttiled,1.0-1e-14),v1,slerpd)
return tf.where(tf.less(ttiled,1e-14),v0,slerpdorv1)
def sftpluswparam(x):
return tf.log(1.0 + tf.exp(100. * x)) / 100.0
def RotToQuat(axes_):
axes is a ... X 3 3 tensor of axes
this generates a ... X 4 tensor of quaternions.
which are 1:1 with those axes.
w = (1./2.)*tf.sqrt(1e-15+tf.abs(1 + axes_[...,0, 0] + axes_[...,1, 1] + axes_[...,2, 2]))
x = tf.sign(axes_[...,2, 1] - axes_[...,1, 2])*tf.abs(0.5*tf.sqrt(1e-15+tf.abs(1.0 + axes_[...,0, 0] - axes_[...,1, 1] - axes_[...,2, 2])))
y = tf.sign(axes_[...,0, 2] - axes_[...,2, 0])*tf.abs(0.5*tf.sqrt(1e-15+tf.abs(1.0 - axes_[...,0, 0] + axes_[...,1, 1] - axes_[...,2, 2])))
z = tf.sign(axes_[...,1, 0] - axes_[...,0, 1])*tf.abs(0.5*tf.sqrt(1e-15+tf.abs(1.0 - axes_[...,0, 0] - axes_[...,1, 1] + axes_[...,2, 2])))
return tf.stack([w,x,y,z],axis=-1)
def QuatToRot(q):
a_ ... X 4 tensor of quaternions
this generates a ... X 3 X 3 of rotation matrices.
tmp=tf.stack([1 - 2.*(q[...,2]*q[...,2] + q[...,3]*q[...,3]), 2*(q[...,1]*q[...,2] - q[...,3]*q[...,0]),
2*(q[...,1]*q[...,3] + q[...,2]*q[...,0]),2*(q[...,1]*q[...,2] + q[...,3]*q[...,0]), 1 - 2.*(q[...,1]*q[...,1] + q[...,3]*q[...,3]),
2*(q[...,2]*q[...,3] - q[...,1]*q[...,0]),2*(q[...,1]*q[...,3] - q[...,2]*q[...,0]), 2*(q[...,2]*q[...,3] + q[...,1]*q[...,0]),
1 - 2.*(q[...,1]*q[...,1] + q[...,2]*q[...,2])],axis=-1)
return tf.reshape(tmp,[-1,3,3])
def VectorsToOrient(v1,v2):
v1n = safe_normalize(v1)
v2n = safe_normalize(v2)
v3 = safe_normalize(tf.cross(v1n, v2n)+tf.constant(np.array([0., 0., 1e-19]), dtype=tf.float64))
# Compute the average of v1, v2, and their projections onto the
# plane.
v_av = (v1n + v2n) / 2.0
v_av = safe_normalize(v_av)
# Rotate pi/4 cw and ccw to obtain v1,v2
first = TF_AxisAngleRotation(v3, v_av, tf.constant(-Pi / 4., dtype=tf.float64))
second = TF_AxisAngleRotation(v3, v_av,tf.constant(Pi / 4., dtype=tf.float64))
vs = tf.concat([first[:, tf.newaxis, :], second[:, tf.newaxis, :],v3[:, tf.newaxis, :]],axis=1)
return vs
def VectorsToAxisQs(v1,v2):
return tf.reshape(RotToQuat(VectorsToOrient(v1,v2)),(-1, 4))
def safe_normalize(x_):
nrm = tf.clip_by_value(tf.norm(x_,axis=-1,keepdims=True),1e-36,1e36)
nrm_ok = tf.logical_and(tf.not_equal(nrm,0.),tf.logical_not(tf.is_nan(nrm)))
safe_nrm = tf.where(nrm_ok,nrm,tf.ones_like(nrm))
return x_*tf.where(nrm_ok,1.0/safe_nrm,tf.zeros_like(nrm))
def safe_inv_norm(x_):
nrm = tf.clip_by_value(tf.norm(x_,axis=-1,keepdims=True),1e-36,1e36)
nrm_ok = tf.logical_and(tf.not_equal(nrm,0.),tf.logical_not(tf.is_nan(nrm)))
safe_nrm = tf.where(nrm_ok,nrm,tf.ones_like(nrm))
return tf.where(nrm_ok,1.0/safe_nrm,tf.zeros_like(nrm))
def safe_norm(x_):
nrm = tf.clip_by_value(tf.norm(x_, axis=-1, keepdims=True), 1e-36, 1e36)
nrm_ok = tf.logical_and(
tf.not_equal(nrm, 0.), tf.logical_not(tf.is_nan(nrm)))
safe_nrm = tf.where(nrm_ok, nrm, tf.zeros_like(nrm))
return safe_nrm
with tf.Graph().as_default():
xyzs = tf.Variable(np.random.random((batch_size,MaxNAtom,3))*7.0 - 5.0)
init = tf.global_variables_initializer()
sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
