Skip to content

Instantly share code, notes, and snippets.

@Findus23
Last active February 22, 2024 10:11
Show Gist options
  • Star 3 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Findus23/eb5ecb9f65ccf13152cda7c7e521cbdd to your computer and use it in GitHub Desktop.
Save Findus23/eb5ecb9f65ccf13152cda7c7e521cbdd to your computer and use it in GitHub Desktop.
distributed 3D-rfft using JAX
#!/bin/bash
#SBATCH --time=01:00:00
#SBATCH --nodes=2 #equal to -N 1
#SBATCH --tasks-per-node=2
#SBATCH --exclusive
#SBATCH --job-name=jax-fft-test
#SBATCH --gpus=4
#SBATCH --output output/slurm-%j.out
nvidia-smi
source $DATA/venv-jax/bin/activate
cd ~/jax-testing/
#export XLA_PYTHON_CLIENT_PREALLOCATE=false
#export XLA_PYTHON_CLIENT_ALLOCATOR=platform
srun --output "output/slurm-%2j-%2t.out" python -u main.py
import os
from pathlib import Path
import jax
import numpy as np
import scipy
from jax import jit
from jax.experimental import mesh_utils
from jax.experimental.multihost_utils import sync_global_devices
from jax.sharding import Mesh, PartitionSpec as P, NamedSharding
import sharded_rfft_general
from utils import Timing, plot_graph
print("jax version", jax.__version__)
num_gpus = int(os.environ.get("SLURM_GPUS"))
# jax.config.update("jax_enable_x64", True)
def host_subset(array: jax.Array | np.ndarray, size: int):
host_id = jax.process_index()
start = host_id * size // num_gpus
end = (host_id + 1) * size // num_gpus
return array[:, start:end]
def print_subset(x):
print(x[0, :4, :4])
def compare(a, b):
is_equal = np.allclose(a, b, rtol=1.e-2, atol=1.e-4)
print(is_equal)
diff = a - np.asarray(b)
max_value = np.max(np.real(out_ref_subs))
max_diff = np.max(np.abs(diff))
print("max_value", max_value)
print("max_diff", max_diff)
print("max_diff / max_value", max_diff / max_value)
print("distributed initialize")
jax.distributed.initialize()
timing = Timing(print)
print("CUDA_VISIBLE_DEVICES", os.environ.get("CUDA_VISIBLE_DEVICES"))
print("devices:", jax.device_count(), jax.devices())
print("local_devices:", jax.local_device_count(), jax.local_devices())
print("process_index", jax.process_index())
print("total number of GPUs:", num_gpus)
timing.log("random ICs")
size = 512
rng = np.random.default_rng(12345)
x_np_full = rng.random((size, size, size), dtype=np.float32)
x_np = host_subset(x_np_full, size)
print("x_np shape", x_np.shape)
global_shape = (size, size, size)
timing.log("generated")
print(x_np.nbytes / 1024 / 1024 / 1024, "GB")
print(x_np.shape, x_np.dtype)
devices = mesh_utils.create_device_mesh((num_gpus,))
mesh = Mesh(devices, axis_names=('gpus',))
timing.log("start")
with mesh:
x_single = jax.device_put(x_np)
xshard = jax.make_array_from_single_device_arrays(
global_shape,
NamedSharding(mesh, P(None, "gpus")),
[x_single])
rfftn_jit = jit(
sharded_rfft_general.rfftn,
donate_argnums=0, # doesn't help
in_shardings=(NamedSharding(mesh, P(None, "gpus"))),
out_shardings=(NamedSharding(mesh, P(None, "gpus")))
)
irfftn_jit = jit(
sharded_rfft_general.irfftn,
donate_argnums=0,
in_shardings=(NamedSharding(mesh, P(None, "gpus"))),
out_shardings=(NamedSharding(mesh, P(None, "gpus")))
)
if jax.process_index() == 0:
with jax.spmd_mode('allow_all'):
a = Path("compiled.txt")
a.write_text(rfftn_jit.lower(xshard).compile().as_text())
z = jax.xla_computation(rfftn_jit)(xshard)
plot_graph(z)
sync_global_devices("wait for compiler output")
with jax.spmd_mode('allow_all'):
timing.log("warmup")
rfftn_jit(xshard).block_until_ready()
timing.log("calculating")
out_jit: jax.Array = rfftn_jit(xshard).block_until_ready()
print(out_jit.nbytes / 1024 / 1024 / 1024, "GB")
print(out_jit.shape, out_jit.dtype)
timing.log("inverse calculating")
out_inverse: jax.Array = irfftn_jit(out_jit).block_until_ready()
timing.log("collecting")
sync_global_devices("loop")
local_out_subset = out_jit.addressable_data(0)
local_inverse_subset = out_inverse.addressable_data(0)
print(local_out_subset.shape)
print_subset(local_out_subset)
# print("JAX output without JIT:")
# print_subset(out)
# print("JAX output with JIT:")
# # print_subset(out_jit)
# print("out_jit.shape1", out_jit.shape)
# print(out_jit.dtype)
timing.log("done")
out_ref = scipy.fft.rfftn(x_np_full, workers=128)
timing.log("ref done")
print("out_ref", out_ref.shape)
out_ref_subs = host_subset(out_ref, size)
print("out_ref_subs", out_ref_subs.shape)
print("JAX output with JIT:")
print_subset(local_out_subset)
print("Reference output:")
print_subset(out_ref_subs)
print("ref")
compare(out_ref_subs, local_out_subset)
print("inverse")
compare(x_np, local_inverse_subset)
print_subset(x_np)
print_subset(local_inverse_subset)
from typing import Callable
import jax
from jax.experimental.custom_partitioning import custom_partitioning
from jax.sharding import PartitionSpec as P, NamedSharding
def fft_partitioner(fft_func: Callable[[jax.Array], jax.Array], partition_spec: P):
@custom_partitioning
def func(x):
return fft_func(x)
def supported_sharding(sharding, shape):
return NamedSharding(sharding.mesh, partition_spec)
def partition(arg_shapes, arg_shardings, result_shape, result_sharding):
return fft_func, supported_sharding(arg_shardings[0], arg_shapes[0]), (
supported_sharding(arg_shardings[0], arg_shapes[0]),)
def infer_sharding_from_operands(arg_shapes, arg_shardings, shape):
return supported_sharding(arg_shardings[0], arg_shapes[0])
func.def_partition(
infer_sharding_from_operands=infer_sharding_from_operands,
partition=partition
)
return func
def _fft_XY(x):
return jax.numpy.fft.fftn(x, axes=[0, 1])
def _fft_Z(x):
return jax.numpy.fft.rfft(x, axis=2)
def _ifft_XY(x):
return jax.numpy.fft.ifftn(x, axes=[0, 1])
def _ifft_Z(x):
return jax.numpy.fft.irfft(x, axis=2)
fft_XY = fft_partitioner(_fft_XY, P(None, None, "gpus"))
fft_Z = fft_partitioner(_fft_Z, P(None, "gpus"))
ifft_XY = fft_partitioner(_ifft_XY, P(None, None, "gpus"))
ifft_Z = fft_partitioner(_ifft_Z, P(None, "gpus"))
def rfftn(x):
x = fft_Z(x)
x = fft_XY(x)
return x
def irfftn(x):
x = ifft_XY(x)
x = ifft_Z(x)
return x
import subprocess
import time
def plot_graph(z):
with open("t.dot", "w") as f:
f.write(z.as_hlo_dot_graph())
with open("t.png", "wb") as f:
subprocess.run(["dot", "t.dot", "-Tpng"], stdout=f)
class Timing:
def __init__(self,print_func):
self.start = time.perf_counter()
self.last = self.start
self.print_func=print_func
def log(self, message: str) -> None:
now = time.perf_counter()
delta = now - self.start
self.print_func(f"{delta:.4f} / {now - self.last:.4f}: {message}")
self.last = now
@Vandermode
Copy link

Hi, @Findus23 , Thank you very much for the great work!

Could you please tell me the jax version you used to run this code?
When using the latest version of jax (0.4.24) to run this code, I got a runtime error as follows and had no clue how to resolve it.

custom_partitioner: TypeError: 'Mesh' object is not subscriptable

Thanks in advance!

@Findus23
Copy link
Author

Hi,

I should have documented this, but I used the latest version back then, so probably 0.4.19.

I have gotten this error also a few times in the past and don't fully understand it. I can't test this code right now, but will report back if it still works like this in the latest version.

@Vandermode
Copy link

Hi, thanks for the reply!

I have resolved the problem. The bug is triggered because the custom_partitioning signature has been changed. So the arg_shapes variable in partition would be filled by mesh object, which is obviously incorrect.

@Findus23
Copy link
Author

Oh, I thought I updated this snippet after that change, but apparently I didn't.

I will do that in the next week.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment