Skip to content

Instantly share code, notes, and snippets.

@jamesgregson
Created September 13, 2019 15:41
Show Gist options
  • Save jamesgregson/c3a80a41c8cea808072375272ef3fc65 to your computer and use it in GitHub Desktop.
Save jamesgregson/c3a80a41c8cea808072375272ef3fc65 to your computer and use it in GitHub Desktop.
Fast Sweeping Method in 2D and 3D using PyBind11
/*cppimport
<%
setup_pybind11(cfg)
%>
*/
#include <pybind11/numpy.h>
#include <pybind11/pybind11.h>
namespace py = pybind11;
#include <cmath>
#include <iostream>
#include <algorithm>
template< typename Mask, typename Image >
void fast_sweeping_method_2d( Mask& msk, Image& dis, const int num_iterations, double maxval=1e10 ){
const int ny = dis.shape(1);
const int nx = dis.shape(0);
const int num_sweeps = 4;
const int sx[][2] = {{0,1},{nx-1,-1},{nx-1,-1},{0,1}};
const int sy[][2] = {{0,1},{0,1},{ny-1,-1},{ny-1,-1}};
auto xmin = [&]( int x, int y ){
if( x == 0 ){
return dis(x+1,y);
} else if( x == nx-1 ){
return dis(x-1,y);
} else {
return std::min(dis(x-1,y),dis(x+1,y));
}
};
auto ymin = [&]( int x, int y ){
if( y == 0 ){
return dis(x,y+1);
} else if( y == ny-1 ){
return dis(x,y-1);
} else {
return std::min(dis(x,y-1),dis(x,y+1));
}
};
double a, b;
for( auto iter=0; iter<num_iterations; ++iter ){
for( auto s=0; s<num_sweeps; ++s ){
for( auto x=sx[s][0]; x >= 0 && x < nx; x+=sx[s][1] ){
for( auto y=sy[s][0]; y>= 0 && y < ny; y+=sy[s][1] ){
if( msk(x,y) ){
a = xmin(x,y);
b = ymin(x,y);
if( fabs(a-b) >= 1.0 ){
dis(x,y) = std::min(a,b) + 1.0;
} else {
dis(x,y) = (a+b+sqrt(2.0-(a-b)*(a-b)))/2.0;
}
}
}
}
}
}
}
template< typename Mask, typename Image >
void fast_sweeping_method_3d( Mask& msk, Image& dis, const int num_iterations, double maxval=1e10 ){
const int nz = dis.shape(2);
const int ny = dis.shape(1);
const int nx = dis.shape(0);
const int num_sweeps = 8;
const int sx[][2] = {{0,1},{nx-1,-1},{nx-1,-1},{0,1},{0,1},{nx-1,-1},{nx-1,-1},{0,1}};
const int sy[][2] = {{0,1},{0,1},{ny-1,-1},{ny-1,-1},{0,1},{0,1},{ny-1,-1},{ny-1,-1}};
const int sz[][2] = {{0,1},{0,1},{0,1},{0,1},{nz-1,-1},{nz-1,-1},{nz-1,-1},{nz-1,-1}};
auto solve_eikonal = [=]( int dim, double* T ){
// https://github.com/jvgomez/fast_methods/blob/master/doc/files/phd_thesis.pdf
double a,b,c,q,sum,sum2;
if( dim == 0 ){
return T[0]+1.0;
} else if( dim == 1 ){
sum = T[0]+T[1];
sum2 = T[0]*T[0] + T[1]*T[1];
} else {
sum = T[0]+T[1]+T[2];
sum2 = T[0]*T[0] + T[1]*T[1] + T[2]*T[2];
}
// using zero-based index, need to add 1 to a!
a = double(dim+1);
b = -2.0*sum;
c = sum2 - 1.0;
q = b*b - 4.0*a*c;
return q < 0.0 ? maxval : (-b+sqrt(q))/(2.0*a);
};
auto xmin = [&]( int x, int y, int z ){
if( x == 0 ){
return dis(x+1,y,z);
} else if( x == nx-1 ){
return dis(x-1,y,z);
} else {
return std::min(dis(x-1,y,z),dis(x+1,y,z));
}
};
auto ymin = [&]( int x, int y, int z ){
if( y == 0 ){
return dis(x,y+1,z);
} else if( y == ny-1 ){
return dis(x,y-1,z);
} else {
return std::min(dis(x,y-1,z),dis(x,y+1,z));
}
};
auto zmin = [&]( int x, int y, int z ){
if( z == 0 ){
return dis(x,y,z+1);
} else if( z == nz-1 ){
return dis(x,y,z-1);
} else {
return std::min(dis(x,y,z-1), dis(x,y,z+1));
}
};
int max_dim;
for( auto iter=0; iter<num_iterations; ++iter ){
for( auto s=0; s<num_sweeps; ++s ){
for( auto x=sx[s][0]; x >=0 && x<nx; x+=sx[s][1] ){
for( auto y=sy[s][0]; y>=0 && y<ny; y+=sy[s][1] ){
for( auto z=sz[s][0]; z>=0 && z<nz; z+=sz[s][1] ){
if( msk(x,y,z) ){
// sort axis values
double d,a[] = { xmin(x,y,z), ymin(x,y,z), zmin(x,y,z), maxval };
std::sort( &a[0], &a[0]+3 );
max_dim = (a[0]<maxval) + (a[1]<maxval) + (a[2]<maxval);
if( max_dim > 0 ){
for( auto dim=0; dim<max_dim; ++dim ){
d = solve_eikonal( dim, a );
if( d < a[dim+1] ){
break;
}
}
dis(x,y,z) = d;
} else {
dis(x,y,z) = maxval;
}
}
}
}
}
}
}
}
void fast_sweeping_method_2d_wrap( py::array_t<bool> &mask, py::array_t<double> &distance, int num_iterations, double maxval ){
auto msk = mask.unchecked<2>();
auto dis = distance.mutable_unchecked<2>();
fast_sweeping_method_2d( msk, dis, num_iterations, maxval );
}
void fast_sweeping_method_3d_wrap( py::array_t<bool> &mask, py::array_t<double> &distance, int num_iterations, double maxval ){
auto msk = mask.unchecked<3>();
auto dis = distance.mutable_unchecked<3>();
fast_sweeping_method_3d( msk, dis, num_iterations, maxval );
}
PYBIND11_MODULE( fsm, m) {
m.def("fast_sweeping_method_2d", &fast_sweeping_method_2d_wrap );
m.def("fast_sweeping_method_3d", &fast_sweeping_method_3d_wrap );
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment