Skip to content

Instantly share code, notes, and snippets.

@Sean1708
Created February 11, 2020 09:48
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Sean1708/69c5694048e9a9ca7bd84fcbc9ea6f45 to your computer and use it in GitHub Desktop.
Save Sean1708/69c5694048e9a9ca7bd84fcbc9ea6f45 to your computer and use it in GitHub Desktop.
#!/usr/bin/env julia
import Markdown
import Statistics
const OPENBLAS_URL = "https://github.com/xianyi/OpenBLAS/archive/v0.2.20.tar.gz"
struct Argument
type::Symbol
name::Symbol
end
Base.show(io::IO, arg::Argument) = print(io, arg.type, " ", arg.name)
struct Func
name::Symbol
return_type::Symbol
arguments::Vector{Argument}
body::String
end
function Base.show(io::IO, func::Func)
println(io, func.return_type, " ", func.name, "(", join(func.arguments, ", "), ") {")
println(io, func.body)
println(io, "}")
end
struct Algo
algo::Func
extras::Vector{Func}
setup::String
end
standard_multiply(body) = Func(
:matrix_multiply,
:void,
[
Argument(:Matrix, :left),
Argument(:Matrix, :right),
Argument(:Matrix, :result),
],
body,
)
const ALGOS = Dict(
:blas => Algo(
standard_multiply("
cblas_dgemm(
CblasRowMajor,
CblasNoTrans, CblasNoTrans,
(blasint)left.rows, (blasint)right.cols, (blasint)left.cols,
1.0,
left.data, (blasint)left.stride,
right.data, (blasint)right.stride,
0.0,
result.data, (blasint)result.stride);
"),
[],
"",
),
# Accesses `right` in a cache-unfriendly way.
:naive => Algo(
standard_multiply("
for (size_t row = 0; row < left.rows; row++) {
for (size_t col = 0; col < right.cols; col++) {
MATRIX_INDEX(result, row, col) = 0.0;
for (size_t inner = 0; inner < left.cols; inner++) {
MATRIX_INDEX(result, row, col) += MATRIX_INDEX(left, row, inner) * MATRIX_INDEX(right, inner, col);
}
}
}
"),
[],
"",
),
# Accesses matrices in cache-friendly chunks, but requires zeroing and knowing the cache size.
:tiled => Algo(
Func(
:matrix_multiply,
:void,
[
Argument(:Matrix, :left),
Argument(:Matrix, :right),
Argument(:Matrix, :result),
Argument(:size_t, :cache_factor),
],
"
for (size_t row = 0; row < result.rows; row++) {
for (size_t col = 0; col < result.cols; col++) {
MATRIX_INDEX(result, row, col) = 0.0;
}
}
for (size_t ROW = 0; ROW < left.rows; ROW += cache_factor) {
for (size_t COL = 0; COL < right.cols; COL += cache_factor) {
for (size_t INNER = 0; INNER < left.cols; INNER += cache_factor) {
for (size_t row = ROW; row < ROW + cache_factor && row < left.rows; row++) {
for (size_t col = COL; col < COL + cache_factor && col < right.cols; col++) {
for (size_t inner = INNER; inner < INNER + cache_factor && inner < left.cols; inner++) {
MATRIX_INDEX(result, row, col) += MATRIX_INDEX(left, row, inner) * MATRIX_INDEX(right, inner, col);
}
}
}
}
}
}
",
),
[],
"size_t cache_factor = (size_t)sqrt(8.0 * sysconf(_SC_LEVEL1_DCACHE_SIZE) / sysconf(_SC_LEVEL1_DCACHE_LINESIZE));",
),
# Cache-oblivious algorithm, but requires zeroing `result`.
:fastf77 => Algo(
standard_multiply("
for (size_t row = 0; row < result.rows; row++) {
for (size_t col = 0; col < result.cols; col++) {
MATRIX_INDEX(result, row, col) = 0.0;
}
}
for (size_t row = 0; row < left.rows; row++) {
for (size_t inner = 0; inner < left.cols; inner++) {
// TODO: Will the compiler hoist this itself?
double temp = MATRIX_INDEX(left, row, inner);
for (size_t col = 0; col < right.cols; col++) {
MATRIX_INDEX(result, row, col) += MATRIX_INDEX(right, inner, col) * temp;
}
}
}
"),
[],
"",
),
# Cache-oblivious recursive algorithm which lends itself to multi-threading.
:divide_and_conquer => Algo(standard_multiply(""), [], ""),
:naive_simd => Algo(standard_multiply(""), [], ""),
:tiled_simd => Algo(standard_multiply(""), [], ""),
:fastf77_simd => Algo(standard_multiply(""), [], ""),
# O(n^2.8) algorithm with small constant coefficients.
:strassen => Algo(standard_multiply(""), [], ""),
# O(n^2.4) algorithm with large constant coeficients.
:coppersmith_winograd => Algo(standard_multiply(""), [], ""),
)
program(algo, runs, iterations, (m, n, p)) = """
#include <cblas.h>
#include <math.h>
#include <stdio.h>
#include <stdlib.h>
#include <time.h>
#include <unistd.h>
#define _XOPEN_SOURCE
typedef struct {
double* data;
size_t rows;
size_t cols;
size_t stride;
} Matrix;
#define MATRIX_INDEX(matrix, row, col) matrix.data[(row)*(matrix).stride + (col)]
Matrix matrix_zero(size_t rows, size_t cols) {
Matrix matrix;
matrix.stride = cols;
matrix.rows = rows;
matrix.cols = cols;
matrix.data = calloc(rows * cols, sizeof (double));
return matrix;
}
Matrix matrix_random(size_t rows, size_t cols) {
Matrix matrix = matrix_zero(rows, cols);
for (size_t row = 0; row < rows; row++) {
for (size_t col = 0; col < cols; col++) {
// Add 1 because we're not interested in gotchas related to denormal numbers.
MATRIX_INDEX(matrix, row, col) = 1.0 + drand48();
}
}
return matrix;
}
Matrix matrix_view(Matrix parent, size_t row, size_t col, size_t rows, size_t cols) {
Matrix view;
view.stride = parent.stride;
view.rows = rows;
view.cols = cols;
view.data = &MATRIX_INDEX(parent, row, col);
return view;
}
void matrix_fprint(FILE* output, Matrix m) {
fprintf(output, "[");
for (size_t row = 0; row < m.rows; row++) {
for (size_t col = 0; col < m.cols; col++) {
fprintf(output, "%.8lf ", MATRIX_INDEX(m, row, col));
}
fprintf(output, "; ");
}
fprintf(output, "]");
}
$(join(algo.extras, '\n'))
$(algo.algo)
int main(void) {
$(algo.setup)
for (size_t run = 0; run < $runs; run++) {
Matrix left = matrix_random($m, $n);
Matrix right = matrix_random($n, $p);
Matrix result = matrix_random($m, $p);
for (size_t iteration = 0; iteration < $iterations; iteration++) {
clock_t start = clock();
$(algo.algo.name)($(join([arg.name for arg in algo.algo.arguments], ", ")));
clock_t stop = clock();
printf("%lf ", (double)(stop - start) / CLOCKS_PER_SEC);
}
printf("\\n");
if ($m + $n + $p < 100 && run == 0) {
fprintf(stderr, "isapprox(");
matrix_fprint(stderr, left);
fprintf(stderr, " * ");
matrix_fprint(stderr, right);
fprintf(stderr, ", ");
matrix_fprint(stderr, result);
fprintf(stderr, ")");
}
free(left.data);
free(right.data);
free(result.data);
}
}
"""
const RUNS = 3
const ITERATIONS = 10
function run(blas, algo, shape)
mktempdir() do dir
cd(dir) do
open("source.c", write = true) do handle
print(handle, program(algo, RUNS, ITERATIONS, shape))
end
Base.run(`gcc -o prog -O3 -I$blas/include source.c $blas/lib/libopenblas.a -lm -lpthread`)
test_cases = Pipe()
output = map((row) -> parse.(Float64,row), split.(split(
chomp(read(pipeline(`./prog`, stderr=test_cases), String)),
'\n',
)))
close(test_cases.in)
for case in readlines(test_cases)
@assert(include_string(Main, case), case)
end
minimum(Statistics.median.(output))
end
end
end
function run(blas, algos, shapes::Vector)
results = Dict()
for algo in algos
results[algo] = Dict()
for shape in shapes
results[algo][shape] = run(blas, ALGOS[algo], shape)
end
end
results
end
function display_results(results, order=sort(collect(keys(results))))
shapes = sort(collect(keys(collect(values(results))[1])))
table = [Any[""; string.(shapes)]]
for algo in order
push!(table, [algo; [round(results[algo][shape], sigdigits=3) for shape in shapes]])
end
Markdown.term(
stdout,
Markdown.Table(table, [:l; repeat([:r], length(shapes))]),
Markdown.cols(stdout),
)
println()
end
function build_openblas(dir, url)
cd(dir) do
download(url, "v0.2.20.tar.gz")
Base.run(`tar -xzf v0.2.20.tar.gz`)
prefix = abspath(joinpath(dir, "blas"))
cd("OpenBLAS-0.2.20")
Base.run(`make`)
Base.run(`make PREFIX=$prefix install`)
prefix
end
end
function main(args)
algos = [
:naive,
:tiled,
:fastf77,
:blas,
]
shapes = [
(4, 4, 4),
(5, 5, 5),
(32, 32, 32),
(33, 33, 33),
(256, 256, 256),
(257, 257, 257),
(512, 512, 512),
(513, 513, 513),
(1024, 1024, 1024),
(1025, 1025, 1025),
]
mktempdir() do dir
blas = if isempty(args)
println("WARNING: You have not passed a path to the BLAS library!")
println(" This script will download and build OpenBLAS in a temporary directory!")
println(" If you plan to run this script more than once please hit Ctrl-C and run")
println(""" julia -e 'include("$PROGRAM_FILE"); build_openblas(".", OPENBLAS_URL)'""")
println(" then run")
println(" julia $PROGRAM_FILE ./blas")
println(" each time.")
println(" Otherwise hit Enter to continue.")
readline(stdin)
build_openblas(dir, OPENBLAS_URL)
else
abspath(args[1])
end
results = run(blas, algos, shapes)
display_results(results, algos)
end
end
if abspath(PROGRAM_FILE) == @__FILE__
main(ARGS)
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment