Skip to content

Instantly share code, notes, and snippets.

@ivan-pi
Created December 8, 2020 13:04
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ivan-pi/661c0884069baced35c71e6d5b6fe3ce to your computer and use it in GitHub Desktop.
Save ivan-pi/661c0884069baced35c71e6d5b6fe3ce to your computer and use it in GitHub Desktop.
Ball tree in Fortran following the one by Jake Vanderplas (see https://gist.github.com/jakevdp/5216193). Construction appears to work, querying is broken!
module nheap_mod
implicit none
private
public :: nheap
integer, parameter :: wp = kind(1.0d0)
type :: nheap
real(wp), allocatable :: distances(:,:)
integer, allocatable :: indices(:,:)
contains
procedure :: init => nheap_init
procedure :: largest => nheap_largest
procedure :: push => nheap_push
procedure :: get_arrays => nheap_get_arrays
end type
contains
subroutine nheap_init(self,n_pts,n_nbrs)
class(nheap), intent(inout) :: self
integer, intent(in) :: n_pts, n_nbrs
allocate(self%distances(0:n_pts-1,0:n_nbrs-1))
self%distances = 0.0_wp + huge(self%distances)
allocate(self%indices(0:n_pts-1,0:n_nbrs-1))
self%indices = 0
end subroutine
function nheap_largest(self,row) result(res)
class(nheap), intent(in) :: self
integer, intent(in) :: row
real(wp) :: res
res = self%distances(row,0)
end function
subroutine nheap_push(self,row,val,i_val)
class(nheap), intent(inout) :: self
integer, intent(in) :: row, i_val
real(wp), intent(in) :: val
integer :: sz, i, ic1, ic2, i_swap
sz = size(self%distances,dim=2)
! check if val shoud be in heap
if (val > self%distances(row,0)) then
return
end if
! insert val at position zero
self%distances(row,0) = val
self%indices(row,0) = i_val
! descend the heap, swapping values until the max heap criterion is met
i = 0
do
ic1 = 2*i + 1
ic2 = ic1 + 1
if (ic1 >= sz) then
exit
else if (ic2 >= sz) then
if (self%distances(row,ic1) > val) then
i_swap = ic1
else
exit
end if
else if (self%distances(row,ic1) >= self%distances(row,ic2)) then
if (val < self%distances(row,ic1)) then
i_swap = ic1
else
exit
end if
else
if (val < self%distances(row,ic2)) then
i_swap = ic2
else
exit
end if
end if
self%distances(row,i) = self%distances(row,i_swap)
self%indices(row,i) = self%indices(row,i_swap)
i = i_swap
end do
self%distances(row,i) = val
self%indices(row,i) = i_val
end subroutine
subroutine nheap_get_arrays(self,sort,distances,indices)
class(nheap), intent(in) :: self
logical, intent(in) :: sort
real(wp), intent(out) :: distances(0:size(self%distances,1)-1,0:size(self%distances,2)-1)
integer, intent(out) :: indices(0:size(self%indices,1)-1,0:size(self%indices,2)-1)
real(wp), allocatable :: d(:)
integer, allocatable :: j(:)
integer :: n_pts,n_nbrs, i, k
n_pts = size(self%distances,1)
n_nbrs = size(self%distances,2)
allocate(d(0:n_nbrs-1))
allocate(j(0:n_nbrs-1))
if (sort) then
do i = 0, n_pts - 1
d = self%distances(i,:)
! print *, "d = ", d
j = self%indices(i,:)
! call quicksort(d,j)
call quicksort(d,j)
! print *, "d = ", d
distances(i,:) = d
indices(i,:) = j
end do
else
distances = self%distances
indices = self%indices
end if
end subroutine
! quicksort.f -*-f90-*-
! Author: t-nissie, some tweaks by 1AdAstra1
! License: GPLv3
! Gist: https://gist.github.com/t-nissie/479f0f16966925fa29ea
!!
recursive subroutine quicksort(a,perm)
implicit none
real(wp), intent(inout) :: a(:)
integer, intent(inout) :: perm(size(a))
real(wp) x, t
integer :: first = 1, last
integer i, j, ti
last = size(a, 1)
x = a( (first+last) / 2 ) ! could overflow if array is really big!
i = first
j = last
do
do while (a(i) < x)
i=i+1
end do
do while (x < a(j))
j=j-1
end do
if (i >= j) exit
t = a(i); a(i) = a(j); a(j) = t
ti = perm(i); perm(i) = perm(j); perm(j) = ti
i=i+1
j=j-1
end do
if (first < i - 1) call quicksort(a(first : i - 1), perm(first : i - 1))
if (j + 1 < last) call quicksort(a(j + 1 : last), perm(j + 1 : last))
end subroutine quicksort
end module
module btree_mod
use nheap_mod, only: nheap
implicit none
private
public :: wp
public :: btree, btree_init, btree_query
integer, parameter :: wp = kind(1.0d0)
type :: btree
real(wp), allocatable :: data(:,:)
integer :: leaf_size
integer :: n_samples, n_features
integer :: n_levels, n_nodes
integer, allocatable :: idx_array(:) ! n_samples
real(wp), allocatable :: node_radius(:) ! n_nodes
integer, allocatable :: node_idx_start(:) ! n_nodes
integer, allocatable :: node_idx_end(:) ! n_nodes
logical, allocatable :: node_is_leaf(:) ! n_nodes
real(wp), allocatable :: node_centroids(:,:) ! n_nodes, n_features
contains
procedure :: rdist
procedure :: min_rdist
procedure :: recursive_build
procedure :: query_recursive
end type
contains
function btree_init(data,leaf_size) result(self)
real(wp), intent(in) :: data(0:,0:)
integer, intent(in), optional :: leaf_size
type(btree) :: self
integer :: i, t
allocate(self%data,source=data)
print *, lbound(self%data,dim=1), ubound(self%data,dim=1)
self%leaf_size = 40
if (present(leaf_size)) self%leaf_size = leaf_size
print *, "[btree_init] leaf_size = ", self%leaf_size
self%n_samples = size(data,dim=1)
self%n_features = size(data,dim=2)
print *, "[btree_init] n_samples = ", self%n_samples
print *, "[btree_init] n_features = ", self%n_features
t = max(1,(self%n_samples - 1)/self%leaf_size)
self%n_levels = 1 + int(log2(real(t,wp))) ! floor division
self%n_nodes = 2**self%n_levels - 1
print *, "[btree_init] n_levels = ", self%n_levels
print *, "[btree_init] n_nodes = ", self%n_nodes
! allocate arrays for storage
allocate(self%idx_array(0:self%n_samples-1))
do i = 0, self%n_samples - 1
self%idx_array(i) = i
end do
print *, self%idx_array
allocate(self%node_radius(0:self%n_nodes-1))
self%node_radius = 0.0_wp
allocate(self%node_idx_start(0:self%n_nodes-1))
self%node_idx_start = 0
allocate(self%node_idx_end(0:self%n_nodes-1))
self%node_idx_end = 0
allocate(self%node_is_leaf(0:self%n_nodes-1))
self%node_is_leaf = .false.
allocate(self%node_centroids(0:self%n_nodes-1,0:self%n_features-1))
self%node_centroids = 0.0_wp
call self%recursive_build(0,0,self%n_samples)
end function
recursive subroutine recursive_build(self,i_node,idx_start,idx_end)
class(btree), intent(inout) :: self
integer, intent(in) :: i_node
integer, intent(in) :: idx_start, idx_end
integer :: n_mid
print *, "i_node,idx_start,idx_end",i_node,idx_start,idx_end
call init_node(self,i_node,idx_start,idx_end)
if ((2*i_node + 1) >= self%n_nodes) then
self%node_is_leaf(i_node) = .true.
if ((idx_end - idx_start) > 2*self%leaf_size) then
write(*,*) "Internal: memory layout is flawed: not enough nodes allocated"
end if
else if ((idx_end - idx_start) < 2) then
write(*,*) "Internal: memory layout is flawed: too many nodes allocated"
self%node_is_leaf(i_node) = .true.
else
! split node and recursively construct child nodes
self%node_is_leaf(i_node) = .false.
n_mid = int((idx_end + idx_start)/2)
call partition_indices(self%data,self%idx_array,idx_start,idx_end,n_mid)
call self%recursive_build(2*i_node+1,idx_start,n_mid)
call self%recursive_build(2*i_node+2,n_mid,idx_end)
end if
end subroutine
subroutine init_node(self,i_node,idx_start,idx_end)
type(btree), intent(inout) :: self
integer, intent(in) :: i_node, idx_start, idx_end
integer :: i, j
real(wp) :: sq_radius, sq_dist
! determine node centroid
do j = 0, self%n_features - 1
self%node_centroids(i_node,j) = 0
do i = idx_start, idx_end - 1
self%node_centroids(i_node,j) = self%node_centroids(i_node,j) + &
self%data(self%idx_array(i),j)
end do
self%node_centroids(i_node,j) = self%node_centroids(i_node,j)/real(idx_end - idx_start,wp)
end do
print *, "node_centroid = ", self%node_centroids(i_node,:)
! determine node radius
sq_radius = 0
do i = idx_start, idx_end -1
sq_dist = self%rdist(self%node_centroids,i_node,self%data,self%idx_array(i))
sq_radius = max(sq_radius,sq_dist)
end do
print *, "sq_radius, sq_dist", sq_radius,sq_dist
self%node_radius(i_node) = sqrt(sq_radius)
self%node_idx_start(i_node) = idx_start
self%node_idx_end(i_node) = idx_end
print *, "node_radius",self%node_radius(i_node)
print *, "node_idx_start",self%node_idx_start(i_node)
print *, "node_idx_end",self%node_idx_end(i_node)
! nbrhd = se
end subroutine
function rdist(self,x1,i1,x2,i2) result(d)
class(btree), intent(in) :: self
real(wp), intent(in) :: x1(0:self%n_nodes-1,0:self%n_features-1)
real(wp), intent(in) :: x2(0:,0:)
integer, intent(in) :: i1, i2
real(wp) :: d, tmp
integer :: k
d = 0
do k = 0, self%n_features - 1
tmp = x1(i1,k) - x2(i2,k)
d = d + tmp*tmp
end do
end function
function min_rdist(self,i_node,x,j) result(res)
class(btree), intent(in) :: self
integer, intent(in) :: i_node
real(wp), intent(in) :: x(0:,0:)
integer, intent(in) :: j
real(wp) :: d, res
d = self%rdist(self%node_centroids,i_node,x,j)
res = (max(0.0_wp,sqrt(d) - self%node_radius(i_node)))**2
end function
subroutine partition_indices(data,idx_array,idx_start,idx_end,split_index)
real(wp), intent(in) :: data(0:,0:)
integer, intent(inout) :: idx_array(0:size(data,dim=1)-1)
integer, intent(in) :: idx_start, idx_end, split_index
integer :: n_features, split_dim
real(wp) :: max_spread, max_val, min_val, val, d1, d2
integer :: i, j, left, right, midindex, tmp
! find the split dimension
n_features = size(data,dim=2)
split_dim = 0
max_spread = 0
do j = 0, n_features-1
max_val = -huge(data)
min_val = huge(data)
do i = idx_start, idx_end - 1
val = data(idx_array(i),j)
max_val = max(max_val,val)
min_val = min(min_val,val)
end do
if ((max_val - min_val) > max_spread) then
max_spread = max_val - min_val
split_dim = j
end if
end do
! partition using the split dimension
left = idx_start
right = idx_end - 1
do
midindex = left
do i = left, right - 1
d1 = data(idx_array(i),split_dim)
d2 = data(idx_array(right),split_dim)
if (d1 < d2) then
tmp = idx_array(i)
idx_array(i) = idx_array(midindex)
idx_array(midindex) = tmp
midindex = midindex + 1
end if
end do
tmp = idx_array(midindex)
idx_array(midindex) = idx_array(right)
idx_array(right) = tmp
if (midindex == split_index) then
exit
else if (midindex < split_index) then
left = midindex + 1
else
right = midindex - 1
end if
end do
end subroutine
real(wp) function log2(x)
real(wp), intent(in) :: x
log2 = log(x) / log(2._wp)
end function
subroutine btree_query(self,x,k,sort_results,distances,indices)
type(btree), intent(in) :: self
real(wp), intent(in) :: x(0:,0:)
integer, intent(in) :: k
logical, intent(in), optional :: sort_results
real(wp), intent(out) :: distances(0:size(x,1)-1,0:k-1)
integer, intent(out) :: indices(0:size(x,1)-1,0:k-1)
logical :: sort_results_
type(nheap) :: heap
integer :: i
real(wp) :: sq_dist_LB
if (size(x,2) /= self%n_features) then
write(*,*) "query data dimension must match training data dimension"
error stop 1
end if
if (size(self%data,1) < k) then
write(*,*) "k must be less than or equal to the number of training points"
error stop 1
end if
sort_results_ = .true.
if (present(sort_results)) sort_results_ = sort_results
call heap%init(size(x,1),k)
print *, shape(heap%distances)
print *, shape(heap%indices)
do i = 0, size(x,1) - 1
sq_dist_LB = self%min_rdist(0,x,i)
write(*,'(A,I0,A,F8.4)') "sq_dist_LB(",i,") = ", sq_dist_LB
call self%query_recursive(0,x,i,heap,sq_dist_LB)
end do
do i = 0, size(x,1) - 1
print *, heap%indices(i,:)
end do
call heap%get_arrays(sort_results_,distances,indices)
distances = sqrt(distances)
end subroutine
recursive subroutine query_recursive(self,i_node,x,i_pt,heap,sq_dist_LB)
class(btree), intent(in) :: self
integer, intent(in) :: i_node
real(wp), intent(in) :: x(0:,0:)
integer, intent(in) :: i_pt
type(nheap), intent(inout) :: heap
real(wp), intent(in) :: sq_dist_LB
real(wp) :: dist_pt, sq_dist_LB_1, sq_dist_LB_2
integer :: i, i1, i2
! Case 1: query point is outside node radius:
! trim it from the query
if (sq_dist_LB > heap%largest(i_pt)) then
print *, "i_node ", i_node, "Case 1"
continue
! Case 2: this is a leaf node. Update set of nearby points
!
else if (self%node_is_leaf(i_node)) then
print *, "i_node ", i_node, "Case 2"
do i = self%node_idx_start(i_node), self%node_idx_end(i_node) - 1
dist_pt = self%rdist(self%data,self%idx_array(i),x,i_pt)
! print *, "dist_pt = ", dist_pt
if (dist_pt < heap%largest(i_pt)) then
call heap%push(i_pt,dist_pt,self%idx_array(i))
end if
end do
! Case 3: Node is not a leaf. recursively query subnodes
! starting with the closest
else
print *, "i_node ", i_node, "Case 3"
i1 = 2*i_node + 1
i2 = i1 + 1
sq_dist_LB_1 = self%min_rdist(i1,x,i_pt)
sq_dist_LB_2 = self%min_rdist(i2,x,i_pt)
print *, sq_dist_LB_1, sq_dist_LB_2
call flush()
! recursively query subnodes
if (sq_dist_LB_1 <= sq_dist_LB_2) then
call self%query_recursive(i1,x,i_pt,heap,sq_dist_LB_1)
call self%query_recursive(i2,x,i_pt,heap,sq_dist_LB_2)
else
call self%query_recursive(i2,x,i_pt,heap,sq_dist_LB_2)
call self%query_recursive(i1,x,i_pt,heap,sq_dist_LB_1)
end if
end if
end subroutine
end module
program main
use btree_mod, only: btree, btree_init, wp, btree_query
use nheap_mod
implicit none
integer, parameter :: n = 100
integer, parameter :: d = 2
integer, parameter :: ls = 15
real(wp) :: x(0:n-1,0:d-1)
type(btree) :: bt
integer :: i, funit, is_leaf
integer, parameter :: k = 5
real(wp) :: p(3,2), r(0:n-1,0:k-1)
integer :: ri(0:n-1,0:k-1)
integer :: perm(6)
real(wp) :: a(6)
type(nheap) :: heap
! a = [2._wp,3._wp,1._wp,9._wp,7._wp,5._wp]
! perm = [1,2,3,4,5,6]
! call heap%init(1,6)
! do i = 1, 6
! call heap%push(0,a(i),perm(i))
! print *, i, heap%distances
! end do
! call heap%get_arrays(sort=.true.,distances=a,indices=perm)
! print *, a
! print *, perm
call random_number(x)
x = x*2 - 1
x(:,1) = x(:,1)*0.1_wp
x(:,1) = x(:,1) + x(:,0)**2
bt = btree_init(x,ls)
print *, "leaf_size: ", bt%leaf_size
print *, "nsamples, nfeatures: ", bt%n_samples, bt%n_features
print *, "nlevels, nnodes: ", bt%n_levels, bt%n_nodes
open(newunit=funit,file="ball.txt")
do i = 0, n-1
write(funit,*) x(i,:), bt%idx_array(i)
end do
close(funit)
open(newunit=funit,file="ball_nodes.txt")
do i = 0, bt%n_nodes-1
if (bt%node_is_leaf(i)) then
is_leaf = 1
else
is_leaf = 0
end if
write(funit,'(I0,X,I0,X,I0,X,I0,X,F16.10,X,2(F16.10,X))') i, is_leaf, &
bt%node_idx_start(i), bt%node_idx_end(i), bt%node_radius(i), bt%node_centroids(i,:)
end do
close(funit)
call btree_query(bt,x,k,sort_results=.true., &
distances=r, &
indices=ri)
do i = 0, n-1
print *, ri(i,:)
end do
end program
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment