Skip to content

Instantly share code, notes, and snippets.

@PhilipVinc
Last active June 23, 2020 19:28
Show Gist options
  • Save PhilipVinc/3574a8adf88a4cff13a7c154afac9f1e to your computer and use it in GitHub Desktop.
Save PhilipVinc/3574a8adf88a4cff13a7c154afac9f1e to your computer and use it in GitHub Desktop.
jax and mpi. Interpret underscores in file names as slashes (subpaths)
from . import mpi_xla_bridge
from jax.lib import xla_client
for name, fn in mpi_xla_bridge.cpu_custom_call_targets.items():
xla_client.register_cpu_custom_call_target(name, fn)
# cython: language_level=2
# distutils: language = c++
cimport mpi4py.MPI as MPI
cimport mpi4py.libmpi as libmpi
from cpython.pycapsule cimport PyCapsule_New
from libc.stdio cimport printf
from libc.stdint cimport int32_t, int64_t
cdef void sum_inplace_mpi_f32(void* out_ptr, void** data_ptr) nogil:
cdef int32_t nitems = (<int32_t*>(data_ptr[0]))[0]
cdef float* x = <float*>(data_ptr[1])
cdef float* out = <float*>(out_ptr)
libmpi.MPI_Allreduce(x, out, nitems, libmpi.MPI_FLOAT, libmpi.MPI_SUM, libmpi.MPI_COMM_WORLD)
cdef void sum_inplace_mpi_f64(void* out_ptr, void** data_ptr) nogil:
cdef int32_t nitems = (<int32_t*>(data_ptr[0]))[0]
cdef double* x = <double*>(data_ptr[1])
cdef double* out = <double*>(out_ptr)
out[0] = x[0]
libmpi.MPI_Allreduce(x, out, nitems, libmpi.MPI_DOUBLE, libmpi.MPI_SUM, libmpi.MPI_COMM_WORLD)
cdef void sum_inplace_mpi_c64(void* out_ptr, void** data_ptr) nogil:
cdef int32_t nitems = (<int32_t*>(data_ptr[0]))[0]
cdef float complex* x = <float complex*>(data_ptr[1])
cdef float complex* out = <float complex*>(out_ptr)
out[0] = x[0]
libmpi.MPI_Allreduce(x, out, nitems * 2, libmpi.MPI_FLOAT, libmpi.MPI_SUM, libmpi.MPI_COMM_WORLD)
cdef void sum_inplace_mpi_c128(void* out_ptr, void** data_ptr) nogil:
cdef int32_t nitems = (<int32_t*>(data_ptr[0]))[0]
cdef double complex* x = <double complex*>(data_ptr[1])
cdef double complex* out = <double complex*>(out_ptr)
out[0] = x[0]
libmpi.MPI_Allreduce(x, out, nitems * 2, libmpi.MPI_DOUBLE, libmpi.MPI_SUM, libmpi.MPI_COMM_WORLD)
cpu_custom_call_targets = {}
cdef register_custom_call_target(fn_name, void* fn):
cdef const char* name = "xla._CUSTOM_CALL_TARGET"
cpu_custom_call_targets[fn_name] = PyCapsule_New(fn, name, NULL)
register_custom_call_target(b"sum_inplace_mpi_f32", <void*>(sum_inplace_mpi_f32))
register_custom_call_target(b"sum_inplace_mpi_f64", <void*>(sum_inplace_mpi_f64))
register_custom_call_target(b"sum_inplace_mpi_c64", <void*>(sum_inplace_mpi_c64))
register_custom_call_target(b"sum_inplace_mpi_c128", <void*>(sum_inplace_mpi_c128))
import numpy as _np
import jax
@sum_inplace.register(jax.interpreters.xla.DeviceArray)
def sum_inplace_jax(x):
# if not isinstance(x, jax.interpreters.xla.DeviceArray):
# raise TypeError("Argument to sum_inplace_jax must be a DeviceArray, got {}"
# .format(type(x)))
if _n_nodes == 1:
return x
# This below only works on cpus...
# we should make this work for gpus too..
# TODO: unsafe_buffer_pointer is considered not yet definitive interface
ptr = x.block_until_ready().device_buffer.unsafe_buffer_pointer()
# The above is faster.
# This below should work more often, but might copy.
# Depending on future changes in jaxlib, we might have to switch to
# this below.
# see Google/jax #2123 and #1009
# _x = jax.xla._force(x.block_until_ready())
# ptr = _x.device_buffer.unsafe_buffer_pointer()
# using native numpy because jax's numpy does not have ctypeslib
data_pointer = _np.ctypeslib.ndpointer(x.dtype, shape=x.shape)
# wrap jax data into a standard numpy array which is handled by MPI
arr = data_pointer(ptr).contents
_MPI_comm.Allreduce(_MPI.IN_PLACE, arr.reshape(-1), op=_MPI.SUM)
return x
from jax import core
from jax import abstract_arrays
from jax.lib import xla_client
from jax.interpreters import xla
_ops = xla_client.ops
## The underlying jax primitive
sum_inplace_p = core.Primitive("sum_inplace_mpi") # Create the primitive
# This function applies the primitive to a AST
def sum_inplace_jax_primitive(x):
return sum_inplace_p.bind(x)
# this function executes the primitive, when not under any transformation
sum_inplace_p.def_impl(sum_inplace_jax)
# def sum_inplace_impl(x):
# return sum_inplace_jax(x)
# sum_inplace_p.def_impl(sum_inplace_impl)
# This function evaluates only the shapes during AST construction
def sum_inplace_abstract_eval(xs):
return abstract_arrays.ShapedArray(xs.shape, xs.dtype)
sum_inplace_p.def_abstract_eval(sum_inplace_abstract_eval)
# Herlper functions
def _constant_s32_scalar(c, x):
return _ops.Constant(c, _np.int32(x))
def _unpack_builder(c):
# If `c` is a ComputationBuilder object, extracts the underlying XlaBuilder.
return getattr(c, "_builder", c)
# This function compiles the operation
def sum_inplace_xla_encode(c, x):
c = _unpack_builder(c)
x_shape = c.GetShape(x)
dtype = x_shape.element_type()
dims = x_shape.dimensions()
# compute total number of elements in array
nitems = dims[0]
for el in dims[1:]:
nitems *= el
# those kernels have been loaded through cython.
if dtype == _np.float32:
kernel = b"sum_inplace_mpi_f32"
elif dtype == _np.float64:
kernel = b"sum_inplace_mpi_f64"
elif dtype == _np.complex64:
kernel = b"sum_inplace_mpi_c64"
elif dtype == _np.complex128:
kernel = b"sum_inplace_mpi_c128"
return _ops.CustomCall(
c,
kernel,
operands=(xla_client.ops.Constant(c, _np.int32(nitems)), x),
shape=xla_client.Shape.array_shape(dtype, dims),
)
# assign to the primitive the correct encoder
xla.backend_specific_translations["cpu"][sum_inplace_p] = sum_inplace_xla_encode
@sum_inplace.register(jax.interpreters.partial_eval.JaxprTracer)
@sum_inplace.register(jax.interpreters.ad.JVPTracer)
def sum_inplace_jax_jittracer(x):
if _n_nodes == 1:
return x
else:
return sum_inplace_jax_primitive(x)
from setuptools import setup
from setuptools.extension import Extension
import os
def mpi_includes():
import mpi4py
config = mpi4py.get_config()
cmd_compile = " ".join([config["mpicc"], "--showme:compile"])
out_stream = os.popen(cmd_compile)
compile_flags = out_stream.read().strip()
include_dirs = [p[2:] for p in compile_flags.split()]
include_dirs.append(mpi4py.get_include())
return include_dirs
setup(
name="mypkg",
packages=[
"mypkg",
"mypkg.cython"],
ext_modules=[
Extension(
name="netket.cython.mpi_xla_bridge",
sources=["mypkg/cython/mpi_xla_bridge.pyx"],
include_dirs=mpi_includes(),
),
],
setup_requires=["setuptools>=18.0", "cython>=0.21", "mpi4py>=3.0.1",],
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment