Created
January 30, 2020 03:47
-
-
Save kghose/99f8ba6205c6942a64313a0890363176 to your computer and use it in GitHub Desktop.
A certain awkwardness with C++ inheritance
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/* | |
*/ | |
#include <iostream> | |
#include <vector> | |
const size_t N = 1000; | |
class Matrix { | |
public: | |
Matrix(size_t n, size_t m) | |
: n(n) | |
, m(m) | |
{ | |
data.resize(n * m); | |
} | |
size_t rows() const { return n; } | |
size_t cols() const { return m; } | |
double& get(size_t i, size_t j) { return data[_index(i, j)]; }; | |
const double get(size_t i, size_t j) const { return data[_index(i, j)]; }; | |
protected: | |
virtual size_t _index(size_t i, size_t j) const = 0; | |
size_t n, m; | |
std::vector<double> data; | |
}; | |
class MatrixRowMajor : public Matrix { | |
public: | |
MatrixRowMajor(size_t n, size_t m) | |
: Matrix(n, m) | |
{ | |
} | |
private: | |
size_t _index(size_t i, size_t j) const { return i * m + j; }; | |
}; | |
class MatrixColMajor : public Matrix { | |
public: | |
MatrixColMajor(size_t n, size_t m) | |
: Matrix(n, m) | |
{ | |
} | |
private: | |
size_t _index(size_t i, size_t j) const { return i + j * n; }; | |
}; | |
void mul(const Matrix& lhs, const Matrix& rhs, Matrix& ans) | |
{ | |
for (size_t i = 0; i < lhs.rows(); i++) { | |
for (size_t j = 0; j < rhs.cols(); j++) { | |
double v = 0; | |
for (size_t k = 0; k < lhs.cols(); k++) { | |
v += lhs.get(i, k) * rhs.get(k, j); | |
} | |
ans.get(i, j) = v; | |
} | |
} | |
} | |
MatrixRowMajor operator*(const Matrix& lhs, const Matrix& rhs) | |
{ | |
MatrixRowMajor ans(lhs.rows(), rhs.cols()); | |
mul(lhs, rhs, ans); | |
return ans; | |
} | |
std::ostream& operator<<(std::ostream& out, const Matrix& mat) | |
{ | |
for (size_t i = 0; i < mat.cols(); i++) { | |
for (size_t j = 0; j < mat.rows(); j++) { | |
out << mat.get(i, j) << " "; | |
} | |
out << std::endl; | |
} | |
return out; | |
} | |
// Matrix operator*(const Matrix& lhs, const Matrix& rhs) | |
// { | |
// Matrix ans(lhs.rows(), rhs.cols()); | |
// for (size_t i = 0; i < lhs.rows(); i++) { | |
// for (size_t j = 0; j < rhs.cols(); j++) { | |
// double v = 0; | |
// for (size_t k = 0; k < lhs.cols(); k++) { | |
// v += lhs.get(i, k) * rhs.get(k, j); | |
// } | |
// ans.get(i, j) = v; | |
// } | |
// } | |
// return ans; | |
// } | |
int main(int argc, char* argv[]) | |
{ | |
MatrixRowMajor m1(2, 4); | |
m1.get(0, 1) = 2; | |
MatrixColMajor m2(4, 2); | |
m2.get(1, 0) = 2; | |
auto m3 = m1 * m2; | |
std::cout << m3; | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from abc import ABC, abstractmethod | |
class Matrix(ABC): | |
def __init__(self, n, m): | |
self.n, self.m = n, m | |
self.data = [0 for _ in range(n * m)] | |
def __mul__(self, rhs): | |
ans = type(self)(self.n, rhs.m) | |
for i in range(self.n): | |
for j in range(rhs.m): | |
x = 0 | |
for k in range(self.m): | |
x += self[i, k] * rhs[k, j] | |
ans[i, j] = x | |
return ans | |
@abstractmethod | |
def _index(self, key): | |
pass | |
def __getitem__(self, key): | |
return self.data[self._index(key)] | |
def __setitem__(self, key, value): | |
self.data[self._index(key)] = value | |
def __str__(self): | |
s = [] | |
for i in range(self.n): | |
s += [", ".join([str(self[i, j]) for j in range(self.m)])] | |
return "\n".join(s) | |
class MatrixRowMajor(Matrix): | |
def __init__(self, n, m): | |
super().__init__(n, m) | |
def _index(self, key): | |
return key[0] * self.m + key[1] | |
class MatrixColMajor(Matrix): | |
def __init__(self, n, m): | |
super().__init__(n, m) | |
def _index(self, key): | |
return key[0] + key[1] * self.n | |
def main(): | |
# This correctly raises an exception | |
# m = Matrix(2, 4) | |
m = MatrixRowMajor(2, 4) | |
m[0, 2] = 2 | |
print(m.data) | |
print(m) | |
m = MatrixColMajor(2, 4) | |
m[0, 2] = 2 | |
print(m.data) | |
print(m) | |
m1 = MatrixRowMajor(2, 4) | |
m1[0, 1] = 2 | |
print(m1) | |
m2 = MatrixColMajor(4, 2) | |
m2[1, 0] = 2 | |
print(m2) | |
print((m1 * m2)) | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment