Skip to content

Instantly share code, notes, and snippets.

@jswhit
Last active January 8, 2019 17:34
Show Gist options
  • Save jswhit/27e7bb07f1567d389b46199473a96df4 to your computer and use it in GitHub Desktop.
Save jswhit/27e7bb07f1567d389b46199473a96df4 to your computer and use it in GitHub Desktop.
MPI parallel nearest neighbor search on a sphere
from __future__ import print_function # Python 3 compatible print function
"""find nearest neighbors for points on a sphere when locations to be compared
are distributed across MPI tasks"""
# to run: mpirun -np 4 python parallel_nn_sphere.py
# requires mpi4py ('conda install mpi4py' in anaconda python)
from mpi4py import MPI
import numpy as np
import sys
# generate lat,lon values for random points on the surface of a unit sphere
def random_sphere(npts):
# generate random points on a sphere,
# so that every small area on the sphere is expected
# to have the same number of points.
# http://mathworld.wolfram.com/SpherePointPicking.html
u = np.random.uniform(0.,1.,size=npts)
v = np.random.uniform(0.,1.,size=npts)
lons = 2.*np.pi*u
lats = np.arccos(2*v-1) - np.pi/2.
return lons, lats
# great circle distance
def gcdist(lon1,lat1,lon2,lat2):
# compute great circle distance in radians between (lon1,lat1) and
# (lon2,lat2).
# lon,lat pairs given in radians - returned distance is in radians.
# uses Haversine formula
dlon = lon2 - lon1
dlat = lat2 - lat1
a = (np.sin(dlat/2))**2 + np.cos(lat1) * np.cos(lat2) * (np.sin(dlon/2))**2
# this can happen due to roundoff error, resulting in dist = NaN.
a = a.clip(0.,1.)
return 2.0 * np.arctan2( np.sqrt(a), np.sqrt(1-a) )
# function to find nearest neighbors
# this could be replaced by more efficient kd tree
def find_neighbors(lon1,lat1,lon2,lat2,radius):
# lon1,lat1 is a single point
# lon2,lat2 are vectors of lons/lats
# return indices of lon2,lat2 within radius.
r = gcdist(lon1,lat1,lon2,lat2)
return np.where(r <= radius)[0]
# function to perform distributed nearest neighbor search
def find_distributed_neigbors(lon,lat,oblons,oblats,radius,comm=MPI.COMM_WORLD):
# find neighbors within radius of lon,lat in oblons,oblats
# oblons, oblats are distributed across MPI tasks.
# returns oblons_close_all, oblats_close_all which contain
# all neighbors. Must be called by all MPI tasks.
# create arrays needed for MPI
rank = comm.rank
nprocs = comm.Get_size()
lonstmp = np.empty(nprocs, np.float64)
latstmp = np.empty(nprocs, np.float64)
recvcounts = np.empty(nprocs,np.int)
# broadcast this state variable location to all other tasks
# NOTE: allgather only works if same number of state locations
# on each task.
comm.Allgather(lon,lonstmp)
comm.Allgather(lat,latstmp)
# lonstmp now contains a vector length nprocs with the i'th value
# of the state variable longitude location for each task.
# now loop over the locations in lonstmp,latstmp
# j is the task number that this location belongs to
for j in range(nprocs):
# check for missing value, if found set neighbors arrays empty
# and continue loop.
# if there are not the same number of points assigned to each task,
# the arrays can be padded with nans.
if np.isnan(lonstmp[j]) or np.isnan(latstmp[j]):
oblons_close_all_tmp = np.array([],np.float64)
oblats_close_all_tmp = np.array([],np.float64)
if rank == j:
oblons_close_all=oblons_close_all_tmp
oblats_close_all=oblats_close_all_tmp
continue
# find the nearest neighbor ob locations for lonstmp[j],latstmp[j]
# on this task
indices = find_neighbors(lonstmp[j],latstmp[j],oblons,oblats,radius)
oblons_close = oblons[indices]
oblats_close = oblats[indices]
# recvcounts is the number of nearest neighbors found on each task
# ncount is the number of nearest neighbors on this task
ncount = np.asarray(indices.size)
# send recvcounts to all tasks.
comm.Allgather(ncount,recvcounts)
# obclose_all_tmp is an array to hold all the nearest
# neighbors found on all tasks. Only needs to be
# allocated on task responsible for this state variable.
if rank==j:
ncount_all = recvcounts.sum()
oblons_close_all_tmp=np.empty(ncount_all,np.float64)
oblats_close_all_tmp=np.empty(ncount_all,np.float64)
else:
oblons_close_all_tmp=None
oblats_close_all_tmp=None
# displs is the 'displacement index vector' for Gatherv
displs = np.zeros(nprocs,np.int)
for nrank in range(1,nprocs):
displs[nrank]=displs[nrank-1]+recvcounts[nrank-1]
# gather all nearest neighbors on task responsible for this state
# variable (rank=j).
comm.Gatherv([oblons_close,recvcounts[rank],MPI.DOUBLE],[oblons_close_all_tmp,tuple(recvcounts),tuple(displs),MPI.DOUBLE],root=j)
comm.Gatherv([oblats_close,recvcounts[rank],MPI.DOUBLE],[oblats_close_all_tmp,tuple(recvcounts),tuple(displs),MPI.DOUBLE],root=j)
# save result on rank j
if rank==j:
oblons_close_all=oblons_close_all_tmp
oblats_close_all=oblats_close_all_tmp
return oblons_close_all, oblats_close_all
# get MPI task info
comm = MPI.COMM_WORLD
rank = comm.rank # The process ID (integer 0-3 for 4-process run)
nprocs = comm.Get_size() # total number of MPI tasks
# total number of state variable locations, distributed evenly over tasks
npts = 1000
if npts % nprocs:
if rank==0: sys.stdout.write('npts must be divisible by nprocs, exiting ...')
raise SystemExit
npts_pertask = npts // nprocs
xlons, xlats = random_sphere(npts_pertask)
# add a missing value lon/lat pair on root task
if rank==0: xlons[-1]=np.nan; xlats[-1]=np.nan
# total number of observation locations, distributed evenly over tasks.
nobs = 1000
if nobs % nprocs:
if rank==0: sys.stdout.write('nobs must be divisible by nprocs, exiting ...')
raise SystemExit
nobs_pertask = nobs // nprocs
oblons, oblats = random_sphere(nobs_pertask)
# Allgather to get all ob locations on all tasks (for debugging)
check_result = True
if check_result:
oblons_all = np.empty(nobs,np.float64)
oblats_all = np.empty(nobs,np.float64)
comm.Allgather(oblons,oblons_all)
comm.Allgather(oblats,oblats_all)
# nearest neighbor search radius (radians)
radius = 0.25
# measure walltime in this loop
t1 = MPI.Wtime()
# loop over state variables on each task
for i in range(npts_pertask):
# find all neighbors for xlons[i],xlats[i] on this task, considering oblons,oblats across all MPI tasks.
oblons_close_all,oblats_close_all = find_distributed_neigbors(xlons[i],xlats[i],oblons,oblats,radius)
# check result
if check_result and oblons_close_all.size > 0: # non-empty neighbors array
# find correct answer by searching all ob locations on each task
# (this is just for checking the answer, the whole point of this approach
# is to avoid having a global array of ob locations on each task)
indices = find_neighbors(xlons[i],xlats[i],oblons_all,oblats_all,radius)
oblons_close_all_check = oblons_all[indices]
oblats_close_all_check = oblats_all[indices]
difflons = np.abs(np.sort(oblons_close_all_check)-np.sort(oblons_close_all))
difflats = np.abs(np.sort(oblats_close_all_check)-np.sort(oblats_close_all))
if difflons.max() > 1.e-10 or difflats.max() > 1.e-10:
print('incorrect result on rank',rank)
# print out mean wall clock time spent in above loop.
# should be nearly constant with number of MPI tasks.
# So, this approach doesn't speed up the search, but it does reduce the memory overhead
# by eliminating the need for global arrays.
t = MPI.Wtime() - t1
tmean = np.array(0.,np.float64)
comm.Reduce(np.array(t,np.float64),tmean,op=MPI.SUM,root=0)
if rank==0: print('total time=',tmean/nprocs)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment