Skip to content

Instantly share code, notes, and snippets.

@abadams
Created December 23, 2016 01:46
Show Gist options
  • Save abadams/c2e6f67d79e1768af6db5afcabb1caab to your computer and use it in GitHub Desktop.
Save abadams/c2e6f67d79e1768af6db5afcabb1caab to your computer and use it in GitHub Desktop.
Scheduling a chain of horizontal stencils
#include "Halide.h"
#include "benchmark.h"
using namespace Halide;
int main(int argc, char **argv) {
Var x, y, z, yi, xi;
Func f1, f2, res;
ImageParam input1(type_of<float>(), 3);
ImageParam input2(type_of<float>(), 3);
Func in1 = BoundaryConditions::constant_exterior(input1, 0.0f);
Func in2 = BoundaryConditions::constant_exterior(input2, 0.0f);
f1(x, y, z) = (in1(x + 1, y, z) + in1(x, y, z) + in1(x - 1, y,z));
f2(x, y, z) = (in2(x + 2, y, z) + in2(x + 1, y, z) + in2(x, y, z) +in2(x - 1, y, z) + in2(x - 2, y, z));
res(x, y, z) = f1(x, y, z) + f1(x - 1, y, z) + f2(x - 1, y, z) + f2(x, y, z);
// Schedule 1: 0.001791
//f1.store_at(res, y).compute_at(res, yi).vectorize(x, 8);
//f2.store_at(res, y).compute_at(res, yi).vectorize(x, 8);
//res.split(y, y, yi, 8).vectorize(x, 8).parallel(y);
// Schedule 2: 0.004027
//res.split(y, y, yi, 8).vectorize(x, 8).parallel(y);
// Schedule 3: 0.001733
//f1.compute_at(res, yi).vectorize(x, 8);
//f2.compute_at(res, yi).vectorize(x, 8);
//res.split(y, y, yi, 8).vectorize(x, 8).parallel(y);
// Schedule 4: 0.002135
//f1.compute_at(res, x).vectorize(x, 8);
//f2.compute_at(res, x).vectorize(x, 8);
//res.split(y, y, yi, 8).split(x, x, xi, 256).vectorize(xi, 8).parallel(y);
// Schedule 5: 0.001710
//in1.compute_at(res, yi).vectorize(in1.args()[0], 8);
//in2.compute_at(res, yi).vectorize(in2.args()[0], 8);
//res.split(y, y, yi, 8).vectorize(x, 8).parallel(y);
// Schedule 6: 0.002415
//for (Func f : {in1, in2, f1, f2}) {
// f.store_at(res, yi).compute_at(res, x).vectorize(f.args()[0], 8);
//}
//res.split(y, y, yi, 8).vectorize(x, 8).parallel(y);
// Schedule 7: 0.001828
//for (Func f : {in1, in2, f1, f2}) {
// f.compute_at(res, yi).vectorize(f.args()[0], 8);
//}
//res.split(y, y, yi, 8).vectorize(x, 8).parallel(y);
// Schedule 8: 0.002070
//in1.compute_at(f1, y).vectorize(in1.args()[0], 8);
//in2.compute_at(f2, y).vectorize(in2.args()[0], 8);
//f1.compute_at(res, y).vectorize(x, 8);
//f2.compute_at(res, y).vectorize(x, 8);
//res.vectorize(x, 8).parallel(y, 8);
// Schedule 9: 0.001757
// f2.compute_at(res, y).vectorize(x, 8);
// res.vectorize(x, 8).parallel(y, 8);
// Schedule 9: 0.002126
in1.compute_at(res, y).vectorize(in1.args()[0], 8);
f2.compute_at(res, y).vectorize(x, 8);
res.vectorize(x, 8).parallel(y, 8);
res.compile_jit();
Buffer<float> b1(1024, 1024, 3);
Buffer<float> b2(1024, 1024, 3);
Buffer<float> out(1024, 1024, 3);
input1.set(b1);
input2.set(b2);
double t = benchmark(10, 10, [&]() {res.realize(out);});
printf("%f\n", t);
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment