Skip to content

Instantly share code, notes, and snippets.

@kghose
Created January 30, 2020 03:47
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save kghose/99f8ba6205c6942a64313a0890363176 to your computer and use it in GitHub Desktop.
Save kghose/99f8ba6205c6942a64313a0890363176 to your computer and use it in GitHub Desktop.
A certain awkwardness with C++ inheritance
/*
*/
#include <iostream>
#include <vector>
const size_t N = 1000;
class Matrix {
public:
Matrix(size_t n, size_t m)
: n(n)
, m(m)
{
data.resize(n * m);
}
size_t rows() const { return n; }
size_t cols() const { return m; }
double& get(size_t i, size_t j) { return data[_index(i, j)]; };
const double get(size_t i, size_t j) const { return data[_index(i, j)]; };
protected:
virtual size_t _index(size_t i, size_t j) const = 0;
size_t n, m;
std::vector<double> data;
};
class MatrixRowMajor : public Matrix {
public:
MatrixRowMajor(size_t n, size_t m)
: Matrix(n, m)
{
}
private:
size_t _index(size_t i, size_t j) const { return i * m + j; };
};
class MatrixColMajor : public Matrix {
public:
MatrixColMajor(size_t n, size_t m)
: Matrix(n, m)
{
}
private:
size_t _index(size_t i, size_t j) const { return i + j * n; };
};
void mul(const Matrix& lhs, const Matrix& rhs, Matrix& ans)
{
for (size_t i = 0; i < lhs.rows(); i++) {
for (size_t j = 0; j < rhs.cols(); j++) {
double v = 0;
for (size_t k = 0; k < lhs.cols(); k++) {
v += lhs.get(i, k) * rhs.get(k, j);
}
ans.get(i, j) = v;
}
}
}
MatrixRowMajor operator*(const Matrix& lhs, const Matrix& rhs)
{
MatrixRowMajor ans(lhs.rows(), rhs.cols());
mul(lhs, rhs, ans);
return ans;
}
std::ostream& operator<<(std::ostream& out, const Matrix& mat)
{
for (size_t i = 0; i < mat.cols(); i++) {
for (size_t j = 0; j < mat.rows(); j++) {
out << mat.get(i, j) << " ";
}
out << std::endl;
}
return out;
}
// Matrix operator*(const Matrix& lhs, const Matrix& rhs)
// {
// Matrix ans(lhs.rows(), rhs.cols());
// for (size_t i = 0; i < lhs.rows(); i++) {
// for (size_t j = 0; j < rhs.cols(); j++) {
// double v = 0;
// for (size_t k = 0; k < lhs.cols(); k++) {
// v += lhs.get(i, k) * rhs.get(k, j);
// }
// ans.get(i, j) = v;
// }
// }
// return ans;
// }
int main(int argc, char* argv[])
{
MatrixRowMajor m1(2, 4);
m1.get(0, 1) = 2;
MatrixColMajor m2(4, 2);
m2.get(1, 0) = 2;
auto m3 = m1 * m2;
std::cout << m3;
}
from abc import ABC, abstractmethod
class Matrix(ABC):
def __init__(self, n, m):
self.n, self.m = n, m
self.data = [0 for _ in range(n * m)]
def __mul__(self, rhs):
ans = type(self)(self.n, rhs.m)
for i in range(self.n):
for j in range(rhs.m):
x = 0
for k in range(self.m):
x += self[i, k] * rhs[k, j]
ans[i, j] = x
return ans
@abstractmethod
def _index(self, key):
pass
def __getitem__(self, key):
return self.data[self._index(key)]
def __setitem__(self, key, value):
self.data[self._index(key)] = value
def __str__(self):
s = []
for i in range(self.n):
s += [", ".join([str(self[i, j]) for j in range(self.m)])]
return "\n".join(s)
class MatrixRowMajor(Matrix):
def __init__(self, n, m):
super().__init__(n, m)
def _index(self, key):
return key[0] * self.m + key[1]
class MatrixColMajor(Matrix):
def __init__(self, n, m):
super().__init__(n, m)
def _index(self, key):
return key[0] + key[1] * self.n
def main():
# This correctly raises an exception
# m = Matrix(2, 4)
m = MatrixRowMajor(2, 4)
m[0, 2] = 2
print(m.data)
print(m)
m = MatrixColMajor(2, 4)
m[0, 2] = 2
print(m.data)
print(m)
m1 = MatrixRowMajor(2, 4)
m1[0, 1] = 2
print(m1)
m2 = MatrixColMajor(4, 2)
m2[1, 0] = 2
print(m2)
print((m1 * m2))
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment