Last active
August 20, 2023 21:11
-
-
Save carlthome/a8c7571a3cc87808a91d23ebe5cbc45b to your computer and use it in GitHub Desktop.
Zero-copy data sharing between JAX and TensorFlow via DLPack
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"source": [ | |
"import numpy as np\n", | |
"import tensorflow as tf\n", | |
"import jax.dlpack\n", | |
"\n", | |
"tf_arr = tf.random.uniform((10,))\n", | |
"dl_arr = tf.experimental.dlpack.to_dlpack(tf_arr)\n", | |
"jax_arr = jax.dlpack.from_dlpack(dl_arr)\n", | |
"\n", | |
"np.testing.assert_array_equal(tf_arr, jax_arr)" | |
], | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stderr", | |
"text": [ | |
"WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n" | |
] | |
} | |
], | |
"metadata": {} | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"source": [ | |
"import jax.numpy as jnp" | |
], | |
"outputs": [], | |
"metadata": {} | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"source": [ | |
"def tf_to_jax(arr):\n", | |
" return jax.dlpack.from_dlpack(tf.experimental.dlpack.to_dlpack(tf_arr))\n", | |
"\n", | |
"def jax_to_tf(arr):\n", | |
" return tf.experimental.dlpack.from_dlpack(jax.dlpack.to_dlpack(arr))\n", | |
"\n", | |
"jax_arr = jnp.arange(20.)\n", | |
"tf_arr = jax_to_tf(jax_arr)\n", | |
"jax_arr2 = tf_to_jax(tf_arr)" | |
], | |
"outputs": [], | |
"metadata": {} | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"source": [ | |
"jnp.all(jax_arr == jax_arr2)" | |
], | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"DeviceArray(True, dtype=bool)" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 4 | |
} | |
], | |
"metadata": {} | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"source": [ | |
"jax_arr.unsafe_buffer_pointer() == jax_arr2.unsafe_buffer_pointer()" | |
], | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"True" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 5 | |
} | |
], | |
"metadata": {} | |
} | |
], | |
"metadata": { | |
"language_info": { | |
"name": "python", | |
"version": "3.9.5", | |
"mimetype": "text/x-python", | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"pygments_lexer": "ipython3", | |
"nbconvert_exporter": "python", | |
"file_extension": ".py" | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3.9.5 64-bit ('tf': conda)" | |
}, | |
"interpreter": { | |
"hash": "0b497c359b89e58986a3a03ab9378dba9a3d04bd39a71c4238cdd4a97bc8ad5f" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment