Skip to content

Instantly share code, notes, and snippets.

@wayslog
Created November 22, 2016 07:47
Show Gist options
  • Save wayslog/a9a869baf04c206104dd81e5f5b93e6b to your computer and use it in GitHub Desktop.
Save wayslog/a9a869baf04c206104dd81e5f5b93e6b to your computer and use it in GitHub Desktop.
use std::{ops, fmt};
#[derive(PartialEq, Debug)]
pub struct Matrix<T> {
data: Vec<T>,
row: usize,
col: usize,
}
impl<T: Copy> Matrix<T> {
/// Creates a new matrix of `row` rows and `col` columns, and initializes
/// the matrix with the elements in `values` in row-major order.
pub fn new(row: usize, col: usize, values: &[T]) -> Matrix<T> {
let mut vec = Vec::new();
for i in values {
vec.push(*i);
}
Matrix {
data: vec,
row: row,
col: col,
}
}
/// Creates a new, empty matrix of `row` rows and `col` columns.
/// `data` contains no element.
pub fn new_empty(row: usize, col: usize) -> Matrix<T> {
Matrix {
data: Vec::new(),
row: row,
col: col,
}
}
/// Returns a shared reference to `data`
pub fn data(&self) -> &Vec<T> {
&self.data
}
/// Returns a mutable reference to `data`
pub fn mut_data(&mut self) -> &mut Vec<T> {
&mut self.data
}
/// Returns the number of rows and columns in the first and second
/// elements of the tuple, respectively.
pub fn size(&self) -> (usize, usize) {
(self.row, self.col)
}
}
impl<T: ops::Add<Output = T> + Copy> ops::Add for Matrix<T> {
type Output = Self;
/// Returns the sum of `self` and `rhs`. If `self.row != rhs.row || self.col != rhs.col`, panic.
fn add(self, rhs: Self) -> Self::Output {
if self.row != rhs.row || self.col != rhs.col {
panic!()
}
let a: usize = self.row * self.col - 1;
let mut vec = Vec::new();
for b in 0..a {
vec.push(self.data[b] + rhs.data[b]);
}
Matrix {
data: vec,
row: self.row,
col: self.col,
}
}
}
impl<T: ops::Sub<Output = T> + Copy> ops::Sub for Matrix<T> {
type Output = Self;
/// Returns the subtraction of `rhs` from `self`.
/// If `self.row != rhs.row || self.col != rhs.col`, panic.
fn sub(self, rhs: Self) -> Self::Output {
if self.row != rhs.row || self.col != rhs.col {
panic!("")
}
let a: usize = self.row * self.col - 1;
let mut vec = Vec::new();
for b in 0..a {
vec.push(self.data[b] - rhs.data[b]);
}
Matrix {
data: vec,
row: self.row,
col: self.col,
}
}
}
impl<T: fmt::Display> fmt::Display for Matrix<T> {
/// Formats the matrix as follows:
/// * Writes each row on a separate line.
/// No empty lines before or after any row.
/// * On each row, writes each element followed by a single space,
/// except no space following the last element of the row.
/// Outputs using `write!(f, ...)`.
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let number = self.row * self.col;
for a in 0..number {
if (a + 1) % self.col != 0 {
write!(f, "{} ", self.data[a]).unwrap();
} else {
write!(f, "{}\n", self.data[a]).unwrap();
}
}
write!{f,"{}",""}
}
}
impl<T: ops::Add<Output = T> + ops::Mul<Output = T> + Copy> ops::Mul for Matrix<T> {
type Output = Self;
/// Returns the multiplication of `self` by `rhs`. If `self.col != rhs.row`, panic.
fn mul(self, rhs: Self) -> Self::Output {
if self.col != rhs.row {
panic!("self.col != rhs.row");
}
let mut nmetric = Vec::<T>::with_capacity(self.row * rhs.col);
for i in 0..self.row {
let mut row = Vec::new();
for j in 0..self.col {
for k in 0..rhs.col {
let nval = self.data[i * self.col + j] * rhs.data[j * rhs.col + k];
if let Some(v) = row.get_mut(k) {
*v = *v + nval;
continue;
}
row.push(nval);
}
}
for &z in &row {
nmetric.push(z);
}
}
Matrix {
data: nmetric,
col: rhs.col,
row: self.row,
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment