Skip to content

Instantly share code, notes, and snippets.

@skariel
Created February 15, 2015 19:04
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save skariel/57f667992746fa20aaa5 to your computer and use it in GitHub Desktop.
Save skariel/57f667992746fa20aaa5 to your computer and use it in GitHub Desktop.
module nbody
using OctTrees
import OctTrees: modify, stop_cond, getx, gety, getz
immutable Particle <: AbstractPoint3D
_x::Float64
_y::Float64
_z::Float64
_m::Float64
end
Particle() = Particle(0., 0., 0., 0.)
getx(p::Particle) = p._x
gety(p::Particle) = p._y
getz(p::Particle) = p._z
type World
tree::OctTree{Particle}
particles::Array{Particle, 1}
vx::Array{Float64, 1}
vy::Array{Float64, 1}
vz::Array{Float64, 1}
ax::Array{Float64, 1}
ay::Array{Float64, 1}
az::Array{Float64, 1}
n::Int64
opening_alpha2::Float64
opening_excluded_frac2::Float64
smth2::Float64
end
function worldnormal(n::Int64; smth=0.000001, opening_excluded_frac=0.0, opening_alpha=0.7)
particles = [Particle(randn(), randn(), randn(), 1./n) for i in 1:n]
World(
OctTree(Particle; n=trunc(Integer,4.1*n)),
particles,
zeros(n),
zeros(n),
zeros(n),
zeros(n),
zeros(n),
zeros(n),
n,
opening_alpha^2,
opening_excluded_frac^2,
smth*smth
)
end
function worldspherical(n::Int64; smth=0.0, opening_excluded_frac=0.6, opening_alpha=0.7)
particles = Particle[]
while length(particles) < n
tx = rand()*2.0-1.0
ty = rand()*2.0-1.0
tz = rand()*2.0-1.0
tx*tx+ty*ty+tz*tz < 1.0 && push!(particles, Particle(tx,ty,tz, 1./n))
end
World(
OctTree(Particle; n=trunc(Integer,4.1*n)),
particles,
zeros(n),
zeros(n),
zeros(n),
zeros(n),
zeros(n),
zeros(n),
n,
opening_alpha^2,
opening_excluded_frac^2,
smth*smth
)
end
@inline function modify(q::OctTreeNode{Particle}, p::Particle)
const total_mass = q.point._m + p._m
const newx = (q.point._x*q.point._m + p._x*p._m)/total_mass
const newy = (q.point._y*q.point._m + p._y*p._m)/total_mass
const newz = (q.point._z*q.point._m + p._z*p._m)/total_mass
q.point = Particle(newx, newy, newz, total_mass)
end
function buildtree(w::World)
clear!(w.tree)
# calculate new boundries same extent on both x and y
minc = Float64(1.e30)
maxc = Float64(-1.e30)
for i in 1:w.n
@inbounds const p = w.particles[i]
if p._x < minc
minc = p._x
end
if p._y < minc
minc = p._y
end
if p._z < minc
minc = p._z
end
if p._x > maxc
maxc = p._x
end
if p._y > maxc
maxc = p._y
end
if p._z > maxc
maxc = p._z
end
end
r = 0.5*(maxc-minc)
md= 0.5*(maxc+minc)
initnode!(w.tree.head, r*1.05, md, md, md)
insert!(w.tree, w.particles, Modify)
end
type DataToCalculateAccelOnParticle
ax::Float64
ay::Float64
az::Float64
px::Float64
py::Float64
pz::Float64
w::World
end
@inline function stop_cond(q::OctTreeNode{Particle}, data::DataToCalculateAccelOnParticle)
isemptyleaf(q) && return true # empty node, nothing to do
if isleaf(q)
# we have a single particle in the node
q.point._x == data.px &&
q.point._y == data.py &&
q.point._z == data.pz && return true
const dx = q.point._x - data.px
const dx2 = dx*dx
const dy = q.point._y - data.py
const dy2 = dy*dy
const dz = q.point._z - data.pz
const dz2 = dz*dz
const dr2 = dx2+dy2+dz2+data.w.smth2
const dr = sqrt(dr2)
#const denom = (dx2+dy2+dz2+data.w.smth2)^1.5/q.point._m
const denom = dr2*dr/q.point._m
data.ax += dx/denom
data.ay += dy/denom
data.az += dz/denom
return true
end
# here q is divided. Check we are not too close to the cell
const lx = 2.0*q.r
const lx2 = lx*lx
const dqx = q.midx - data.px
const dqx2 = dqx*dqx
const dqy = q.midy - data.py
const dqy2 = dqy*dqy
const dqz = q.midz - data.pz
const dqz2 = dqz*dqz
const fac = data.w.opening_excluded_frac2*lx2
dqx2 < fac && dqy2 < fac && dqz2 < fac && return false # we need to further open the node
const dx = q.point._x - data.px
const dx2 = dx*dx
const dy = q.point._y - data.py
const dy2 = dy*dy
const dz = q.point._z - data.pz
const dz2 = dz*dz
const r2 = dx2 + dy2 + dz2
lx2/r2 > data.w.opening_alpha2 && return false # we need to further open the node
# consider the node, no further need to open it
const dr2 = r2+data.w.smth2
const dr = sqrt(dr2)
const denom = dr2*dr/q.point._m
#const denom = (r2+data.w.smth2)^1.5/q.point._m
data.ax += dx/denom
data.ay += dy/denom
data.az += dz/denom
return true
end
@inline function calculate_accel_on_particle(w::World, particle_ix::Int64)
@inbounds const p = w.particles[particle_ix]
@inbounds data = DataToCalculateAccelOnParticle(0.0,0.0,0.0,p._x,p._y,p._z,w)
map(w.tree, data)
@inbounds w.ax[particle_ix] = data.ax
@inbounds w.ay[particle_ix] = data.ay
@inbounds w.az[particle_ix] = data.az
end
function calc_accel(w::World)
buildtree(w)
data = DataToCalculateAccelOnParticle(0.0,0.0,0.0,0.0,0.0,0.0,w)
@inbounds for i in 1:w.n
const p = w.particles[i]
data.ax = 0.0
data.ay = 0.0
data.az = 0.0
data.px = p._x
data.py = p._y
data.pz = p._z
map(w.tree, data)
@inbounds w.ax[i] = data.ax
@inbounds w.ay[i] = data.ay
@inbounds w.az[i] = data.az
end
end
function calc_accel_brute_force(w::World, ixs=1:w.n)
ax = zeros(w.n)
ay = zeros(w.n)
az = zeros(w.n)
for i in ixs
@inbounds p_i = w.particles[i]
for j in 1:w.n
i==j && continue
@inbounds pj = w.particles[j]
const dx = pj._x - p_i._x
const dy = pj._y - p_i._y
const dz = pj._z - p_i._z
const dx2 = dx*dx
const dy2 = dy*dy
const dz2 = dz*dz
const r2 = dx2+dy2+dz2
const r22 = r2+w.smth2
const r21 = sqrt(r21)
const denom = r22*r21/p_i._m
@inbounds ax[i] += dx/denom
@inbounds ay[i] += dy/denom
@inbounds az[i] += dz/denom
end
end
ax[ixs],ay[ixs],az[ixs]
end
using Base.Test
function test()
w = worldnormal(10000)
buildtree(w)
#testing all points are insize tree boundaries:
for p in w.particles
@test p._x > w.tree.head.midx - w.tree.head.r
@test p._x < w.tree.head.midx + w.tree.head.r
@test p._y > w.tree.head.midy - w.tree.head.r
@test p._y < w.tree.head.midy + w.tree.head.r
@test p._z > w.tree.head.midz - w.tree.head.r
@test p._z < w.tree.head.midz + w.tree.head.r
end
# testing number of full leafs
tot_not_empty = 0
tot_massive_leafs = 0
for n in w.tree.nodes
if isfullleaf(n)
tot_not_empty += 1
end
if n.point._m > 1.e-13 && isleaf(n)
tot_massive_leafs += 1
end
end
@test tot_not_empty == w.n
@test tot_massive_leafs == w.n
total_mass = 0.0
for n in w.tree.nodes[1:w.tree.number_of_nodes_used]
if !n.is_empty
total_mass += n.point._m
end
end
@test_approx_eq_eps total_mass 1.0 1.e-4
# testing tree construction
for n in w.tree.nodes
# testing children radius is 0.5*r
if n.is_divided
@test_approx_eq n.r/2 n.lxlylz.r
@test_approx_eq n.r/2 n.lxlyhz.r
@test_approx_eq n.r/2 n.lxhylz.r
@test_approx_eq n.r/2 n.lxhyhz.r
@test_approx_eq n.r/2 n.hxlylz.r
@test_approx_eq n.r/2 n.hxlyhz.r
@test_approx_eq n.r/2 n.hxhylz.r
@test_approx_eq n.r/2 n.hxhyhz.r
end
# testing mass in children is mass in parent
if n.is_divided
parent_mass = n.point._m
children_mass =
n.lxlylz.point._m +
n.lxlyhz.point._m +
n.lxhylz.point._m +
n.lxhyhz.point._m +
n.hxlylz.point._m +
n.hxlyhz.point._m +
n.hxhylz.point._m +
n.hxhyhz.point._m
@test_approx_eq_eps parent_mass children_mass 1.e-4
end
# testing all divided nodes are empty
if n.is_divided
@test n.is_empty
@test !isleaf(n)
@test !isemptyleaf(n)
@test !isfullleaf(n)
else
@test isleaf(n)
if isfullleaf(n)
@test isfullleaf(n)
end
end
# testing center of mass in children builds com in parent
if n.is_divided
parent_x = n.point._x
parent_y = n.point._y
parent_z = n.point._z
children_x = (
n.lxlylz.point._x*n.lxlylz.point._m +
n.lxlyhz.point._x*n.lxlyhz.point._m +
n.lxhylz.point._x*n.lxhylz.point._m +
n.lxhyhz.point._x*n.lxhyhz.point._m +
n.hxlylz.point._x*n.hxlylz.point._m +
n.hxlyhz.point._x*n.hxlyhz.point._m +
n.hxhylz.point._x*n.hxhylz.point._m +
n.hxhyhz.point._x*n.hxhyhz.point._m
)/n.point._m
children_y = (
n.lxlylz.point._y*n.lxlylz.point._m +
n.lxlyhz.point._y*n.lxlyhz.point._m +
n.lxhylz.point._y*n.lxhylz.point._m +
n.lxhyhz.point._y*n.lxhyhz.point._m +
n.hxlylz.point._y*n.hxlylz.point._m +
n.hxlyhz.point._y*n.hxlyhz.point._m +
n.hxhylz.point._y*n.hxhylz.point._m +
n.hxhyhz.point._y*n.hxhyhz.point._m
)/n.point._m
children_z = (
n.lxlylz.point._z*n.lxlylz.point._m +
n.lxlyhz.point._z*n.lxlyhz.point._m +
n.lxhylz.point._z*n.lxhylz.point._m +
n.lxhyhz.point._z*n.lxhyhz.point._m +
n.hxlylz.point._z*n.hxlylz.point._m +
n.hxlyhz.point._z*n.hxlyhz.point._m +
n.hxhylz.point._z*n.hxhylz.point._m +
n.hxhyhz.point._z*n.hxhyhz.point._m
)/n.point._m
@test_approx_eq_eps parent_x children_x 1.e-4
@test_approx_eq_eps parent_y children_y 1.e-4
@test_approx_eq_eps parent_z children_z 1.e-4
end
end
println("*** All tests passed! ***")
end
test()
type TestAcc
ax_tree::Array{Float64,1}
a_tree::Array{Float64,1}
ax_bf::Array{Float64,1}
a_bf::Array{Float64,1}
ferr::Array{Float64,1}
fbelow::Array{Float64,1}
end
function test_acc(n, nout)
ixs = randperm(n)[1:nout]
w = worldspherical(n)
buildtree(w)
for i in ixs
calculate_accel_on_particle(w, i)
end
ax_tree = w.ax[ixs]
ay_tree = w.ay[ixs]
az_tree = w.az[ixs]
a_tree = sqrt(ax_tree.^2+ay_tree.^2+ay_tree.^2)
println("calculating BF...")
ax_bf, ay_bf, az_bf = calc_accel_brute_force(w, ixs)
println("Done!\n")
a_bf = sqrt(ax_bf.^2+ay_bf.^2+ay_bf.^2)
dax = ax_tree-ax_bf
day = ay_tree-ay_bf
daz = az_tree-az_bf
ferr = sqrt(dax.^2+day.^2+daz.^2)./a_bf.*100.0;
sort!(ferr)
fbelow = [1:nout]./nout.*100.0
TestAcc(
ax_tree,
a_tree,
ax_bf,
a_bf,
ferr,
fbelow
)
end
function calc_dt(w::World)
end
end # module
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment