Skip to content

Instantly share code, notes, and snippets.

@kkonevets
Last active February 17, 2022 21:49
Show Gist options
  • Save kkonevets/90ebde3f4a545288f3d8c8cc7bf400fb to your computer and use it in GitHub Desktop.
Save kkonevets/90ebde3f4a545288f3d8c8cc7bf400fb to your computer and use it in GitHub Desktop.
use std::error::Error;
use std::sync::mpsc;
use std::sync::Arc;
use std::thread;
const NTHREADS: usize = 4;
/// Apply function `f` to vector `v` in parallelm keeping the order.
/// If `v` length is less than `threshold` no splitting will occur and no threads be created
fn apply_par<F, T, R>(v: Vec<T>, f: F, threshold: usize) -> Result<Vec<R>, Box<dyn Error>>
where
T: Sync + Send + 'static,
R: Send + Default + 'static,
F: Fn(&T) -> R,
F: Send + Copy + 'static,
{
if v.is_empty() || v.len() < threshold {
return Ok(v.iter().map(|x| f(x)).collect());
}
let chunk_size = if v.len() < NTHREADS {
1
} else {
v.len() / NTHREADS
};
let v_shared = Arc::new(v);
let (tx, rx) = mpsc::channel();
let mut threads = vec![];
for i in 0..v_shared.chunks(chunk_size).count() {
let v_clone = v_shared.clone();
let thread_tx = tx.clone();
threads.push(thread::spawn(move || {
for (j, x) in v_clone
.chunks(chunk_size)
.nth(i)
.unwrap()
.iter()
.enumerate()
{
thread_tx
.send((i * chunk_size + j, f(x)))
.expect("Unable to send on channel");
}
}));
}
let mut result = Vec::with_capacity(v_shared.len());
for _ in 0..v_shared.len() {
result.push(R::default());
}
for _ in 0..v_shared.len() {
let (i, r) = rx.recv()?;
result[i] = r;
}
for th in threads {
th.join().unwrap();
}
return Ok(result);
}
#[cfg(test)]
mod tests {
use super::*;
fn f(x: &i32) -> String {
x.to_string()
}
#[test]
fn test1() {
assert_eq!(apply_par(vec![1, 2, 3], f, 1).unwrap(), vec!["1", "2", "3"]);
}
#[test]
fn test2() {
assert_eq!(apply_par(vec![1, 2, 3], f, 0).unwrap(), vec!["1", "2", "3"]);
}
#[test]
fn test3() {
let res = apply_par(Vec::new(), f, 1).unwrap();
assert_eq!(res.len(), 0);
}
#[test]
fn test4() {
assert_eq!(
apply_par(vec![1, 2, 3], f, 10).unwrap(),
vec!["1", "2", "3"]
);
}
#[test]
fn test5() {
let res = apply_par(vec![1, 2, 3, 4], f, 2).unwrap();
assert_eq!(res, vec!["1", "2", "3", "4"]);
}
#[test]
fn test6() {
assert_eq!(
apply_par(vec![4, 2, 2, 1, 5], f, 2).unwrap(),
vec!["4", "2", "2", "1", "5"]
);
}
#[test]
fn test7() {
assert_eq!(
apply_par(vec![1, 2, 3, 4, 5, 6, 7, 8, 9], f, 2).unwrap(),
vec!["1", "2", "3", "4", "5", "6", "7", "8", "9"]
);
}
}
fn main() {}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment