Last active
June 23, 2020 19:28
-
-
Save PhilipVinc/3574a8adf88a4cff13a7c154afac9f1e to your computer and use it in GitHub Desktop.
jax and mpi. Interpret underscores in file names as slashes (subpaths)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from . import mpi_xla_bridge | |
from jax.lib import xla_client | |
for name, fn in mpi_xla_bridge.cpu_custom_call_targets.items(): | |
xla_client.register_cpu_custom_call_target(name, fn) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# cython: language_level=2 | |
# distutils: language = c++ | |
cimport mpi4py.MPI as MPI | |
cimport mpi4py.libmpi as libmpi | |
from cpython.pycapsule cimport PyCapsule_New | |
from libc.stdio cimport printf | |
from libc.stdint cimport int32_t, int64_t | |
cdef void sum_inplace_mpi_f32(void* out_ptr, void** data_ptr) nogil: | |
cdef int32_t nitems = (<int32_t*>(data_ptr[0]))[0] | |
cdef float* x = <float*>(data_ptr[1]) | |
cdef float* out = <float*>(out_ptr) | |
libmpi.MPI_Allreduce(x, out, nitems, libmpi.MPI_FLOAT, libmpi.MPI_SUM, libmpi.MPI_COMM_WORLD) | |
cdef void sum_inplace_mpi_f64(void* out_ptr, void** data_ptr) nogil: | |
cdef int32_t nitems = (<int32_t*>(data_ptr[0]))[0] | |
cdef double* x = <double*>(data_ptr[1]) | |
cdef double* out = <double*>(out_ptr) | |
out[0] = x[0] | |
libmpi.MPI_Allreduce(x, out, nitems, libmpi.MPI_DOUBLE, libmpi.MPI_SUM, libmpi.MPI_COMM_WORLD) | |
cdef void sum_inplace_mpi_c64(void* out_ptr, void** data_ptr) nogil: | |
cdef int32_t nitems = (<int32_t*>(data_ptr[0]))[0] | |
cdef float complex* x = <float complex*>(data_ptr[1]) | |
cdef float complex* out = <float complex*>(out_ptr) | |
out[0] = x[0] | |
libmpi.MPI_Allreduce(x, out, nitems * 2, libmpi.MPI_FLOAT, libmpi.MPI_SUM, libmpi.MPI_COMM_WORLD) | |
cdef void sum_inplace_mpi_c128(void* out_ptr, void** data_ptr) nogil: | |
cdef int32_t nitems = (<int32_t*>(data_ptr[0]))[0] | |
cdef double complex* x = <double complex*>(data_ptr[1]) | |
cdef double complex* out = <double complex*>(out_ptr) | |
out[0] = x[0] | |
libmpi.MPI_Allreduce(x, out, nitems * 2, libmpi.MPI_DOUBLE, libmpi.MPI_SUM, libmpi.MPI_COMM_WORLD) | |
cpu_custom_call_targets = {} | |
cdef register_custom_call_target(fn_name, void* fn): | |
cdef const char* name = "xla._CUSTOM_CALL_TARGET" | |
cpu_custom_call_targets[fn_name] = PyCapsule_New(fn, name, NULL) | |
register_custom_call_target(b"sum_inplace_mpi_f32", <void*>(sum_inplace_mpi_f32)) | |
register_custom_call_target(b"sum_inplace_mpi_f64", <void*>(sum_inplace_mpi_f64)) | |
register_custom_call_target(b"sum_inplace_mpi_c64", <void*>(sum_inplace_mpi_c64)) | |
register_custom_call_target(b"sum_inplace_mpi_c128", <void*>(sum_inplace_mpi_c128)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as _np | |
import jax | |
@sum_inplace.register(jax.interpreters.xla.DeviceArray) | |
def sum_inplace_jax(x): | |
# if not isinstance(x, jax.interpreters.xla.DeviceArray): | |
# raise TypeError("Argument to sum_inplace_jax must be a DeviceArray, got {}" | |
# .format(type(x))) | |
if _n_nodes == 1: | |
return x | |
# This below only works on cpus... | |
# we should make this work for gpus too.. | |
# TODO: unsafe_buffer_pointer is considered not yet definitive interface | |
ptr = x.block_until_ready().device_buffer.unsafe_buffer_pointer() | |
# The above is faster. | |
# This below should work more often, but might copy. | |
# Depending on future changes in jaxlib, we might have to switch to | |
# this below. | |
# see Google/jax #2123 and #1009 | |
# _x = jax.xla._force(x.block_until_ready()) | |
# ptr = _x.device_buffer.unsafe_buffer_pointer() | |
# using native numpy because jax's numpy does not have ctypeslib | |
data_pointer = _np.ctypeslib.ndpointer(x.dtype, shape=x.shape) | |
# wrap jax data into a standard numpy array which is handled by MPI | |
arr = data_pointer(ptr).contents | |
_MPI_comm.Allreduce(_MPI.IN_PLACE, arr.reshape(-1), op=_MPI.SUM) | |
return x | |
from jax import core | |
from jax import abstract_arrays | |
from jax.lib import xla_client | |
from jax.interpreters import xla | |
_ops = xla_client.ops | |
## The underlying jax primitive | |
sum_inplace_p = core.Primitive("sum_inplace_mpi") # Create the primitive | |
# This function applies the primitive to a AST | |
def sum_inplace_jax_primitive(x): | |
return sum_inplace_p.bind(x) | |
# this function executes the primitive, when not under any transformation | |
sum_inplace_p.def_impl(sum_inplace_jax) | |
# def sum_inplace_impl(x): | |
# return sum_inplace_jax(x) | |
# sum_inplace_p.def_impl(sum_inplace_impl) | |
# This function evaluates only the shapes during AST construction | |
def sum_inplace_abstract_eval(xs): | |
return abstract_arrays.ShapedArray(xs.shape, xs.dtype) | |
sum_inplace_p.def_abstract_eval(sum_inplace_abstract_eval) | |
# Herlper functions | |
def _constant_s32_scalar(c, x): | |
return _ops.Constant(c, _np.int32(x)) | |
def _unpack_builder(c): | |
# If `c` is a ComputationBuilder object, extracts the underlying XlaBuilder. | |
return getattr(c, "_builder", c) | |
# This function compiles the operation | |
def sum_inplace_xla_encode(c, x): | |
c = _unpack_builder(c) | |
x_shape = c.GetShape(x) | |
dtype = x_shape.element_type() | |
dims = x_shape.dimensions() | |
# compute total number of elements in array | |
nitems = dims[0] | |
for el in dims[1:]: | |
nitems *= el | |
# those kernels have been loaded through cython. | |
if dtype == _np.float32: | |
kernel = b"sum_inplace_mpi_f32" | |
elif dtype == _np.float64: | |
kernel = b"sum_inplace_mpi_f64" | |
elif dtype == _np.complex64: | |
kernel = b"sum_inplace_mpi_c64" | |
elif dtype == _np.complex128: | |
kernel = b"sum_inplace_mpi_c128" | |
return _ops.CustomCall( | |
c, | |
kernel, | |
operands=(xla_client.ops.Constant(c, _np.int32(nitems)), x), | |
shape=xla_client.Shape.array_shape(dtype, dims), | |
) | |
# assign to the primitive the correct encoder | |
xla.backend_specific_translations["cpu"][sum_inplace_p] = sum_inplace_xla_encode | |
@sum_inplace.register(jax.interpreters.partial_eval.JaxprTracer) | |
@sum_inplace.register(jax.interpreters.ad.JVPTracer) | |
def sum_inplace_jax_jittracer(x): | |
if _n_nodes == 1: | |
return x | |
else: | |
return sum_inplace_jax_primitive(x) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from setuptools import setup | |
from setuptools.extension import Extension | |
import os | |
def mpi_includes(): | |
import mpi4py | |
config = mpi4py.get_config() | |
cmd_compile = " ".join([config["mpicc"], "--showme:compile"]) | |
out_stream = os.popen(cmd_compile) | |
compile_flags = out_stream.read().strip() | |
include_dirs = [p[2:] for p in compile_flags.split()] | |
include_dirs.append(mpi4py.get_include()) | |
return include_dirs | |
setup( | |
name="mypkg", | |
packages=[ | |
"mypkg", | |
"mypkg.cython"], | |
ext_modules=[ | |
Extension( | |
name="netket.cython.mpi_xla_bridge", | |
sources=["mypkg/cython/mpi_xla_bridge.pyx"], | |
include_dirs=mpi_includes(), | |
), | |
], | |
setup_requires=["setuptools>=18.0", "cython>=0.21", "mpi4py>=3.0.1",], | |
) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment