Skip to content

Instantly share code, notes, and snippets.

@jaro-sevcik
Last active May 9, 2025 13:46
Show Gist options
  • Save jaro-sevcik/993c286f5e6277d29fee404d138fcb12 to your computer and use it in GitHub Desktop.
Save jaro-sevcik/993c286f5e6277d29fee404d138fcb12 to your computer and use it in GitHub Desktop.
import jax, jax.numpy as jnp
from jax.experimental import buffer_callback
import jaxlib
# pip install cupy-cuda12x
import cupy as cp
muladd = cp.ElementwiseKernel(
'float32 x, float32 y, float32 f', 'float32 z',
'z = x * y + f',
'muladd')
# Define FFI handler that calls into cupy.
def ffi_muladd(ctx, output, inputs, bias=0.0):
with cp.cuda.ExternalStream(ctx.stream):
muladd(inputs[0], inputs[1], bias, output)
@jax.jit
def f(x):
output_shape = jax.ShapeDtypeStruct(x.shape, x.dtype)
muladd_call = buffer_callback.buffer_callback(ffi_muladd, output_shape,
command_buffer_compatible=True)
return muladd_call((x, x), bias=0.5)
input = jnp.arange(8.0)
for _ in range(10):
print(f(input))
print(f(input))
# Result:
# [ 0.5 1.5 4.5 9.5 16.5 25.5 36.5 49.5]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment