Created
February 11, 2018 08:48
-
-
Save zvookin/d9b5fb11326f5f1075bfbe59015375a3 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include "Halide.h" | |
#include <cstdio> | |
#include "halide_benchmark.h" | |
#include "pthread.h" | |
using namespace Halide; | |
using namespace Halide::Tools; | |
struct test_func { | |
Param<int32_t> p; | |
Func f; | |
Var x; | |
test_func() { | |
Expr big = 0; | |
for (int i = 0; i < 75; i++) { | |
big += p; | |
} | |
Func inner; | |
inner(x) = x + big; | |
f(x) = inner(x - 1) + inner(x) + inner(x + 1); | |
inner.compute_at(f, x); | |
} | |
}; | |
pthread_mutex_t compiler_mutex; | |
void *separate_func_per_thread_executor(void *arg) { | |
test_func f; | |
pthread_mutex_lock(&compiler_mutex); | |
f.f.compile_jit(); | |
pthread_mutex_unlock(&compiler_mutex); | |
int32_t index = *(int32_t *)arg; | |
f.p.set(index); | |
for (int i = 0; i < 10; i++) { | |
Buffer<int32_t> result = f.f.realize(10); | |
for (int j = 0; j < 10; j++) { | |
assert(result(j) == 3 * (j + (75 * index))); | |
} | |
} | |
return nullptr; | |
} | |
void separate_func_per_thread() { | |
struct thread_info { | |
pthread_t thread; | |
int32_t arg; | |
} threads[16]; | |
for (int i = 0; i < (int)(sizeof(threads)/sizeof(threads[0])); i++) { | |
threads[i].arg = i; | |
pthread_create(&threads[i].thread, nullptr, separate_func_per_thread_executor, &threads[i].arg); | |
} | |
for (int i = 0; i < (int)(sizeof(threads)/sizeof(threads[0])); i++) { | |
pthread_join(threads[i].thread, nullptr); | |
} | |
} | |
struct thread_arg { | |
test_func *test; | |
int32_t index; | |
}; | |
void *same_func_per_thread_executor(void *arg) { | |
struct thread_arg *thread_arg = (struct thread_arg *)arg; | |
for (int i = 0; i < 10; i++) { | |
Buffer<int32_t> result = thread_arg->test->f.realize(10, get_target_from_environment(), { { thread_arg->test->p, thread_arg->index } }); | |
for (int j = 0; j < 10; j++) { | |
assert(result(j) == 3 * (j + (75 * thread_arg->index))); | |
} | |
} | |
return nullptr; | |
} | |
void same_func_per_thread() { | |
struct thread_info { | |
pthread_t thread; | |
thread_arg arg; | |
} threads[16]; | |
test_func f; | |
pthread_mutex_lock(&compiler_mutex); | |
f.f.compile_jit(); | |
pthread_mutex_unlock(&compiler_mutex); | |
for (int i = 0; i < (int)(sizeof(threads)/sizeof(threads[0])); i++) { | |
threads[i].arg.test = &f; | |
threads[i].arg.index = i; | |
pthread_create(&threads[i].thread, nullptr, same_func_per_thread_executor, &threads[i].arg); | |
} | |
for (int i = 0; i < (int)(sizeof(threads)/sizeof(threads[0])); i++) { | |
pthread_join(threads[i].thread, nullptr); | |
} | |
} | |
int main(int argc, char **argv) { | |
pthread_mutex_init(&compiler_mutex, nullptr); | |
double separate_time = benchmark(separate_func_per_thread); | |
printf("Separate compilations time: %fs.\n", separate_time); | |
double same_time = benchmark(same_func_per_thread); | |
printf("One compilation time: %fs.\n", same_time); | |
printf("Success!\n"); | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment