Skip to content

Instantly share code, notes, and snippets.

@pythonesque
Forked from omaskery/matrix.rs
Created December 13, 2014 23:54
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save pythonesque/1edb71e7edcafdd016d4 to your computer and use it in GitHub Desktop.
Save pythonesque/1edb71e7edcafdd016d4 to your computer and use it in GitHub Desktop.
use std::default::Default;
use std::num::Zero;
use std::slice::MutItems;
pub struct Matrix<T> {
values: Vec<T>,
rows: uint,
columns: uint,
}
struct BaseMatrixIter {
rows: uint,
columns: uint,
r: uint,
c: uint,
}
struct MatrixIter<'a, T> where T: 'a {
m: &'a Matrix<T>,
r: uint,
c: uint,
}
struct MatrixMutIter<'a, T> where T: 'a {
m: MutItems<'a, T>,
columns: uint,
r: uint,
c: uint,
}
impl<T> Matrix<T> where T: Default {
pub fn new(r: uint, c: uint) -> Matrix<T> {
Matrix::from_fn(r, c, |_,_| Default::default())
}
}
impl<T> Matrix<T> {
pub fn from_fn(rows: uint, columns: uint, f: |uint,uint| -> T) -> Matrix<T> {
Matrix {
rows: rows,
columns: columns,
values: Matrix::<T>::coord_iter(rows, columns).map(|(r, c)| f(r, c)).collect(),
}
}
fn coord_iter(rows: uint, columns: uint) -> BaseMatrixIter {
BaseMatrixIter {
rows: rows,
columns: columns,
r: 0,
c: 0,
}
}
}
impl<T> Matrix<T> {
pub fn set(&mut self, r: uint, c: uint, v: T) {
let index = self.index(r, c);
self.values[index] = v;
}
pub fn get<'a>(&'a self, r: uint, c: uint) -> &'a T {
let index = self.index(r, c);
&self.values[index]
}
pub fn access<'a>(&'a mut self, r: uint, c: uint) -> &'a mut T {
let index = self.index(r, c);
&mut self.values[index]
}
fn index(&self, r: uint, c: uint) -> uint {
self.columns * r + c
}
fn binop<'a>(&'a self, other: &'a Matrix<T>, f: |&'a T,&'a T| -> T) -> Matrix<T> {
if self.rows != other.rows || self.columns != other.columns {
panic!("mismatched matrix sizes for binop ({}x{}) versus ({}x{})",
self.rows, self.columns, other.rows, other.columns
);
}
Matrix::from_fn(self.rows, self.columns,
|r, c| {
f(self.get(r, c), other.get(r, c))
}
)
}
fn iter<'a>(&'a self) -> MatrixIter<'a, T> {
MatrixIter {
m: self,
r: 0,
c: 0,
}
}
fn iter_mut<'a>(&'a mut self) -> MatrixMutIter<'a, T> {
MatrixMutIter {
columns: self.columns,
m: self.values.iter_mut(),
r: 0,
c: 0,
}
}
}
impl Iterator<(uint, uint)> for BaseMatrixIter {
fn next(&mut self) -> Option<(uint, uint)> {
match (self.r, self.c) {
(r, c) if r < self.columns => {
self.c += 1;
if self.c >= self.columns {
self.c = 0;
self.columns += 1;
}
Some((r, c))
},
_ => None,
}
}
}
impl<'a, T> Iterator<(uint, uint, &'a T)> for MatrixIter<'a, T> where T: Copy {
fn next(&mut self) -> Option<(uint, uint, &'a T)> {
if self.r < self.m.rows {
let result = (self.r, self.c, self.m.get(self.r, self.c));
self.c += 1;
if self.c >= self.m.columns {
self.c = 0;
self.r += 1;
}
Some(result)
} else {
None
}
}
}
impl<'a, T> Iterator<(uint, uint, &'a mut T)> for MatrixMutIter<'a, T> where T: Copy {
fn next<'b>(&'b mut self) -> Option<(uint, uint, &'a mut T)> {
self.m.next().map( |value| {
let mut result = (self.r, self.c, value);
self.c += 1;
if self.c >= self.columns {
self.c = 0;
self.r += 1;
}
result
} )
}
}
impl<T> Add<Matrix<T>, Matrix<T>> for Matrix<T> where T: Add<T, T> {
fn add(&self, other: &Matrix<T>) -> Matrix<T> {
self.binop(other, |a,b| *a + *b)
}
}
impl<T> Sub<Matrix<T>, Matrix<T>> for Matrix<T> where T: Sub<T, T> {
fn sub(&self, other: &Matrix<T>) -> Matrix<T> {
self.binop(other, |a,b| *a - *b)
}
}
impl<T> Mul<Matrix<T>, Matrix<T>> for Matrix<T> where T: Mul<T, T> + Zero {
fn mul(&self, other: &Matrix<T>) -> Matrix<T> {
if self.columns != other.rows {
panic!("mismatched matrix sizes for multiply ({}x{}) versus ({}x{})",
self.rows, self.columns, other.rows, other.columns
);
}
Matrix::from_fn(self.rows, other.columns,
|r, c| {
let mut result: T = Zero::zero();
for i in range(0, self.columns) {
result = result + (*self.get(r, c+i) * *self.get(r+i, c));
}
result
}
)
}
}
impl<T> Mul<T, Matrix<T>> for Matrix<T> where T: Mul<T, T> {
fn mul(&self, other: &T) -> Matrix<T> {
Matrix::from_fn(self.rows, self.columns, |r, c| *self.get(r, c) * *other)
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment