Skip to content

Instantly share code, notes, and snippets.

@ro99
Last active October 21, 2021 16:36
Show Gist options
  • Save ro99/8346fd0536b5b37dfb3cd79060a77fe7 to your computer and use it in GitHub Desktop.
Save ro99/8346fd0536b5b37dfb3cd79060a77fe7 to your computer and use it in GitHub Desktop.
tests with rayon (about 1.3 speedup)
unsafe fn run_with_scratch_space_parallel(
&self,
m: usize,
n: usize,
non_linear: &[FusedSpec],
) -> anyhow::Result<()> {
let mr = K::mr();
let nr = K::nr();
let mut rows: Vec<usize> = (0..n / nr).collect();
let size = rows.len() / 32 + rows.len() % 32;
let ctx: ThreadLocalCtx<Box<dyn ScratchSpace>, _> = ThreadLocalCtx::new(|| {
let mut scratch = self.allocate_scratch_space();
scratch
.downcast_mut::<ScratchSpaceFusedNonLinear<TI>>()
.unwrap()
.prepare::<K>(non_linear);
scratch
});
for ia in 0..m / mr {
rows.par_chunks_mut(size).for_each(|row_chunk|{
let row_chunk = row_chunk.to_owned();
let mut scratch = ctx.get();
let scratch = scratch.downcast_mut::<ScratchSpaceFusedNonLinear<TI>>().unwrap();
for ib in row_chunk {
scratch.for_valid_tile::<K>(&non_linear, ia, ib);
let err = K::kernel(&scratch.uspecs());
debug_assert_eq!(err, 0, "Kernel return error {}", err);
}
});
}
if m % mr != 0 {
rows.par_chunks_mut(size).for_each(|row_chunk|{
let row_chunk = row_chunk.to_owned();
let mut scratch = ctx.get();
let scratch = scratch.downcast_mut::<ScratchSpaceFusedNonLinear<TI>>().unwrap();
for ib in row_chunk {
scratch.for_border_tile::<K>(&non_linear, m / mr, ib);
let err = K::kernel(&scratch.uspecs());
debug_assert_eq!(err, 0, "Kernel return error {}", err);
scratch.postprocess_tile::<K>(&non_linear, m / mr, ib, m % mr, nr);
}
});
if n % nr != 0 {
let mut scratch = ctx.get();
let scratch = scratch.downcast_mut::<ScratchSpaceFusedNonLinear<TI>>().unwrap();
scratch.for_border_tile::<K>(&non_linear, m / mr, n / nr);
let err = K::kernel(&scratch.uspecs());
debug_assert_eq!(err, 0, "Kernel return error {}", err);
scratch.postprocess_tile::<K>(&non_linear, m / mr, n / nr, m % mr, n % nr);
}
}
Ok(())
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment