Last active
December 23, 2023 23:16
-
-
Save robertknight/20acb0496803ac5e4bfba9cdfb2bedfd to your computer and use it in GitHub Desktop.
Port of https://github.com/danieldk/gemm-benchmark/tree/main for Wasnn
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
//! Port of https://github.com/danieldk/gemm-benchmark/ for comparison of | |
//! matrix multiplication performance against other popular libraries. | |
use std::time::Duration; | |
use wasnn::gemm; | |
use wasnn_tensor::prelude::*; | |
use wasnn_tensor::NdTensor; | |
use rayon::iter::IntoParallelIterator; | |
use rayon::iter::ParallelIterator; | |
struct BenchmarkStats { | |
elapsed: Duration, | |
flops: usize, | |
} | |
fn gemm_benchmark(dim: usize, iterations: usize, threads: usize) -> BenchmarkStats { | |
let one = 1.0; | |
let two = one + one; | |
let point_five = one / two; | |
let matrix_a = NdTensor::full([dim, dim], two); | |
let matrix_b = NdTensor::full([dim, dim], point_five); | |
let c_matrices: Vec<_> = std::iter::repeat(NdTensor::full([dim, dim], one)) | |
.take(threads) | |
.collect::<Vec<_>>(); | |
let start = std::time::Instant::now(); | |
c_matrices.into_par_iter().for_each(|mut matrix_c| { | |
for _ in 0..iterations { | |
let row_stride = matrix_c.stride(0); | |
gemm( | |
matrix_c.data_mut().unwrap(), | |
row_stride, | |
matrix_a.view(), | |
matrix_b.view(), | |
1., | |
1., | |
); | |
} | |
}); | |
let elapsed = start.elapsed(); | |
BenchmarkStats { | |
elapsed, | |
flops: (dim.pow(3) * 2 * iterations * threads) + (dim.pow(2) * 2 * iterations * threads), | |
} | |
} | |
fn main() { | |
let threads = 4; | |
let iterations = 2000; | |
let dim = 1024; | |
rayon::ThreadPoolBuilder::new() | |
.num_threads(threads) | |
.build_global() | |
.unwrap(); | |
println!("Threads: {}", threads); | |
println!("Iterations per thread: {}", iterations); | |
println!("Matrix shape: {} x {}", dim, dim); | |
let stats = gemm_benchmark(dim, iterations, threads); | |
println!( | |
"GFLOPS: {:.2}", | |
(stats.flops as f64 / stats.elapsed.as_secs_f64()) / 1000_000_000. | |
); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment