Skip to content

Instantly share code, notes, and snippets.

@yoh2
Last active July 25, 2018 15:23
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save yoh2/43723639fb32d5c004056b9fa6045d5c to your computer and use it in GitHub Desktop.
Save yoh2/43723639fb32d5c004056b9fa6045d5c to your computer and use it in GitHub Desktop.
並列化失敗例
use std::ops::{Index, IndexMut};
use std::thread;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
// 適当行列クラス
pub struct Matrix {
nr_cols: usize,
v: Vec<f64>, // len = nr_cols * (暗黙の nr_rows)
}
impl Matrix {
pub fn new(nr_rows: usize, nr_cols: usize) -> Self {
let capacity = nr_rows * nr_cols;
let mut v = Vec::with_capacity(capacity);
for _ in 0..capacity {
v.push(0.0)
}
Self { nr_cols, v }
}
pub fn nr_cols(&self) -> usize {
self.nr_cols
}
pub fn nr_rows(&self) -> usize {
self.v.len() / self.nr_cols
}
fn row_range(&self, row_index: usize) -> (usize, usize) {
let first = row_index * self.nr_cols;
(first, first + self.nr_cols)
}
}
impl Index<usize> for Matrix {
type Output = [f64];
fn index(&self, index: usize) -> &Self::Output {
let (first, last) = self.row_range(index);
&self.v[first..last]
}
}
impl IndexMut<usize> for Matrix {
fn index_mut(&mut self, index: usize) -> &mut Self::Output {
let (first, last) = self.row_range(index);
&mut self.v[first..last]
}
}
const NR_THREADS: usize = 4; // 適当に
pub fn add_matrix<'a>(m1: &'a Matrix, m2: &'a Matrix) -> Matrix {
assert_eq!(m1.nr_rows(), m2.nr_rows());
assert_eq!(m1.nr_cols(), m2.nr_cols());
let nr_cols = m1.nr_cols();
let nr_rows = m1.nr_rows();
let mut result = Matrix::new(nr_rows, nr_cols);
// スレッド計算用
struct Context<'b> {
next_row: AtomicUsize,
// 以下、これじゃダメっぽい。m1 も m2 も巨大なものを想定していてスレッド毎に複製したくない。
m1: &'b Matrix,
m2: &'b Matrix,
result: &'b mut Matrix,
}
// こんな感じなのが作れるといいんだけどダメ
let mut context = Arc::new(Context {
next_row: AtomicUsize::new(0),
... /* ここで m1, m2, result 設定したい */
});
// で、こんな風に複数スレッド立ち上げたい
let join_handles: Vec<_> = (0..NR_THREADS)
.map(|_| context.clone())
.map(|c| thread::spawn(move || {
// スレッドは一行単位でループ
loop {
let row_index = c.next_row.fetch_add(1, Ordering::SeqCst);
if row_index >= nr_rows {
break;
}
let m1_row = c.m1[row_index];
let m2_row = c.m2[row_index];
let result_row = c.m3[row_index];
for col_index in 0..nr_cols {
result_row[col_index] = m1_row[col_index] + m2_row[col_index];
}
}
}))
.collect();
join_handles.into_iter().for_each(|h| h.join().unwrap());
result
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment