Skip to content

Instantly share code, notes, and snippets.

@bougui505
Created October 1, 2020 07:53
Show Gist options
  • Star 5 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save bougui505/e48db92b28bf314c00938591666afce1 to your computer and use it in GitHub Desktop.
Save bougui505/e48db92b28bf314c00938591666afce1 to your computer and use it in GitHub Desktop.
Iterative Closest Point (ICP) implementation with least squares fit (lstsq) in Pytorch
#!/usr/bin/env python
# -*- coding: UTF8 -*-
# Author: Guillaume Bouvier -- guillaume.bouvier@pasteur.fr
# https://research.pasteur.fr/en/member/guillaume-bouvier/
# 2020-10-01 09:51:45 (UTC+0200)
import sys
import torch
def icp(coords, coords_ref, device, n_iter):
"""
Iterative Closest Point
"""
for t in range(n_iter):
cdist = torch.cdist(coords - coords.mean(axis=0),
coords_ref - coords_ref.mean(axis=0))
mindists, argmins = torch.min(cdist, axis=1)
X, _ = torch.lstsq(coords_ref[argmins], coords)
coords = coords.mm(X[:3])
rmsd = torch.sqrt((X[3:]**2).sum(axis=1).mean())
print_progress(f'{t+1}/{n_iter}: {rmsd}')
return coords
def print_progress(instr):
sys.stdout.write(f'{instr}\r')
sys.stdout.flush()
@ShichengChen
Copy link

thanks for sharing

@smiles724
Copy link

This implementation is not right. torch.lstsq is abandoned and you should use torch.linalg.lstq instead.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment