Skip to content

Instantly share code, notes, and snippets.

@SebastiaanYN
Created September 5, 2020 15:28
Show Gist options
  • Save SebastiaanYN/3b4e6979d4b1256c2d86facd8465fe1e to your computer and use it in GitHub Desktop.
Save SebastiaanYN/3b4e6979d4b1256c2d86facd8465fe1e to your computer and use it in GitHub Desktop.
#![feature(min_const_generics)]
use std::ops;
#[derive(Debug)]
struct Matrix<T, const W: usize, const H: usize>([[T; W]; H]);
impl<T, const W: usize, const H: usize> Matrix<T, W, H>
where
T: Copy + Default,
{
fn new() -> Self {
Matrix([[T::default(); W]; H])
}
}
impl<'a, 'b, T, const W: usize, const H: usize> ops::Add<&'b Matrix<T, W, H>>
for &'a Matrix<T, W, H>
where
T: ops::Add<Output = T> + Copy + Default,
{
type Output = Matrix<T, W, H>;
fn add(self, other: &'b Matrix<T, W, H>) -> Self::Output {
let mut out = Matrix::new();
for x in 0..W {
for y in 0..H {
out.0[y][x] = self.0[y][x] + other.0[y][x];
}
}
out
}
}
impl<T, const W: usize, const H: usize> ops::Add for Matrix<T, W, H>
where
T: ops::Add<Output = T> + Copy + Default,
{
type Output = Self;
fn add(self, other: Self) -> Self {
&self + &other
}
}
impl<'a, 'b, T, const M: usize, const N: usize, const P: usize> ops::Mul<&'b Matrix<T, N, P>>
for &'a Matrix<T, M, N>
where
T: ops::Mul<Output = T> + ops::AddAssign + Copy + Default,
{
type Output = Matrix<T, M, P>;
fn mul(self, other: &'b Matrix<T, N, P>) -> Self::Output {
let mut out = Matrix::new();
for i in 0..M {
for j in 0..P {
for k in 0..N {
out.0[j][i] += self.0[k][i] * other.0[j][k]
}
}
}
out
}
}
impl<T, const M: usize, const N: usize, const P: usize> ops::Mul<Matrix<T, N, P>>
for Matrix<T, M, N>
where
T: ops::Mul<Output = T> + ops::AddAssign + Copy + Default,
{
type Output = Matrix<T, M, P>;
fn mul(self, other: Matrix<T, N, P>) -> Self::Output {
&self * &other
}
}
fn main() {
let x = Matrix([[1, 2, 3], [4, 5, 6]]);
let y = Matrix([[6, 5, 4], [3, 2, 1]]);
println!("{:?}", &x + &y);
println!("{:?}", x + y);
let x = Matrix([[1, 1, 1], [2, 2, 2], [3, 3, 3], [4, 4, 4]]);
let y = Matrix([[1, 1, 1, 1], [2, 2, 2, 2], [3, 3, 3, 3]]);
println!("{:?}", &x * &y);
println!("{:?}", x * y);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment