Skip to content

Instantly share code, notes, and snippets.

@cgarciae
Last active April 26, 2020 22:24
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save cgarciae/a69fa609f8fcd0aacece92660b5c2315 to your computer and use it in GitHub Desktop.
Save cgarciae/a69fa609f8fcd0aacece92660b5c2315 to your computer and use it in GitHub Desktop.
using Base.Threads
using LoopVectorization
using BenchmarkTools
const None = [CartesianIndex()]
function distances(data1, data2)
data1 = deg2rad.(data1)
data2 = deg2rad.(data2)
lat1 = @view data1[:, 1]
lng1 = @view data1[:, 2]
lat2 = @view data2[:, 1]
lng2 = @view data2[:, 2]
diff_lat = @view(lat1[:, None]) .- @view(lat2[None, :])
diff_lng = @view(lng1[:, None]) .- @view(lng2[None, :])
data = (
@. sin(diff_lat / 2)^2 +
cos(@view(lat1[:, None])) * cos(@view(lat2[None,:])) * sin(diff_lng / 2)^2
)
data .= @. 2.0 * 6373.0 * atan(sqrt(abs(data)), sqrt(abs(1.0 - data)))
return reshape(data, (size(data1, 1), size(data2, 1)))
end
function distances_threaded(data1, data2)
lat1 = [deg2rad(data1[i,1]) for i in 1:size(data1, 1)]
lng1 = [deg2rad(data1[i,2]) for i in 1:size(data1, 1)]
lat2 = [deg2rad(data2[i,1]) for i in 1:size(data2, 1)]
lng2 = [deg2rad(data2[i,2]) for i in 1:size(data2, 1)]
data = Matrix{Float64}(undef, length(lat1), length(lat2))
@threads for i in eachindex(lat2)
lat, lng = lat2[i], lng2[i]
data[:, i] .= @. sin((lat1 - lat) / 2)^2 + cos(lat1) * cos(lat) * sin((lng1 - lng) / 2)^2
end
@threads for i in eachindex(data)
data[i] = 2.0 * 6373.0 * atan(sqrt(abs(data[i])), sqrt(abs(1.0 - data[i])))
end
return data
end
function distances_threaded_simd(data1, data2) # @baggepinnen
lat1 = [deg2rad(data1[i,1]) for i in 1:size(data1, 1)]
lng1 = [deg2rad(data1[i,2]) for i in 1:size(data1, 1)]
lat2 = [deg2rad(data2[i,1]) for i in 1:size(data2, 1)]
lng2 = [deg2rad(data2[i,2]) for i in 1:size(data2, 1)]
data = Matrix{Float64}(undef, length(lat1), length(lat2))
@threads for i in eachindex(lat2)
lat, lng = lat2[i], lng2[i]
@avx data[:, i] .= @. sin((lat1 - lat) / 2)^2 + cos(lat1) * cos(lat) * sin((lng1 - lng) / 2)^2
end
@threads for i in eachindex(data)
@avx data[i] = 2.0 * 6373.0 * atan(sqrt(abs(data[i])), sqrt(abs(1.0 - data[i])))
end
return data
end
function distances_bcast(data1, data2) # @DNF
data1 = deg2rad.(data1)
data2 = deg2rad.(data2)
lat1 = @view data1[:, 1]
lng1 = @view data1[:, 2]
lat2 = @view data2[:, 1]
lng2 = @view data2[:, 2]
data = sin.((lat1 .- lat2') ./ 2).^2 .+ cos.(lat1) .* cos.(lat2') .* sin.((lng1 .- lng2') ./ 2).^2
@. data = 2 * 6373 * atan(sqrt(abs(data)), sqrt(abs(1 - data)))
return data
end
function distances_bcast_simd(data1, data2)
data1 = deg2rad.(data1)
data2 = deg2rad.(data2)
lat1 = @view data1[:, 1]
lng1 = @view data1[:, 2]
lat2 = @view data2[:, 1]
lng2 = @view data2[:, 2]
@avx data = sin.((lat1 .- lat2') ./ 2).^2 .+ cos.(lat1) .* cos.(lat2') .* sin.((lng1 .- lng2') ./ 2).^2
@. data = 2 * 6373 * atan(sqrt(abs(data)), sqrt(abs(1 - data)))
return data
end
using PyCall
py"""
import typing as tp
from jax import numpy as jnp
import jax
import numpy as np
import time
@jax.jit
def distances_jax(data1, data2):
# data1, data2 are the data arrays with 2 cols and they hold
# lat., lng. values in those cols respectively
np = jnp
data1 = np.deg2rad(data1)
data2 = np.deg2rad(data2)
lat1 = data1[:, 0]
lng1 = data1[:, 1]
lat2 = data2[:, 0]
lng2 = data2[:, 1]
diff_lat = lat1[:, None] - lat2
diff_lng = lng1[:, None] - lng2
d = (
np.sin(diff_lat / 2) ** 2
+ np.cos(lat1[:, None]) * np.cos(lat2) * np.sin(diff_lng / 2) ** 2
)
data = 2 * 6373 * np.arctan2(np.sqrt(np.abs(d)), np.sqrt(np.abs(1 - d)))
return data.reshape(data1.shape[0], data2.shape[0])
def distances_np(data1, data2):
# data1, data2 are the data arrays with 2 cols and they hold
# lat., lng. values in those cols respectively
data1 = np.deg2rad(data1)
data2 = np.deg2rad(data2)
lat1 = data1[:, 0]
lng1 = data1[:, 1]
lat2 = data2[:, 0]
lng2 = data2[:, 1]
diff_lat = lat1[:, None] - lat2
diff_lng = lng1[:, None] - lng2
d = (
np.sin(diff_lat / 2) ** 2
+ np.cos(lat1[:, None]) * np.cos(lat2) * np.sin(diff_lng / 2) ** 2
)
data = 2 * 6373 * np.arctan2(np.sqrt(np.abs(d)), np.sqrt(np.abs(1 - d)))
return data.reshape(data1.shape[0], data2.shape[0])
a = np.random.uniform(-100, 100, size=(5000, 2)).astype(np.float32)
b = np.random.uniform(-100, 100, size=(5000, 2)).astype(np.float32)
def dist_np_test():
return distances_np(a, b)
# enforce eager evaluation
def dist_jax_test():
return distances_jax(a, b).block_until_ready()
"""
a = [(rand() - 0.5) * 200 for i in 1:5000, j in 1:2]
b = [(rand() - 0.5) * 200 for i in 1:5000, j in 1:2]
a = convert(Array{Float32}, a)
b = convert(Array{Float32}, b)
print("distances")
@btime distances($a, $b)
print("distances_bcast")
@btime distances_bcast($a, $b)
print("distances_threaded")
@btime distances_threaded($a, $b)
print("distances_threaded_simd")
@btime distances_threaded_simd($a, $b)
print("dist_np_test")
@btime py"dist_np_test"()
print("dist_jax_test")
@btime py"dist_jax_test"()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment