Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save carlthome/a8c7571a3cc87808a91d23ebe5cbc45b to your computer and use it in GitHub Desktop.
Save carlthome/a8c7571a3cc87808a91d23ebe5cbc45b to your computer and use it in GitHub Desktop.
Zero-copy data sharing between JAX and TensorFlow via DLPack
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"source": [
"import numpy as np\n",
"import tensorflow as tf\n",
"import jax.dlpack\n",
"\n",
"tf_arr = tf.random.uniform((10,))\n",
"dl_arr = tf.experimental.dlpack.to_dlpack(tf_arr)\n",
"jax_arr = jax.dlpack.from_dlpack(dl_arr)\n",
"\n",
"np.testing.assert_array_equal(tf_arr, jax_arr)"
],
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n"
]
}
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 2,
"source": [
"import jax.numpy as jnp"
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 3,
"source": [
"def tf_to_jax(arr):\n",
" return jax.dlpack.from_dlpack(tf.experimental.dlpack.to_dlpack(tf_arr))\n",
"\n",
"def jax_to_tf(arr):\n",
" return tf.experimental.dlpack.from_dlpack(jax.dlpack.to_dlpack(arr))\n",
"\n",
"jax_arr = jnp.arange(20.)\n",
"tf_arr = jax_to_tf(jax_arr)\n",
"jax_arr2 = tf_to_jax(tf_arr)"
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 4,
"source": [
"jnp.all(jax_arr == jax_arr2)"
],
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"DeviceArray(True, dtype=bool)"
]
},
"metadata": {},
"execution_count": 4
}
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 5,
"source": [
"jax_arr.unsafe_buffer_pointer() == jax_arr2.unsafe_buffer_pointer()"
],
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"True"
]
},
"metadata": {},
"execution_count": 5
}
],
"metadata": {}
}
],
"metadata": {
"language_info": {
"name": "python",
"version": "3.9.5",
"mimetype": "text/x-python",
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"pygments_lexer": "ipython3",
"nbconvert_exporter": "python",
"file_extension": ".py"
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3.9.5 64-bit ('tf': conda)"
},
"interpreter": {
"hash": "0b497c359b89e58986a3a03ab9378dba9a3d04bd39a71c4238cdd4a97bc8ad5f"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment