Last active
October 16, 2017 13:08
-
-
Save sethbrin/2ec2ba9a374e7eaf003888c14b0b9219 to your computer and use it in GitHub Desktop.
Test performance of arm neon assembly
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <assert.h> | |
#include <stdio.h> | |
#include <algorithm> | |
#include <cstring> | |
#include <functional> | |
#include <vector> | |
#include <chrono> | |
#include <algorithm> | |
#include <cmath> | |
#include <stdlib.h> | |
#include <iostream> | |
#include <arm_neon.h> | |
class Timer { | |
private: | |
using clock = std::chrono::high_resolution_clock; | |
using time_point = clock::time_point; | |
public: | |
Timer() | |
{ | |
reset(); | |
} | |
void reset() | |
{ | |
m_started = false; | |
m_stopped = false; | |
} | |
void start() | |
{ | |
assert(!m_started); | |
assert(!m_stopped); | |
m_started = true; | |
m_start_point = clock::now(); | |
} | |
void stop() | |
{ | |
assert(m_started); | |
assert(!m_stopped); | |
m_stopped = true; | |
m_stop_point = clock::now(); | |
} | |
size_t get_time_in_us() const | |
{ | |
return std::chrono::duration_cast< | |
std::chrono::microseconds>( m_stop_point - m_start_point).count(); | |
} | |
private: | |
bool m_started, m_stopped; | |
time_point m_start_point, m_stop_point; | |
}; | |
using namespace std; | |
//#define DEBUG | |
const bool verbose = false; | |
void init(float* data, size_t N) { | |
for (size_t i = 0; i < N; ++i) { | |
#ifdef DEBUG | |
data[i] = i; | |
#else | |
data[i] = ((float)rand() / RAND_MAX) * 3; | |
#endif | |
if (verbose) { | |
std::cout << "src " | |
<< "index{" << i << "," | |
<< "} val:" << data[i] << std::endl; | |
} | |
} | |
return; | |
} | |
#define rep(i, n) for (auto i = static_cast<decltype(n)>(0); i < (n); ++i) | |
void run_new(const float* src, float* dst, | |
size_t IH, size_t IW, size_t OH, size_t OW, | |
size_t N) { | |
rep(n, N) { | |
const float* src_ptr = src + IW * IH * n; | |
float* outptr = dst; | |
const float* r0 = src_ptr; | |
const float* r1 = src_ptr + IW; | |
float32x4_t k0123 = vdupq_n_f32(3.f); | |
rep(h, OH) { | |
size_t width = OW >> 2; | |
asm volatile( | |
"dup v21.4s, %4.s[0] \n" | |
"dup v22.4s, %4.s[1] \n" | |
"dup v23.4s, %4.s[2] \n" | |
"dup v24.4s, %4.s[3] \n" | |
"mov x3, xzr \n" | |
"0: \n" | |
"ldr q0, [%1] \n" | |
"ld1 {v1.4s, v2.4s}, [%2], #32 \n" | |
"add x3, x3, #0x1 \n" | |
"cmp %0, x3 \n" | |
"ld1 {v3.4s, v4.4s}, [%3], #32 \n" | |
"fmla v0.4s, v1.4s, v21.4s \n" // src[i] * k[i] | |
"fmla v0.4s, v2.4s, v22.4s \n" | |
"fmla v0.4s, v3.4s, v23.4s \n" | |
"fmla v0.4s, v4.4s, v24.4s \n" | |
"str q0, [%1], #16 \n" | |
"bne 0b \n" | |
: "+r"(width), "+r"(outptr), "+r"(r0), "+r"(r1) | |
: "w"(k0123) | |
: "cc", "memory", "x3", "v0", "v1", "v2", "v3", "v4", "v21", "v22", "v23", "v24"); | |
} | |
} | |
} | |
void run_origin(const float* src, float* dst, | |
size_t IH, size_t IW, size_t OH, size_t OW, | |
size_t N) { | |
rep(n, N) { | |
const float* src_ptr = src + IW * IH * n; | |
float* outptr = dst; | |
const float* r0 = src_ptr; | |
const float* r1 = src_ptr + IW; | |
float32x4_t k0123 = vdupq_n_f32(3.f); | |
rep(h, OH) { | |
size_t width = OW >> 2; | |
asm volatile( | |
"dup v21.4s, %4.s[0] \n" | |
"dup v22.4s, %4.s[1] \n" | |
"dup v23.4s, %4.s[2] \n" | |
"dup v24.4s, %4.s[3] \n" | |
"mov x3, xzr \n" | |
"mov x4, xzr \n" | |
"0: \n" | |
"add x19, %2, x4 \n" | |
"ldr q0, [%1] \n" // load dst 0, 1, 2, 3 | |
"ld1 {v1.4s, v2.4s}, [x19]\n" // 1, 2, 4, 6 | |
"add x3, x3, #0x1 \n" | |
"cmp %0, x3 \n" | |
"add x19, %3, x4 \n" | |
"ld1 {v3.4s, v4.4s}, [x19]\n" | |
"fmla v0.4s, v1.4s, v21.4s \n" // src[i] * k[i] | |
"fmla v0.4s, v2.4s, v22.4s \n" | |
"fmla v0.4s, v3.4s, v23.4s \n" | |
"fmla v0.4s, v4.4s, v24.4s \n" | |
"add x4, x4, #0x20 \n" | |
"str q0, [%1], #16 \n" | |
"bne 0b \n" | |
"add %2, %2, x4 \n" | |
"add %3, %3, x4 \n" | |
: "+r"(width), "+r"(outptr), "+r"(r0), "+r"(r1) | |
: "w"(k0123) | |
: "cc", "memory", "x3", "x4", "x19", "v0", "v1", "v2", "v3", "v4", "v21", "v22", "v23", "v24"); | |
} | |
} | |
} | |
void check() { | |
size_t IW = 224; | |
size_t IH = 224; | |
size_t OW = 112; | |
size_t OH = 112; | |
size_t N = 12; | |
size_t IN = IH * IW * N; | |
size_t ON = OH * OW; | |
vector<float> src(IN + 100, 0.f); | |
vector<float> dst(ON + 100, 0.f), dst_expect(ON, 0.f); | |
IN = IN + 2; // add padding | |
init(src.data(), IN); | |
run_origin(src.data(), dst.data(), IH, IW, | |
OH, OW, N); | |
run_new(src.data(), dst_expect.data(), IH, IW, | |
OH, OW, N); | |
for (size_t i = 0; i < ON; i++) { | |
if (std::abs(dst[i] - dst_expect[i]) > 1e-3) { | |
printf("%zu dst:%f dst_expect: %f\n", i, dst[i], dst_expect[i]); | |
} | |
} | |
} | |
void bench() { | |
float* src = new float[3000000]; | |
float* dst = new float[3000000]; | |
const int RUN = 1000; | |
Timer timer; | |
using Fun = std::function<void(const float*, float*, size_t, | |
size_t, size_t, size_t, size_t)>; | |
auto run = [&](size_t N, size_t IH, size_t IW, size_t OH, size_t OW) { | |
size_t IN = IH * IW * N + 2; | |
size_t ON = OH * OW; | |
init(src, IN); | |
printf("N: %zu IH: %zu IW: %zu OH: %zu OW: %zu\n", N, IH, IW, OH, OW); | |
auto get_mflops = [&](Fun fun) { | |
// warm up | |
rep(i, 10) | |
fun(src, dst, IH, IW, OH, OW, N); | |
size_t time = 0; | |
rep(i, RUN) { | |
memset(dst, 0.f, ON * sizeof(float)); | |
timer.start(); | |
fun(src, dst, IH, IW, OH, OW, N); | |
timer.stop(); | |
time += timer.get_time_in_us(); | |
timer.reset(); | |
} | |
// return static_cast<double>(time)/ RUN; | |
return (N * OH * OW * 2 * 2) / (static_cast<double>(time) / RUN); | |
}; | |
double origin = 0, cur = 0; | |
origin = get_mflops(run_origin); | |
cur = get_mflops(run_new); | |
printf("perf origin: %.5lf mflops --- asm: %.5lf mflops --- speedup: " | |
"%.5lf\n", | |
origin, cur, cur / origin); | |
}; | |
run(12, 224, 224, 224/2, 112); | |
delete[] src; | |
delete[] dst; | |
} | |
int main() { | |
check(); | |
bench(); | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment