Skip to content

Instantly share code, notes, and snippets.

@ckm3
Last active July 3, 2024 14:32
Show Gist options
  • Save ckm3/ea29d85ad78f5e0fcd17dca9216e491f to your computer and use it in GitHub Desktop.
Save ckm3/ea29d85ad78f5e0fcd17dca9216e491f to your computer and use it in GitHub Desktop.
A simple GPU version of the box least squares (BLS) algorithm with numba
from numba import cuda
import math
import numpy as np
import cupy as cp
from cupyx import optimizing
@cuda.jit(cache=True)
def calculate_part1(t, ivar, period, duration, t0, ivar_in, ivar_out):
# Thread index
tx = cuda.threadIdx.x
ty = cuda.threadIdx.y
bx = cuda.blockIdx.x
by = cuda.blockIdx.y
bdx = cuda.blockDim.x
bdy = cuda.blockDim.y
i = bx * bdx + tx # index for period
j = by * bdy + ty # index for y
if i < period.size:
period_i = period[i]
duration_i = duration[i]
half_period_i = 0.5 * period_i
t0_i = t0[i]
# Calculate ivar_in and ivar_out using atomic operations
if j < t.size:
phase = (t[j] - t0_i + half_period_i) % period_i
if abs(phase - half_period_i) < 0.5 * duration_i:
cuda.atomic.add(ivar_in, i, ivar[j])
else:
cuda.atomic.add(ivar_out, i, ivar[j])
cuda.syncthreads()
@cuda.jit(cache=True)
def calculate_part2(t, y, ivar, period, duration, t0, y_in, y_out):
# Thread index
tx = cuda.threadIdx.x
ty = cuda.threadIdx.y
bx = cuda.blockIdx.x
by = cuda.blockIdx.y
bdx = cuda.blockDim.x
bdy = cuda.blockDim.y
i = bx * bdx + tx # index for period
j = by * bdy + ty # index for y
if i < period.size:
period_i = period[i]
duration_i = duration[i]
half_period_i = 0.5 * period_i
t0_i = t0[i]
# Calculate y_in and y_out
if j < t.size:
phase = (t[j] - t0_i + half_period_i) % period_i
if abs(phase - half_period_i) < 0.5 * duration_i:
cuda.atomic.add(y_in, i, (ivar[j] * y[j]))
else:
cuda.atomic.add(y_out, i, (ivar[j] * y[j]))
cuda.syncthreads()
@cuda.jit(cache=True)
def calculate_part3(y_out, y_in, depth, depth_err, ivar_in, ivar_out, snr):
i = cuda.grid(1)
if i < snr.size:
depth[i] = y_out[i] / ivar_out[i] - y_in[i] / ivar_in[i]
depth_err[i] = math.sqrt(1.0 / ivar_in[i] + 1.0 / ivar_out[i])
snr[i] = depth[i] / depth_err[i]
def cubls(t, y, ivar, period, duration, oversample, dtype=np.float32):
t = cuda.to_device(t.astype(dtype))
y = cuda.to_device(y.astype(dtype))
ivar = cuda.to_device(ivar.astype(dtype))
P, D = np.meshgrid(period, duration, indexing='ij')
Pf = P.flatten()
Df = D.flatten()
phases = [np.arange(0, p + d, d) for p, d in zip(Pf, Df / oversample)]
sizes = [phase.size for phase in phases]
with optimizing.optimize():
phases = cp.concatenate(phases)
P_c = cp.asarray(Pf)
D_c = cp.asarray(Df)
period_repeated = cp.repeat(P_c, sizes)
duration_repeated = cp.repeat(D_c, sizes)
threads_per_block = (32, 16)
blocks_per_grid_x = math.ceil(period_repeated.size / threads_per_block[0])
blocks_per_grid_y = math.ceil(y.size / threads_per_block[1])
blocks_per_grid = (blocks_per_grid_x, blocks_per_grid_y)
calculate_part1[blocks_per_grid, threads_per_block](t, ivar, period_repeated, duration_repeated, phases, ivar_in, ivar_out)
calculate_part2[blocks_per_grid, threads_per_block](t, y, ivar, period_repeated, duration_repeated, phases, y_in, y_out)
calculate_part3.forall(snr.size)(y_out, y_in, depth, depth_err, ivar_in, ivar_out, snr)
return period_repeated, duration_repeated, phases, snr
if __name__ == "__main__":
rng = np.random.default_rng(42)
t = rng.uniform(0, 20, 5000)
t -= np.min(t)
y = np.ones_like(t) - 0.1*((t%3)<0.2) + 0.01*rng.standard_normal(len(t))
ivar = np.ones_like(t) * 0.01
period = np.linspace(1, 10, 1000, dtype=np.float32)
duration = np.arange(0.01, 0.5, 0.02, dtype=np.float32)
oversample = 10
period_repeated, duration_repeated, phases = cubls(t, y, ivar, period, duration, oversample)
best_period = period_repeated.copy_to_host()[np.nanargmax(snr.copy_to_host())]
best_duration = duration_repeated.copy_to_host()[np.nanargmax(snr.copy_to_host())]
best_t0 = phases.copy_to_host()[np.nanargmax(snr.copy_to_host())]
print("Best period:", best_period, "Best duration:", best_duration, "Best t0:", best_t0, 'Best SNR:', np.nanmax(snr.copy_to_host()))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment