Skip to content

Instantly share code, notes, and snippets.

@cgarciae
Last active April 26, 2020 14:46
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save cgarciae/7fcfc95709b1d94b27010c5e00db6690 to your computer and use it in GitHub Desktop.
Save cgarciae/7fcfc95709b1d94b27010c5e00db6690 to your computer and use it in GitHub Desktop.
python distance functions
import typing as tp
from jax import numpy as jnp
import jax
import numpy as np
import time
@jax.jit
def _distances_jax(data1, data2):
# data1, data2 are the data arrays with 2 cols and they hold
# lat., lng. values in those cols respectively
np = jnp
data1 = np.deg2rad(data1)
data2 = np.deg2rad(data2)
lat1 = data1[:, 0]
lng1 = data1[:, 1]
lat2 = data2[:, 0]
lng2 = data2[:, 1]
diff_lat = lat1[:, None] - lat2
diff_lng = lng1[:, None] - lng2
d = (
np.sin(diff_lat / 2) ** 2
+ np.cos(lat1[:, None]) * np.cos(lat2) * np.sin(diff_lng / 2) ** 2
)
data = 2 * 6373 * np.arctan2(np.sqrt(np.abs(d)), np.sqrt(np.abs(1 - d)))
return np.array(data.reshape(data1.shape[0], data2.shape[0]))
def distances_jax(data1, data2):
return np.asarray(_distances_jax(data1, data2))
def distances_np(data1, data2):
# data1, data2 are the data arrays with 2 cols and they hold
# lat., lng. values in those cols respectively
data1 = np.deg2rad(data1)
data2 = np.deg2rad(data2)
lat1 = data1[:, 0]
lng1 = data1[:, 1]
lat2 = data2[:, 0]
lng2 = data2[:, 1]
diff_lat = lat1[:, None] - lat2
diff_lng = lng1[:, None] - lng2
d = (
np.sin(diff_lat / 2) ** 2
+ np.cos(lat1[:, None]) * np.cos(lat2) * np.sin(diff_lng / 2) ** 2
)
data = 2 * 6373 * np.arctan2(np.sqrt(np.abs(d)), np.sqrt(np.abs(1 - d)))
return data.reshape(data1.shape[0], data2.shape[0]) + 1.0
if __name__ == '__main__':
a = np.random.uniform(-100, 100, size=(5000, 2)).astype(np.float32)
b = np.random.uniform(-100, 100, size=(5000, 2)).astype(np.float32)
t0 = time.time()
d = distances_np(a, b)
print("time np", time.time() - t0)
a = np.random.uniform(-100, 100, size=(5000, 2)).astype(np.float32)
b = np.random.uniform(-100, 100, size=(5000, 2)).astype(np.float32)
jnp.array([1])
d = distances_jax(a, b)
t0 = time.time()
d = distances_jax(a, b)
print("time jax", (time.time() - t0))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment