-
-
Save ACUVE/413919d971951ca47b0d to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <atomic> | |
#include <chrono> | |
#include <iostream> | |
#include <limits> | |
#include <condition_variable> | |
#include <thread> | |
#include <vector> | |
#include <omp.h> | |
template< typename FUNC, typename... Args > | |
inline std::chrono::nanoseconds time( FUNC &&func, Args &&... args ) | |
{ | |
auto const start = std::chrono::high_resolution_clock::now(); | |
func( std::forward< Args >( args )... ); | |
auto const end = std::chrono::high_resolution_clock::now(); | |
auto const duration = end - start; | |
return std::chrono::duration_cast< std::chrono::nanoseconds >( duration ); | |
} | |
class ThreadPool | |
{ | |
public: | |
using func_type = void (*)( unsigned int num, unsigned int thread_num ); | |
private: | |
std::condition_variable cond; | |
std::mutex mutex; | |
std::atomic< func_type > func; | |
std::condition_variable waiting_cond; | |
std::atomic< unsigned int > waiting_counter; | |
std::atomic< unsigned int > counter; | |
unsigned int const num_thread; | |
std::vector< std::thread > thread; | |
private: | |
void WaitAllWaiting() | |
{ | |
std::unique_lock< std::mutex > lock( mutex ); | |
waiting_cond.wait( lock, [ this ]{ return waiting_counter.load() == num_thread; } ); | |
} | |
public: | |
ThreadPool() | |
: func( nullptr ) | |
, waiting_counter( 0 ) | |
, counter( 0 ) | |
, num_thread( std::thread::hardware_concurrency() ) | |
{ | |
unsigned int const con = num_thread; | |
for( unsigned int i = 0; i < num_thread; ++i ) | |
thread.emplace_back( &ThreadPool::Thread, this, i ); | |
WaitAllWaiting(); | |
} | |
~ThreadPool() | |
{ | |
func.store( nullptr, std::memory_order_relaxed ); | |
counter.fetch_add( 1, std::memory_order_release ); | |
cond.notify_all(); | |
for( auto &&th : thread ) | |
th.join(); | |
} | |
void Thread( unsigned int const num ) | |
{ | |
unsigned int c = counter.load(); | |
unsigned int const num_t = num_thread; | |
while( true ) | |
{ | |
std::unique_lock< std::mutex > lock( mutex ); | |
waiting_counter.fetch_add( 1, std::memory_order_acq_rel ); | |
waiting_cond.notify_all(); | |
cond.wait( lock, [ this, &c ]{ return counter.load( std::memory_order_acquire ) != c; } ); | |
c = counter.load( std::memory_order_relaxed ); | |
auto f = func.load( std::memory_order_relaxed ); | |
if( !f ) | |
break; | |
f( num, num_t ); | |
} | |
} | |
void Do( func_type f ) | |
{ | |
if( f == nullptr ) | |
return; | |
func.store( f, std::memory_order_relaxed ); | |
counter.fetch_add( 1, std::memory_order_relaxed ); | |
waiting_counter.store( 0, std::memory_order_release ); | |
cond.notify_all(); | |
WaitAllWaiting(); | |
} | |
}; | |
ThreadPool pool; | |
void aaa( unsigned int const num, unsigned int const thread_num ) | |
{ | |
} | |
void openmp( int val ) | |
{ | |
#pragma omp pallarel for | |
for( int i = 0; i < val; ++i ) | |
{ | |
std::cout << i << std::endl; | |
} | |
} | |
int main( int argc, char **argv ) | |
{ | |
for( unsigned int i = 0; i < 100; ++i ) | |
std::cout << time( std::bind( &ThreadPool::Do, &pool, aaa ) ).count() << std::endl; | |
std::cout << time( openmp, argc ).count() << std::endl; | |
std::cout << time( openmp, argc ).count() << std::endl; | |
std::cout << time( openmp, argc ).count() << std::endl; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment