Skip to content

Instantly share code, notes, and snippets.

@ericcornelissen
Last active December 18, 2023 23:15
A Rust implementation of an algorithm that converts a linearly stored m-by-m matrix A into a linear store of the q-by-q submatrices of A.
// SPDX-License-Identifier: MIT-0
/// Converts a linearly stored m-by-m matrix A into a linear store of the q-by-q
/// submatrices of A. For example, given the 4-by-4 matrix:
///
/// ```no-test
/// in = [ a b c d
/// e f g h
/// h i j k
/// l m n o ]
/// ```
///
/// and q=2, the output will be:
///
/// ```no-test
/// out = [ a b e f
/// c d g h
/// h i l m
/// j k n o ]
/// ```
///
/// considering that the 2-by-2 blocks of the original matrix are:
///
/// ```no-test
/// in = [ a b c d
/// e f g h
///
/// h i j k
/// l m n o ]
/// ```
///
/// # Panics
///
/// If the provided matrix isn't an m-by-m matrix, m^2 does not fit in a
/// `usize`, q isn't between 0 and m, or q doesn't divide m.
pub fn linear_sub_matrices<T>(matrix: &[T], m: usize, q: usize) -> Vec<T>
where
T: Clone,
{
let matrix_size = m.checked_mul(m).expect("m^2 must fit in usize");
let block_size = q * q; // must fit if (m^2) fits and prerequisites hold
let row_size = m * q; // must fit if (m^2) fits and prerequisites hold
// --- Prerequisites ---
if matrix.len() != matrix_size {
panic!("matrix must be m-by-m")
}
if q >= m {
panic!("q must be less than m");
}
if q == 0 {
panic!("q must be greater than 0");
}
if m % q != 0 {
panic!("q must divide m");
}
// --- Algorithm ---
let mut base_index = 0;
let mut out = Vec::with_capacity(matrix.len());
for i in 0..matrix_size {
// move current entry to output
let item_index = i % q;
let current_index = base_index + item_index;
let entry = unsafe { matrix.get_unchecked(current_index) };
out.push(entry.clone());
// update the base index (if necessary)
if (i + 1) % row_size == 0 {
// if the matrix row completed continue with the next row
base_index += q;
} else if (i + 1) % block_size == 0 {
// if the block completed, go back to the block start ...
base_index -= (q - 1) * m;
// ... and move over to the start of the next block
base_index += q;
} else if (i + 1) % q == 0 {
// if the block row completed, go to the next line
base_index += m;
}
}
out
}
// The algorithm for a 4-by-4 matrix of 2-by-2 blocks goes as follows let i=0
// and bi=0. We continuously increment `i` by 1. We use ci to denote the current
// index which is always computed as bi+(i%q). In every iteration, the value at
// ci is pushed onto the output matrix. bi will be show as `.` and ci as `*`.
//
// ```no-test
// i=0, bi=0, ci=0+0=0
//
// | 00 01 02 03 |
// .*
// in = | 04 05 06 07 |
// | 08 09 10 11 |
// | 12 13 14 15 |
// ```
//
// ```no-test
// i=1, bi=0, ci=0+1=1
//
// | 00 01 02 03 |
// . *
// in = | 04 05 06 07 |
// | 08 09 10 11 |
// | 12 13 14 15 |
// ```
//
// because (i+1)%q=0, bi=bi+m
//
// ```no-test
// i=2, bi=0+4=4, ci=4+0=4
//
// | 00 01 02 03 |
// in = | 04 05 06 07 |
// .*
// | 08 09 10 11 |
// | 12 13 14 15 |
// ```
//
// ```no-test
// i=3, bi=4, ci=4+1=5
//
// | 00 01 02 03 |
// in = | 04 05 06 07 |
// . *
// | 08 09 10 11 |
// | 12 13 14 15 |
// ```
//
// because (i+1)%(q^2)=0, bi=bi-((q-1)*m)+q
//
// ```no-test
// i=4, bi=4-4+2=2, ci=2+0=2
//
// | 00 01 02 03 |
// .*
// in = | 04 05 06 07 |
// | 08 09 10 11 |
// | 12 13 14 15 |
// ```
//
// ```no-test
// i=5, bi=2, ci=2+1=3
//
// | 00 01 02 03 |
// . *
// in = | 04 05 06 07 |
// | 08 09 10 11 |
// | 12 13 14 15 |
// ```
//
// because (i+1)%q=0, bi=bi+m
//
// ```no-test
// i=6, bi=2+4=6, ci=6+0=6
//
// | 00 01 02 03 |
// in = | 04 05 06 07 |
// .*
// | 08 09 10 11 |
// | 12 13 14 15 |
// ```
//
// ```no-test
// i=7, bi=6, ci=6+1=7
//
// | 00 01 02 03 |
// in = | 04 05 06 07 |
// . *
// | 08 09 10 11 |
// | 12 13 14 15 |
// ```
//
// because (i+1)%(m*q)=0, bi=bi+q
//
// ```no-test
// i=8, bi=6+2=8, ci=8+0=8
//
// | 00 01 02 03 |
// in = | 04 05 06 07 |
// | 08 09 10 11 |
// .*
// | 12 13 14 15 |
// ```
//
// ```no-test
// i=9, bi=8, ci=8+1=9
//
// | 00 01 02 03 |
// in = | 04 05 06 07 |
// | 08 09 10 11 |
// . *
// | 12 13 14 15 |
// ```
//
// because (i+1)%q=0, bi=bi+m
//
// ```no-test
// i=10, bi=8+4=12, ci=12+0=12
//
// | 00 01 02 03 |
// in = | 04 05 06 07 |
// | 08 09 10 11 |
// | 12 13 14 15 |
// .*
// ```
//
// ```no-test
// i=11, bi=12, ci=12+1=13
//
// | 00 01 02 03 |
// in = | 04 05 06 07 |
// | 08 09 10 11 |
// | 12 13 14 15 |
// . *
// ```
//
// because (i+1)%(q^2)=0, bi=bi-((q-1)*m)+q
//
// ```no-test
// i=12, bi=12-4+2=10, ci=10+0=10
//
// | 00 01 02 03 |
// in = | 04 05 06 07 |
// | 08 09 10 11 |
// .*
// | 12 13 14 15 |
// ```
//
// ```no-test
// i=13, bi=10, ci=10+1=11
//
// | 00 01 02 03 |
// in = | 04 05 06 07 |
// | 08 09 10 11 |
// . *
// | 12 13 14 15 |
// ```
//
// because (i+1)%q=0, bi=bi+m
//
// ```no-test
// i=14, bi=10+4=14, ci=14+0=14
//
// | 00 01 02 03 |
// in = | 04 05 06 07 |
// | 08 09 10 11 |
// | 12 13 14 15 |
// .*
// ```
//
// ```no-test
// i=15, bi=10+4=14, ci=14+1=15
//
// | 00 01 02 03 |
// in = | 04 05 06 07 |
// | 08 09 10 11 |
// | 12 13 14 15 |
// . *
// ```
//
// because (i+1)>=(m^2) the algorithm terminates.
#[cfg(test)]
mod tests {
use super::linear_sub_matrices;
#[test]
fn matrix_4x4_block_2x2() {
check(
4,
2,
&vec![
01, 02, 03, 04, //
05, 06, 07, 08, //
09, 10, 11, 12, //
13, 14, 15, 16, //
],
&vec![
01, 02, 05, 06, //
03, 04, 07, 08, //
09, 10, 13, 14, //
11, 12, 15, 16, //
],
);
}
#[test]
fn matrix_6x6_block_2x2() {
check(
6,
2,
&vec![
01, 02, 03, 04, 05, 06, //
07, 08, 09, 10, 11, 12, //
13, 14, 15, 16, 17, 18, //
19, 20, 21, 22, 23, 24, //
25, 26, 27, 28, 29, 30, //
31, 32, 33, 34, 35, 36, //
],
&vec![
01, 02, 07, 08, 03, 04, //
09, 10, 05, 06, 11, 12, //
13, 14, 19, 20, 15, 16, //
21, 22, 17, 18, 23, 24, //
25, 26, 31, 32, 27, 28, //
33, 34, 29, 30, 35, 36, //
],
);
}
#[test]
fn matrix_6x6_block_3x3() {
check(
6,
3,
&vec![
01, 02, 03, 04, 05, 06, //
07, 08, 09, 10, 11, 12, //
13, 14, 15, 16, 17, 18, //
19, 20, 21, 22, 23, 24, //
25, 26, 27, 28, 29, 30, //
31, 32, 33, 34, 35, 36, //
],
&vec![
01, 02, 03, 07, 08, 09, //
13, 14, 15, 04, 05, 06, //
10, 11, 12, 16, 17, 18, //
19, 20, 21, 25, 26, 27, //
31, 32, 33, 22, 23, 24, //
28, 29, 30, 34, 35, 36, //
],
);
}
fn check(m: usize, q: usize, matrix: &[usize], want: &[usize]) {
let got = linear_sub_matrices(matrix, m, q);
assert_eq!(got, want);
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment