Skip to content

Instantly share code, notes, and snippets.

@takagi
Created December 21, 2022 01:11
Show Gist options
  • Save takagi/bfd7b5ef66ad69af617ed1ceb9ec3c0f to your computer and use it in GitHub Desktop.
Save takagi/bfd7b5ef66ad69af617ed1ceb9ec3c0f to your computer and use it in GitHub Desktop.
Eliminate D2H sync on ascending flag
diff --git a/cupyx/scipy/interpolate/_interpolate.py b/cupyx/scipy/interpolate/_interpolate.py
index bab74671e..ec3c5bcac 100644
--- a/cupyx/scipy/interpolate/_interpolate.py
+++ b/cupyx/scipy/interpolate/_interpolate.py
@@ -22,7 +22,7 @@ INTERVAL_KERNEL = r'''
extern "C" {
__global__ void find_breakpoint_position(
const double* breakpoints, const double* x, long long* out,
- bool extrapolate, int total_x, int total_breakpoints, bool asc) {
+ bool extrapolate, int total_x, int total_breakpoints, const bool* pasc) {
int idx = blockDim.x * blockIdx.x + threadIdx.x;
if(idx >= total_x) {
@@ -32,6 +32,7 @@ __global__ void find_breakpoint_position(
double xp = *&x[idx];
double a = *&breakpoints[0];
double b = *&breakpoints[total_breakpoints - 1];
+ bool asc = pasc[0];
if(isnan(xp)) {
out[idx] = -1;
@@ -224,7 +225,7 @@ __global__ void integrate(
const double* a_val, const double* b_val,
const long long* start, const long long* end,
const long long* c_dims, const long long* c_strides,
- bool asc, T* out) {
+ const bool* pasc, T* out) {
int idx = blockDim.x * blockIdx.x + threadIdx.x;
const long long c_dim2 = *&c_dims[2];
@@ -233,10 +234,11 @@ __global__ void integrate(
return;
}
- const long long start_interval = *&start[0];
- const long long end_interval = *&end[0];
- const double a = *&a_val[0];
- const double b = *&b_val[0];
+ const bool asc = pasc[0];
+ const long long start_interval = asc ? *&start[0] : *&end[0];
+ const long long end_interval = asc ? *&end[0] : *&start[0];
+ const double a = asc ? *&a_val[0] : *&b_val[0];
+ const double b = asc ? *&b_val[0] : *&a_val[0];
const long long stride_0 = *&c_strides[0];
const long long stride_1 = *&c_strides[1];
@@ -330,7 +332,7 @@ def _ppoly_evaluate(c, x, xp, dx, extrapolate, out):
This argument is modified in-place.
"""
# Determine if the breakpoints are in ascending order or descending one
- ascending = (x[x.shape[0] - 1] >= x[0]).item()
+ ascending = x[-1] >= x[0]
intervals = cupy.empty(xp.shape, dtype=cupy.int64)
interval_kernel = INTERVAL_MODULE.get_function('find_breakpoint_position')
@@ -397,12 +399,10 @@ def _integrate(c, x, a, b, extrapolate, out):
This argument is modified in-place.
"""
# Determine if the breakpoints are in ascending order or descending one
- ascending = (x[x.shape[0] - 1] >= x[0]).item()
+ ascending = x[-1] >= x[0]
a = cupy.asarray([a], dtype=cupy.float64)
b = cupy.asarray([b], dtype=cupy.float64)
- if not ascending:
- a, b = b, a
start_interval = cupy.empty(a.shape, dtype=cupy.int64)
end_interval = cupy.empty(b.shape, dtype=cupy.int64)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment