Created
September 13, 2019 15:41
-
-
Save jamesgregson/c3a80a41c8cea808072375272ef3fc65 to your computer and use it in GitHub Desktop.
Fast Sweeping Method in 2D and 3D using PyBind11
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/*cppimport | |
<% | |
setup_pybind11(cfg) | |
%> | |
*/ | |
#include <pybind11/numpy.h> | |
#include <pybind11/pybind11.h> | |
namespace py = pybind11; | |
#include <cmath> | |
#include <iostream> | |
#include <algorithm> | |
template< typename Mask, typename Image > | |
void fast_sweeping_method_2d( Mask& msk, Image& dis, const int num_iterations, double maxval=1e10 ){ | |
const int ny = dis.shape(1); | |
const int nx = dis.shape(0); | |
const int num_sweeps = 4; | |
const int sx[][2] = {{0,1},{nx-1,-1},{nx-1,-1},{0,1}}; | |
const int sy[][2] = {{0,1},{0,1},{ny-1,-1},{ny-1,-1}}; | |
auto xmin = [&]( int x, int y ){ | |
if( x == 0 ){ | |
return dis(x+1,y); | |
} else if( x == nx-1 ){ | |
return dis(x-1,y); | |
} else { | |
return std::min(dis(x-1,y),dis(x+1,y)); | |
} | |
}; | |
auto ymin = [&]( int x, int y ){ | |
if( y == 0 ){ | |
return dis(x,y+1); | |
} else if( y == ny-1 ){ | |
return dis(x,y-1); | |
} else { | |
return std::min(dis(x,y-1),dis(x,y+1)); | |
} | |
}; | |
double a, b; | |
for( auto iter=0; iter<num_iterations; ++iter ){ | |
for( auto s=0; s<num_sweeps; ++s ){ | |
for( auto x=sx[s][0]; x >= 0 && x < nx; x+=sx[s][1] ){ | |
for( auto y=sy[s][0]; y>= 0 && y < ny; y+=sy[s][1] ){ | |
if( msk(x,y) ){ | |
a = xmin(x,y); | |
b = ymin(x,y); | |
if( fabs(a-b) >= 1.0 ){ | |
dis(x,y) = std::min(a,b) + 1.0; | |
} else { | |
dis(x,y) = (a+b+sqrt(2.0-(a-b)*(a-b)))/2.0; | |
} | |
} | |
} | |
} | |
} | |
} | |
} | |
template< typename Mask, typename Image > | |
void fast_sweeping_method_3d( Mask& msk, Image& dis, const int num_iterations, double maxval=1e10 ){ | |
const int nz = dis.shape(2); | |
const int ny = dis.shape(1); | |
const int nx = dis.shape(0); | |
const int num_sweeps = 8; | |
const int sx[][2] = {{0,1},{nx-1,-1},{nx-1,-1},{0,1},{0,1},{nx-1,-1},{nx-1,-1},{0,1}}; | |
const int sy[][2] = {{0,1},{0,1},{ny-1,-1},{ny-1,-1},{0,1},{0,1},{ny-1,-1},{ny-1,-1}}; | |
const int sz[][2] = {{0,1},{0,1},{0,1},{0,1},{nz-1,-1},{nz-1,-1},{nz-1,-1},{nz-1,-1}}; | |
auto solve_eikonal = [=]( int dim, double* T ){ | |
// https://github.com/jvgomez/fast_methods/blob/master/doc/files/phd_thesis.pdf | |
double a,b,c,q,sum,sum2; | |
if( dim == 0 ){ | |
return T[0]+1.0; | |
} else if( dim == 1 ){ | |
sum = T[0]+T[1]; | |
sum2 = T[0]*T[0] + T[1]*T[1]; | |
} else { | |
sum = T[0]+T[1]+T[2]; | |
sum2 = T[0]*T[0] + T[1]*T[1] + T[2]*T[2]; | |
} | |
// using zero-based index, need to add 1 to a! | |
a = double(dim+1); | |
b = -2.0*sum; | |
c = sum2 - 1.0; | |
q = b*b - 4.0*a*c; | |
return q < 0.0 ? maxval : (-b+sqrt(q))/(2.0*a); | |
}; | |
auto xmin = [&]( int x, int y, int z ){ | |
if( x == 0 ){ | |
return dis(x+1,y,z); | |
} else if( x == nx-1 ){ | |
return dis(x-1,y,z); | |
} else { | |
return std::min(dis(x-1,y,z),dis(x+1,y,z)); | |
} | |
}; | |
auto ymin = [&]( int x, int y, int z ){ | |
if( y == 0 ){ | |
return dis(x,y+1,z); | |
} else if( y == ny-1 ){ | |
return dis(x,y-1,z); | |
} else { | |
return std::min(dis(x,y-1,z),dis(x,y+1,z)); | |
} | |
}; | |
auto zmin = [&]( int x, int y, int z ){ | |
if( z == 0 ){ | |
return dis(x,y,z+1); | |
} else if( z == nz-1 ){ | |
return dis(x,y,z-1); | |
} else { | |
return std::min(dis(x,y,z-1), dis(x,y,z+1)); | |
} | |
}; | |
int max_dim; | |
for( auto iter=0; iter<num_iterations; ++iter ){ | |
for( auto s=0; s<num_sweeps; ++s ){ | |
for( auto x=sx[s][0]; x >=0 && x<nx; x+=sx[s][1] ){ | |
for( auto y=sy[s][0]; y>=0 && y<ny; y+=sy[s][1] ){ | |
for( auto z=sz[s][0]; z>=0 && z<nz; z+=sz[s][1] ){ | |
if( msk(x,y,z) ){ | |
// sort axis values | |
double d,a[] = { xmin(x,y,z), ymin(x,y,z), zmin(x,y,z), maxval }; | |
std::sort( &a[0], &a[0]+3 ); | |
max_dim = (a[0]<maxval) + (a[1]<maxval) + (a[2]<maxval); | |
if( max_dim > 0 ){ | |
for( auto dim=0; dim<max_dim; ++dim ){ | |
d = solve_eikonal( dim, a ); | |
if( d < a[dim+1] ){ | |
break; | |
} | |
} | |
dis(x,y,z) = d; | |
} else { | |
dis(x,y,z) = maxval; | |
} | |
} | |
} | |
} | |
} | |
} | |
} | |
} | |
void fast_sweeping_method_2d_wrap( py::array_t<bool> &mask, py::array_t<double> &distance, int num_iterations, double maxval ){ | |
auto msk = mask.unchecked<2>(); | |
auto dis = distance.mutable_unchecked<2>(); | |
fast_sweeping_method_2d( msk, dis, num_iterations, maxval ); | |
} | |
void fast_sweeping_method_3d_wrap( py::array_t<bool> &mask, py::array_t<double> &distance, int num_iterations, double maxval ){ | |
auto msk = mask.unchecked<3>(); | |
auto dis = distance.mutable_unchecked<3>(); | |
fast_sweeping_method_3d( msk, dis, num_iterations, maxval ); | |
} | |
PYBIND11_MODULE( fsm, m) { | |
m.def("fast_sweeping_method_2d", &fast_sweeping_method_2d_wrap ); | |
m.def("fast_sweeping_method_3d", &fast_sweeping_method_3d_wrap ); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment