Skip to content

Instantly share code, notes, and snippets.

@smsharma
Last active April 15, 2023 13:11
Show Gist options
  • Save smsharma/9b17db5b2635973539d58129383c5f1f to your computer and use it in GitHub Desktop.
Save smsharma/9b17db5b2635973539d58129383c5f1f to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "de1c7418-b1b2-4601-bccd-a7c6ace95481",
"metadata": {},
"outputs": [],
"source": [
"import jax\n",
"import jax.numpy as np\n",
"\n",
"import numpy as onp\n",
"\n",
"from scipy.interpolate import interp1d"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "b9d65511-b151-48c4-a873-89ec7186ab32",
"metadata": {},
"outputs": [],
"source": [
"def interp1d_jax(x, y, kind='linear', assume_sorted=False):\n",
" if kind != 'linear':\n",
" raise NotImplementedError('Only linear interpolation is supported.')\n",
"\n",
" if not assume_sorted:\n",
" sorted_indices = np.argsort(x)\n",
" x = x[sorted_indices]\n",
" y = y[sorted_indices]\n",
"\n",
" def interpolate(x_new):\n",
" if np.ndim(x_new) != 1:\n",
" raise ValueError(\"x_new should be a 1D array\")\n",
"\n",
" x_min = x[0]\n",
" x_max = x[-1]\n",
"\n",
" out_of_bounds = (x_new < x_min) | (x_new > x_max)\n",
"\n",
" def in_bounds(x_new):\n",
" indices = np.searchsorted(x, x_new, side='right') - 1\n",
" indices = np.clip(indices, 0, x.shape[0] - 2)\n",
"\n",
" x0, x1 = x[indices], x[indices + 1]\n",
" y0, y1 = y[indices], y[indices + 1]\n",
"\n",
" t = (x_new - x0) / (x1 - x0)\n",
" return (1 - t) * y0 + t * y1\n",
"\n",
" def out_of_bounds_func(x_new):\n",
" return np.zeros_like(x_new)\n",
"\n",
" return np.where(out_of_bounds, out_of_bounds_func(x_new), in_bounds(x_new))\n",
"\n",
" return interpolate"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "1f119e54-a929-49cf-814b-ad22ee016475",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Array([1.5000001, 1.5551683, 1.6164234, 1.6844373, 1.7599555, 1.8438063,\n",
" 1.9369088, 2.0402837, 2.155065 , 2.2825105, 2.4240181, 2.581139 ,\n",
" 2.755596 , 2.9493022, 3.1643806, 3.4031904, 3.6683497, 3.9627657,\n",
" 4.2896667, 4.652636 , 5.055655 , 5.503141 , 6. ], dtype=float32)"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x = np.linspace(0, 10, 100)\n",
"y = np.linspace(1, 6, 100)\n",
"\n",
"x_test = np.logspace(0, 1, 23)\n",
"\n",
"# Test Jax version\n",
"interp1d_jax(x, y)(x_test)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "aeebab8d-c88b-4894-9b42-28204349364d",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Array(True, dtype=bool)"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Make sure it matches scipy version\n",
"np.allclose(interp1d_jax(x, y)(x_test), interp1d(x, y)(x_test))"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "4e84779b-9f2a-404b-8e5c-743ed0637a57",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Array(True, dtype=bool)"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Make sure jit works\n",
"np.allclose(jax.jit(interp1d_jax(x, y))(x_test), interp1d_jax(x, y)(x_test))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1a8a6a7f-bd47-45b0-9696-ba33003e6f32",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.13"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment