Skip to content

Instantly share code, notes, and snippets.

@Findus23
Last active February 22, 2024 10:11
Show Gist options
  • Save Findus23/eb5ecb9f65ccf13152cda7c7e521cbdd to your computer and use it in GitHub Desktop.
Save Findus23/eb5ecb9f65ccf13152cda7c7e521cbdd to your computer and use it in GitHub Desktop.
distributed 3D-rfft using JAX
#!/bin/bash
#SBATCH --time=01:00:00
#SBATCH --nodes=2 #equal to -N 1
#SBATCH --tasks-per-node=2
#SBATCH --exclusive
#SBATCH --job-name=jax-fft-test
#SBATCH --gpus=4
#SBATCH --output output/slurm-%j.out
nvidia-smi
source $DATA/venv-jax/bin/activate
cd ~/jax-testing/
#export XLA_PYTHON_CLIENT_PREALLOCATE=false
#export XLA_PYTHON_CLIENT_ALLOCATOR=platform
srun --output "output/slurm-%2j-%2t.out" python -u main.py
import os
from pathlib import Path
import jax
import numpy as np
import scipy
from jax import jit
from jax.experimental import mesh_utils
from jax.experimental.multihost_utils import sync_global_devices
from jax.sharding import Mesh, PartitionSpec as P, NamedSharding
import sharded_rfft_general
from utils import Timing, plot_graph
print("jax version", jax.__version__)
num_gpus = int(os.environ.get("SLURM_GPUS"))
# jax.config.update("jax_enable_x64", True)
def host_subset(array: jax.Array | np.ndarray, size: int):
host_id = jax.process_index()
start = host_id * size // num_gpus
end = (host_id + 1) * size // num_gpus
return array[:, start:end]
def print_subset(x):
print(x[0, :4, :4])
def compare(a, b):
is_equal = np.allclose(a, b, rtol=1.e-2, atol=1.e-4)
print(is_equal)
diff = a - np.asarray(b)
max_value = np.max(np.real(out_ref_subs))
max_diff = np.max(np.abs(diff))
print("max_value", max_value)
print("max_diff", max_diff)
print("max_diff / max_value", max_diff / max_value)
print("distributed initialize")
jax.distributed.initialize()
timing = Timing(print)
print("CUDA_VISIBLE_DEVICES", os.environ.get("CUDA_VISIBLE_DEVICES"))
print("devices:", jax.device_count(), jax.devices())
print("local_devices:", jax.local_device_count(), jax.local_devices())
print("process_index", jax.process_index())
print("total number of GPUs:", num_gpus)
timing.log("random ICs")
size = 512
rng = np.random.default_rng(12345)
x_np_full = rng.random((size, size, size), dtype=np.float32)
x_np = host_subset(x_np_full, size)
print("x_np shape", x_np.shape)
global_shape = (size, size, size)
timing.log("generated")
print(x_np.nbytes / 1024 / 1024 / 1024, "GB")
print(x_np.shape, x_np.dtype)
devices = mesh_utils.create_device_mesh((num_gpus,))
mesh = Mesh(devices, axis_names=('gpus',))
timing.log("start")
with mesh:
x_single = jax.device_put(x_np)
xshard = jax.make_array_from_single_device_arrays(
global_shape,
NamedSharding(mesh, P(None, "gpus")),
[x_single])
rfftn_jit = jit(
sharded_rfft_general.rfftn,
donate_argnums=0, # doesn't help
in_shardings=(NamedSharding(mesh, P(None, "gpus"))),
out_shardings=(NamedSharding(mesh, P(None, "gpus")))
)
irfftn_jit = jit(
sharded_rfft_general.irfftn,
donate_argnums=0,
in_shardings=(NamedSharding(mesh, P(None, "gpus"))),
out_shardings=(NamedSharding(mesh, P(None, "gpus")))
)
if jax.process_index() == 0:
with jax.spmd_mode('allow_all'):
a = Path("compiled.txt")
a.write_text(rfftn_jit.lower(xshard).compile().as_text())
z = jax.xla_computation(rfftn_jit)(xshard)
plot_graph(z)
sync_global_devices("wait for compiler output")
with jax.spmd_mode('allow_all'):
timing.log("warmup")
rfftn_jit(xshard).block_until_ready()
timing.log("calculating")
out_jit: jax.Array = rfftn_jit(xshard).block_until_ready()
print(out_jit.nbytes / 1024 / 1024 / 1024, "GB")
print(out_jit.shape, out_jit.dtype)
timing.log("inverse calculating")
out_inverse: jax.Array = irfftn_jit(out_jit).block_until_ready()
timing.log("collecting")
sync_global_devices("loop")
local_out_subset = out_jit.addressable_data(0)
local_inverse_subset = out_inverse.addressable_data(0)
print(local_out_subset.shape)
print_subset(local_out_subset)
# print("JAX output without JIT:")
# print_subset(out)
# print("JAX output with JIT:")
# # print_subset(out_jit)
# print("out_jit.shape1", out_jit.shape)
# print(out_jit.dtype)
timing.log("done")
out_ref = scipy.fft.rfftn(x_np_full, workers=128)
timing.log("ref done")
print("out_ref", out_ref.shape)
out_ref_subs = host_subset(out_ref, size)
print("out_ref_subs", out_ref_subs.shape)
print("JAX output with JIT:")
print_subset(local_out_subset)
print("Reference output:")
print_subset(out_ref_subs)
print("ref")
compare(out_ref_subs, local_out_subset)
print("inverse")
compare(x_np, local_inverse_subset)
print_subset(x_np)
print_subset(local_inverse_subset)
from typing import Callable
import jax
from jax.experimental.custom_partitioning import custom_partitioning
from jax.sharding import PartitionSpec as P, NamedSharding
def fft_partitioner(fft_func: Callable[[jax.Array], jax.Array], partition_spec: P):
@custom_partitioning
def func(x):
return fft_func(x)
def supported_sharding(sharding, shape):
return NamedSharding(sharding.mesh, partition_spec)
def partition(arg_shapes, arg_shardings, result_shape, result_sharding):
return fft_func, supported_sharding(arg_shardings[0], arg_shapes[0]), (
supported_sharding(arg_shardings[0], arg_shapes[0]),)
def infer_sharding_from_operands(arg_shapes, arg_shardings, shape):
return supported_sharding(arg_shardings[0], arg_shapes[0])
func.def_partition(
infer_sharding_from_operands=infer_sharding_from_operands,
partition=partition
)
return func
def _fft_XY(x):
return jax.numpy.fft.fftn(x, axes=[0, 1])
def _fft_Z(x):
return jax.numpy.fft.rfft(x, axis=2)
def _ifft_XY(x):
return jax.numpy.fft.ifftn(x, axes=[0, 1])
def _ifft_Z(x):
return jax.numpy.fft.irfft(x, axis=2)
fft_XY = fft_partitioner(_fft_XY, P(None, None, "gpus"))
fft_Z = fft_partitioner(_fft_Z, P(None, "gpus"))
ifft_XY = fft_partitioner(_ifft_XY, P(None, None, "gpus"))
ifft_Z = fft_partitioner(_ifft_Z, P(None, "gpus"))
def rfftn(x):
x = fft_Z(x)
x = fft_XY(x)
return x
def irfftn(x):
x = ifft_XY(x)
x = ifft_Z(x)
return x
import subprocess
import time
def plot_graph(z):
with open("t.dot", "w") as f:
f.write(z.as_hlo_dot_graph())
with open("t.png", "wb") as f:
subprocess.run(["dot", "t.dot", "-Tpng"], stdout=f)
class Timing:
def __init__(self,print_func):
self.start = time.perf_counter()
self.last = self.start
self.print_func=print_func
def log(self, message: str) -> None:
now = time.perf_counter()
delta = now - self.start
self.print_func(f"{delta:.4f} / {now - self.last:.4f}: {message}")
self.last = now
@Findus23
Copy link
Author

Oh, I thought I updated this snippet after that change, but apparently I didn't.

I will do that in the next week.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment