Skip to content

Instantly share code, notes, and snippets.

@podgorskiy
Last active February 26, 2020 07:32
Show Gist options
  • Save podgorskiy/fd6b2afad62918c4d2084a3128771a62 to your computer and use it in GitHub Desktop.
Save podgorskiy/fd6b2afad62918c4d2084a3128771a62 to your computer and use it in GitHub Desktop.
Random search for ratation matrix optimization
import numpy as np
from random import random
def optimize(f, t_scale, steps=10000, tolerance=1e-6, step_decay_exp=0.01):
def random_rotation(size):
shape = [size, size]
a = np.random.normal(-1.0, 1.0, shape)
u, s, v = np.linalg.svd(a, full_matrices=False)
return u
R = random_rotation(3)
T = np.random.randn(3, 1) * t_scale
def get_base_rotation(alpha, size):
alphas = np.sin(alpha)
alphac = np.cos(alpha)
flat_rotation = np.array([[alphac, -alphas], [alphas, alphac]])
I = np.eye(size)
I[0:2, 0:2] = flat_rotation
return I.astype(np.float32)
mse = f(R, T)
for i in range(steps):
step = 1.0 / np.exp(i * step_decay_exp)
basis = random_rotation(3)
if random() > 0.5:
s = step
else:
s = -step
deltaR = np.matmul(basis.T, np.matmul(get_base_rotation(s, 3), basis))
R_new = np.matmul(R, deltaR)
mse_new = f(R_new, T)
if mse_new < mse:
mse = mse_new
R = R_new
print(mse)
if mse < tolerance:
break
T_new = T + np.random.randn(3, 1) * step * t_scale
mse_new = f(R, T_new)
if mse_new < mse:
mse = mse_new
T = T_new
print(mse)
if mse < tolerance:
break
return R, T
# some test
if __name__ == "__main__":
def random_rotation(size):
shape = [size, size]
a = np.random.normal(-1.0, 1.0, shape)
u, s, v = np.linalg.svd(a, full_matrices=False)
return u
# random matrix (ground-truth)
R_gt = random_rotation(3)
T_gt = np.random.randn(3, 1)
# bunch of vectors
X = np.random.randn(3, 20)
# transformed vectors
Xp = np.matmul(R_gt, X) + T_gt
################################
# Input: X, Xp
# Output: learned R
################################
def func(R, T):
return np.linalg.norm(np.matmul(R, X) + T - Xp)
R, T = optimize(func, t_scale=3.0)
print("R GT:\n", R_gt)
print("R learned:\n", R)
print()
print("T GT:\n", T_gt)
print("T learned:\n", T)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment