Skip to content

Instantly share code, notes, and snippets.

@siddontang
Created December 13, 2017 05:40
Show Gist options
  • Save siddontang/fa81a59e7234e9960f8514785af69e1e to your computer and use it in GitHub Desktop.
Save siddontang/fa81a59e7234e9960f8514785af69e1e to your computer and use it in GitHub Desktop.
#![feature(fnbox)]
use std::usize;
use std::time::{Duration, Instant};
use std::sync::{Arc, Condvar, Mutex};
use std::thread::{Builder, JoinHandle};
use std::marker::PhantomData;
use std::boxed::FnBox;
use std::collections::VecDeque;
use std::sync::atomic::{AtomicUsize, Ordering as AtomicOrdering};
use std::fmt::Write;
use std::sync::mpsc::{channel, Sender};
use std::thread;
use std::env;
pub const DEFAULT_TASKS_PER_TICK: usize = 10000;
const DEFAULT_QUEUE_CAPACITY: usize = 1000;
const DEFAULT_THREAD_COUNT: usize = 1;
const NAP_SECS: u64 = 1;
const QUEUE_MAX_CAPACITY: usize = 8 * DEFAULT_QUEUE_CAPACITY;
pub trait Context: Send {
fn on_task_started(&mut self) {}
fn on_task_finished(&mut self) {}
fn on_tick(&mut self) {}
}
#[derive(Default)]
pub struct DefaultContext;
impl Context for DefaultContext {}
pub trait ContextFactory<Ctx: Context> {
fn create(&self) -> Ctx;
}
pub struct DefaultContextFactory;
impl<C: Context + Default> ContextFactory<C> for DefaultContextFactory {
fn create(&self) -> C {
C::default()
}
}
pub struct Task<C> {
task: Box<FnBox(&mut C) + Send>,
}
impl<C: Context> Task<C> {
fn new<F>(job: F) -> Task<C>
where
for<'r> F: FnOnce(&'r mut C) + Send + 'static,
{
Task {
task: Box::new(job),
}
}
}
// First in first out queue.
pub struct FifoQueue<C> {
queue: VecDeque<Task<C>>,
}
impl<C: Context> FifoQueue<C> {
fn new() -> FifoQueue<C> {
FifoQueue {
queue: VecDeque::with_capacity(DEFAULT_QUEUE_CAPACITY),
}
}
fn push(&mut self, task: Task<C>) {
self.queue.push_back(task);
}
fn pop(&mut self) -> Option<Task<C>> {
let task = self.queue.pop_front();
if self.queue.is_empty() && self.queue.capacity() > QUEUE_MAX_CAPACITY {
self.queue = VecDeque::with_capacity(DEFAULT_QUEUE_CAPACITY);
}
task
}
}
pub struct ThreadPoolBuilder<C, F> {
name: String,
thread_count: usize,
tasks_per_tick: usize,
stack_size: Option<usize>,
factory: F,
_ctx: PhantomData<C>,
}
impl<C: Context + Default + 'static> ThreadPoolBuilder<C, DefaultContextFactory> {
pub fn with_default_factory(name: String) -> ThreadPoolBuilder<C, DefaultContextFactory> {
ThreadPoolBuilder::new(name, DefaultContextFactory)
}
}
impl<C: Context + 'static, F: ContextFactory<C>> ThreadPoolBuilder<C, F> {
pub fn new(name: String, factory: F) -> ThreadPoolBuilder<C, F> {
ThreadPoolBuilder {
name: name,
thread_count: DEFAULT_THREAD_COUNT,
tasks_per_tick: DEFAULT_TASKS_PER_TICK,
stack_size: None,
factory: factory,
_ctx: PhantomData,
}
}
pub fn thread_count(mut self, count: usize) -> ThreadPoolBuilder<C, F> {
self.thread_count = count;
self
}
pub fn tasks_per_tick(mut self, count: usize) -> ThreadPoolBuilder<C, F> {
self.tasks_per_tick = count;
self
}
pub fn stack_size(mut self, size: usize) -> ThreadPoolBuilder<C, F> {
self.stack_size = Some(size);
self
}
pub fn build(self) -> ThreadPool<C> {
ThreadPool::new(
self.name,
self.thread_count,
self.tasks_per_tick,
self.stack_size,
self.factory,
)
}
}
struct ScheduleState<Ctx> {
queue: FifoQueue<Ctx>,
stopped: bool,
}
/// `ThreadPool` is used to execute tasks in parallel.
/// Each task would be pushed into the pool, and when a thread
/// is ready to process a task, it will get a task from the pool
/// according to the `ScheduleQueue` provided in initialization.
pub struct ThreadPool<Ctx> {
state: Arc<(Mutex<ScheduleState<Ctx>>, Condvar)>,
threads: Vec<JoinHandle<()>>,
task_count: Arc<AtomicUsize>,
}
impl<Ctx> ThreadPool<Ctx>
where
Ctx: Context + 'static,
{
fn new<C: ContextFactory<Ctx>>(
name: String,
num_threads: usize,
tasks_per_tick: usize,
stack_size: Option<usize>,
f: C,
) -> ThreadPool<Ctx> {
assert!(num_threads >= 1);
let state = ScheduleState {
queue: FifoQueue::new(),
stopped: false,
};
let state = Arc::new((Mutex::new(state), Condvar::new()));
let mut threads = Vec::with_capacity(num_threads);
let task_count = Arc::new(AtomicUsize::new(0));
// Threadpool threads
for _ in 0..num_threads {
let state = state.clone();
let task_num = task_count.clone();
let ctx = f.create();
let mut tb = Builder::new().name(name.clone());
if let Some(stack_size) = stack_size {
tb = tb.stack_size(stack_size);
}
let thread = tb.spawn(move || {
let mut worker = Worker::new(state, task_num, tasks_per_tick, ctx);
worker.run();
}).unwrap();
threads.push(thread);
}
ThreadPool {
state: state,
threads: threads,
task_count: task_count,
}
}
pub fn execute<F>(&self, job: F)
where
F: FnOnce(&mut Ctx) + Send + 'static,
Ctx: Context,
{
let task = Task::new(job);
let &(ref lock, ref cvar) = &*self.state;
{
let mut state = lock.lock().unwrap();
if state.stopped {
return;
}
state.queue.push(task);
cvar.notify_one();
}
self.task_count.fetch_add(1, AtomicOrdering::SeqCst);
}
#[inline]
pub fn get_task_count(&self) -> usize {
self.task_count.load(AtomicOrdering::SeqCst)
}
pub fn stop(&mut self) -> Result<(), String> {
let &(ref lock, ref cvar) = &*self.state;
{
let mut state = lock.lock().unwrap();
state.stopped = true;
cvar.notify_all();
}
let mut err_msg = String::new();
for t in self.threads.drain(..) {
if let Err(e) = t.join() {
write!(&mut err_msg, "Failed to join thread with err: {:?};", e).unwrap();
}
}
if !err_msg.is_empty() {
return Err(err_msg);
}
Ok(())
}
}
// Each thread has a worker.
struct Worker<C> {
state: Arc<(Mutex<ScheduleState<C>>, Condvar)>,
task_count: Arc<AtomicUsize>,
tasks_per_tick: usize,
task_counter: usize,
ctx: C,
}
impl<C> Worker<C>
where
C: Context,
{
fn new(
state: Arc<(Mutex<ScheduleState<C>>, Condvar)>,
task_count: Arc<AtomicUsize>,
tasks_per_tick: usize,
ctx: C,
) -> Worker<C> {
Worker {
state: state,
task_count: task_count,
tasks_per_tick: tasks_per_tick,
task_counter: 0,
ctx: ctx,
}
}
fn next_task(&mut self) -> Option<Task<C>> {
let &(ref lock, ref cvar) = &*self.state;
let mut state = lock.lock().unwrap();
let mut timeout = Some(Duration::from_secs(NAP_SECS));
loop {
if state.stopped {
return None;
}
match state.queue.pop() {
Some(t) => {
self.task_counter += 1;
return Some(t);
}
None => {
state = match timeout {
Some(t) => cvar.wait_timeout(state, t).unwrap().0,
None => {
self.task_counter = 0;
self.ctx.on_tick();
cvar.wait(state).unwrap()
}
};
timeout = None;
}
}
}
}
fn run(&mut self) {
loop {
let task = match self.next_task() {
None => return,
Some(t) => t,
};
self.ctx.on_task_started();
(task.task).call_box((&mut self.ctx,));
self.ctx.on_task_finished();
self.task_count.fetch_sub(1, AtomicOrdering::SeqCst);
if self.task_counter == self.tasks_per_tick {
self.task_counter = 0;
self.ctx.on_tick();
}
}
}
}
fn main() {
let mut thread_count: usize = 8;
let args: Vec<String> = env::args().collect();
if args.len() == 2 {
thread_count = args[1].parse().unwrap();
}
println!("Using {} threads", thread_count);
let mut task_pool = ThreadPoolBuilder::with_default_factory(format!("test")).thread_count(thread_count).build();
let (tx, rx) = channel();
thread::spawn(move || {
let mut total_counts = 0;
let mut t = Instant::now();
let interval = Duration::from_secs(1);
while let Ok(index) = rx.recv() {
total_counts += 1;
if t.elapsed() >= interval {
t = Instant::now();
println!("{} QPS", total_counts);
total_counts = 0;
}
}
});
let t = Instant::now();
let run_time = Duration::from_secs(10);
loop {
let tx1 = tx.clone();
task_pool.execute(move |_:&mut DefaultContext|{
tx1.send(0).unwrap();
});
if t.elapsed() >= run_time {
return;
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment