Skip to content

Instantly share code, notes, and snippets.

@rust-play
Created October 18, 2019 21:44
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save rust-play/a489d6374fd82aff6e9681a736868718 to your computer and use it in GitHub Desktop.
Save rust-play/a489d6374fd82aff6e9681a736868718 to your computer and use it in GitHub Desktop.
Code shared from the Rust Playground
use std::ops::Mul;
#[derive(Debug)]
pub struct Matrix<T> {
data: Box<[T]>,
width: usize,
}
impl<T: Copy + Default> Matrix<T> {
pub fn new(data: impl IntoIterator<Item = T>, width: usize) -> Self {
let data: Box<[T]> = data.into_iter().collect();
assert!(data.len() % width == 0);
Matrix { data, width }
}
pub fn height(&self) -> usize {
self.data.len() / self.width
}
pub fn rows(&self) -> MatrixRows<T> {
MatrixRows {
matrix: &self,
current: 0,
}
}
pub fn windows(&self, width: usize, height: usize) -> MatrixWindows<T> {
MatrixWindows {
matrix: &self,
current_column: 0,
current_row: 0,
width,
height,
}
}
pub fn get_row(&self, index: usize) -> Option<Box<[T]>> {
let data = &self.data;
if index < self.height() {
Some(
(index * self.width..(index + 1) * self.width)
.map(|i| data[i])
.collect(),
)
} else {
None
}
}
pub fn get_window(
&self,
row_index: usize,
column_index: usize,
width: usize,
height: usize,
) -> Option<Matrix<T>> {
let data = &self.data;
if row_index + height <= self.height() && column_index + width <= self.width {
let mut entries = Vec::with_capacity(width * height);
for row in row_index..row_index + height {
let current_row =
&data[row * self.width + column_index..row * self.width + column_index + width];
entries.extend(current_row.iter());
}
assert_eq!(entries.len(), width * height);
Some(Matrix::new(entries, width))
} else {
None
}
}
pub fn padded(&self, padding: usize) -> Matrix<T> {
let mut data = Vec::new();
let width = &self.width + 2 * padding;
data.extend(vec![T::default(); width * padding]);
for row in self.rows() {
data.extend(vec![T::default(); padding]);
data.extend(row.iter());
data.extend(vec![T::default(); padding]);
}
data.extend(vec![T::default(); width * padding]);
Matrix::new(data, width)
}
pub fn convolution(&self, &conv_map: &Matrix<T>) -> Matrix<T> {
assert_eq!(conv_map.height(), conv_map.width);
let data = Vec::with_capacity(self.height() * self.width);
let matrix = self.padded((conv_map.height() - 1) / 2);
for cell in matrix.windows(conv_map.width, conv_map.height()) {
let conv_mat = (cell * &conv_map).data.sum();
}
Matrix::new(vec![], 0)
}
}
impl<'a, 'b, T: Mul<Output = T> + Copy + Default> Mul<&'b Matrix<T>> for &'a Matrix<T> {
type Output = Matrix<T>;
fn mul(self, rhs: &Matrix<T>) -> Self::Output {
Matrix::new(
self.data
.into_iter()
.zip(rhs.data.iter())
.map(|(u, v)| *u * *v),
self.width,
)
}
}
impl<'b, T: Mul<Output = T> + Copy + Default> Mul<&'b Matrix<T>> for Matrix<T> {
type Output = Matrix<T>;
fn mul(self, rhs: &Matrix<T>) -> Self::Output {
&self * &rhs
}
}
#[derive(Clone)]
pub struct MatrixRows<'a, T> {
matrix: &'a Matrix<T>,
current: usize,
}
impl<'a, T: Copy + Default> Iterator for MatrixRows<'a, T> {
type Item = Box<[T]>;
fn next(&mut self) -> Option<Self::Item> {
let row = self.matrix.get_row(self.current);
self.current += 1;
row
}
}
pub struct MatrixWindows<'a, T> {
matrix: &'a Matrix<T>,
current_row: usize,
current_column: usize,
width: usize,
height: usize,
}
impl<'a, T: Copy + Default> Iterator for MatrixWindows<'a, T> {
type Item = Matrix<T>;
fn next(&mut self) -> Option<Self::Item> {
let window = self.matrix.get_window(
self.current_row,
self.current_column,
self.width,
self.height,
);
self.current_column += 1;
if self.current_column > self.matrix.width - self.width {
self.current_column = 0;
self.current_row += 1;
}
window
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment