Created
September 23, 2021 11:41
-
-
Save itzmeanjan/ca258ec1479e88837e1cd9451c9ff54c to your computer and use it in GitHub Desktop.
β
Large Scale Parallel Matrix Multiplication, in multiple ways, using SYCL DPC++ π
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <CL/sycl.hpp> | |
#include <chrono> | |
#include <iostream> | |
using namespace sycl; | |
constexpr uint N = 1024; | |
constexpr uint B = 32; | |
int64_t multiply_matrix_matrix_v0(queue &q, const float *matrix_a, | |
const float *matrix_b, | |
float *const matrix_c) { | |
buffer<float, 2> b_matrix_a(matrix_a, range<2>{N, N}); | |
buffer<float, 2> b_matrix_b(matrix_b, range<2>{N, N}); | |
buffer<float, 2> b_matrix_c(matrix_c, range<2>{N, N}); | |
std::chrono::_V2::steady_clock::time_point start = | |
std::chrono::steady_clock::now(); | |
auto event = q.submit([&](handler &h) { | |
accessor<float, 2, access::mode::read, access::target::global_buffer> | |
acc_matrix_a(b_matrix_a, h); | |
accessor<float, 2, access::mode::read, access::target::global_buffer> | |
acc_matrix_b(b_matrix_b, h); | |
accessor<float, 2, access::mode::write, access::target::global_buffer> | |
acc_matrix_c(b_matrix_c, h); | |
h.parallel_for(nd_range<2>{{N, N}, {B, 1}}, [=](nd_item<2> it) { | |
const size_t r = it.get_global_id(0); | |
const size_t c = it.get_global_id(1); | |
float sum = 0.f; | |
for (uint k = 0; k < N; k++) { | |
sum += acc_matrix_a[r][k] * acc_matrix_b[k][c]; | |
} | |
acc_matrix_c[r][c] = sum; | |
}); | |
}); | |
event.wait(); | |
std::chrono::_V2::steady_clock::time_point end = | |
std::chrono::steady_clock::now(); | |
return std::chrono::duration_cast<std::chrono::milliseconds>(end - start) | |
.count(); | |
} | |
int64_t multiply_matrix_matrix_v1(queue &q, const float *matrix_a, | |
const float *matrix_b, | |
float *const matrix_c) { | |
buffer<float, 2> b_matrix_a(matrix_a, range<2>{N, N}); | |
buffer<float, 2> b_matrix_b(matrix_b, range<2>{N, N}); | |
buffer<float, 2> b_matrix_c(matrix_c, range<2>{N, N}); | |
std::chrono::_V2::steady_clock::time_point start = | |
std::chrono::steady_clock::now(); | |
auto event = q.submit([&](handler &h) { | |
accessor<float, 2, access::mode::read, access::target::global_buffer> | |
acc_matrix_a(b_matrix_a, h); | |
accessor<float, 2, access::mode::read, access::target::global_buffer> | |
acc_matrix_b(b_matrix_b, h); | |
accessor<float, 2, access::mode::write, access::target::global_buffer> | |
acc_matrix_c(b_matrix_c, h); | |
accessor<float, 1, access::mode::read_write, access::target::local> | |
acc_local(range<1>{B}, h); | |
h.parallel_for(nd_range<2>{{N, N}, {1, B}}, [=](nd_item<2> it) { | |
const size_t r = it.get_global_id(0); | |
const size_t c = it.get_global_id(1); | |
const size_t l = it.get_local_id(1); | |
float sum = 0.f; | |
for (uint k = 0; k < N; k += B) { | |
acc_local[l] = acc_matrix_a[r][k + l]; | |
it.barrier(); | |
for (uint k_ = 0; k_ < B; k_++) { | |
sum += acc_local[k_] * acc_matrix_b[k + k_][c]; | |
} | |
it.barrier(); | |
} | |
acc_matrix_c[r][c] = sum; | |
}); | |
}); | |
event.wait(); | |
std::chrono::_V2::steady_clock::time_point end = | |
std::chrono::steady_clock::now(); | |
return std::chrono::duration_cast<std::chrono::milliseconds>(end - start) | |
.count(); | |
} | |
int64_t multiply_matrix_matrix_v2(queue &q, const float *matrix_a, | |
const float *matrix_b, | |
float *const matrix_c) { | |
buffer<float, 2> b_matrix_a(matrix_a, range<2>{N, N}); | |
buffer<float, 2> b_matrix_b(matrix_b, range<2>{N, N}); | |
buffer<float, 2> b_matrix_c(matrix_c, range<2>{N, N}); | |
std::chrono::_V2::steady_clock::time_point start = | |
std::chrono::steady_clock::now(); | |
auto event = q.submit([&](handler &h) { | |
accessor<float, 2, access::mode::read, access::target::global_buffer> | |
acc_matrix_a(b_matrix_a, h); | |
accessor<float, 2, access::mode::read, access::target::global_buffer> | |
acc_matrix_b(b_matrix_b, h); | |
accessor<float, 2, access::mode::write, access::target::global_buffer> | |
acc_matrix_c(b_matrix_c, h); | |
h.parallel_for(nd_range<2>{{N, N}, {1, B}}, [=](nd_item<2> it) { | |
sycl::ONEAPI::sub_group sg = it.get_sub_group(); | |
const size_t r = it.get_global_id(0); | |
const size_t c = it.get_global_id(1); | |
const size_t l = it.get_local_id(1); | |
float sum = 0.f; | |
for (uint k = 0; k < N; k += B) { | |
float tile = acc_matrix_a[r][k + l]; | |
for (uint k_ = 0; k_ < B; k_++) { | |
sum += | |
sycl::ONEAPI::broadcast(sg, tile, k_) * acc_matrix_b[k + k_][c]; | |
} | |
} | |
acc_matrix_c[r][c] = sum; | |
}); | |
}); | |
event.wait(); | |
std::chrono::_V2::steady_clock::time_point end = | |
std::chrono::steady_clock::now(); | |
return std::chrono::duration_cast<std::chrono::milliseconds>(end - start) | |
.count(); | |
} | |
int main(int argc, char **argv) { | |
device d{default_selector{}}; | |
queue q{d}; | |
std::cout << "running on " << d.get_info<info::device::name>() << std::endl; | |
float *matrix_a = (float *)malloc(sizeof(float) * N * N); | |
float *matrix_b = (float *)malloc(sizeof(float) * N * N); | |
float *matrix_c = (float *)malloc(sizeof(float) * N * N); | |
memset(matrix_a, 1, sizeof(float) * N * N); | |
memset(matrix_b, 2, sizeof(float) * N * N); | |
memset(matrix_c, 0, sizeof(float) * N * N); | |
int64_t tm = multiply_matrix_matrix_v0(q, matrix_a, matrix_b, matrix_c); | |
std::cout << "matmul_v0, in " << tm << " ms" << std::endl; | |
memset(matrix_c, 0, sizeof(float) * N * N); | |
tm = multiply_matrix_matrix_v1(q, matrix_a, matrix_b, matrix_c); | |
std::cout << "matmul_v1, in " << tm << " ms" << std::endl; | |
memset(matrix_c, 0, sizeof(float) * N * N); | |
tm = multiply_matrix_matrix_v2(q, matrix_a, matrix_b, matrix_c); | |
std::cout << "matmul_v2, in " << tm << " ms" << std::endl; | |
free(matrix_a); | |
free(matrix_b); | |
free(matrix_c); | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Background
This work is supposed to accompany post I wrote here. I write 3 implementations of parallel matrix multiplication, which can be run on heterogeneous devices.
Usage