Skip to content

Instantly share code, notes, and snippets.

@douglasstarnes
Created May 30, 2018 20:32
Show Gist options
  • Save douglasstarnes/0d99bdb0c354abf55067ed6efd52418e to your computer and use it in GitHub Desktop.
Save douglasstarnes/0d99bdb0c354abf55067ed6efd52418e to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 192,
"metadata": {},
"outputs": [],
"source": [
"class Matrix(object):\n",
" def __init__(self, size, data=[]):\n",
" self.shape = size\n",
" self._rows, self._cols = self.shape\n",
" self._data = data\n",
" \n",
" \n",
" def identity(self):\n",
" if not self._rows == self._cols:\n",
" raise ValueError('Identity matrix must be square')\n",
" self._data = []\n",
" for idx in range(self._rows):\n",
" row = [0] * self._cols\n",
" row[idx] = 1\n",
" self._data.extend(row)\n",
" \n",
" def create_identity(self):\n",
" i = Matrix((self._cols, self._cols))\n",
" i.identity()\n",
" return i\n",
" \n",
" def raw(self):\n",
" return [self._data[x:x+self._cols] for x in range(0, len(self._data), self._cols)]\n",
" \n",
" def column(self, idx):\n",
" return [row[idx] for row in self.raw()]\n",
" \n",
" def zero(self):\n",
" row, col = self.shape\n",
" self._data = [0] * (row * col)\n",
" \n",
" def one(self):\n",
" row, col = self.shape\n",
" self._data = [1] * (row * col)\n",
" \n",
" def dot(self, rhs):\n",
" rows_rhs, cols_rhs = rhs.shape\n",
"\n",
" if not self._cols == rows_rhs:\n",
" message = 'got ({0}, {1})'.format(self._cols, rows_rhs)\n",
" raise ValueError('columns on left must equal rows on right, ' + message)\n",
"\n",
" p = Matrix((self._rows, cols_rhs))\n",
" p.zero()\n",
"\n",
" for row in range(self._rows):\n",
" for col in range(cols_rhs):\n",
" p[row,col] = self._dot_vector(self[row], rhs.column(col))\n",
"\n",
" return p\n",
" \n",
" def transpose(self):\n",
" row, col = self.shape\n",
" self.shape = (col, row)\n",
" self._rows, self._cols = self.shape\n",
" \n",
" def _longest_str(self, str_seq):\n",
" longest = 0\n",
" for s in str_seq:\n",
" longest = max(len(s), longest)\n",
" return longest \n",
" \n",
" def _pad(self, s, total):\n",
" return s + ' ' * abs(len(s) - total)\n",
" \n",
" def _pad_row(self, row, total):\n",
" return ' '.join([self._pad(value, total) for value in row])\n",
" \n",
" def _header(self, w, total):\n",
" l = (total * w) + (w - 1)\n",
" h = ' /' + ' ' * (l - 2) + '\\\\' + '\\n'\n",
" h += '/' + ' ' * l + '\\\\' + '\\n'\n",
" return h\n",
"\n",
" def _footer(self, w, total):\n",
" l = (total * w) + (w - 1)\n",
" f = '\\n' + '\\\\' + ' ' * l + '/'\n",
" f += '\\n' + ' \\\\' + ' ' * (l - 2) + '/'\n",
" return f\n",
" \n",
" def _dot_vector(self, l, r):\n",
" return sum([t[0] * t[1] for t in list(zip(l, r))])\n",
" \n",
" def __str__(self):\n",
" data = [str(item) for item in self._data]\n",
"\n",
" longest = self._longest_str(data)\n",
"\n",
" rows = [data[x:x+self._cols] for x in range(0, len(data), self._cols)]\n",
"\n",
" repr = header(self._cols, longest)\n",
" repr += '\\n'.join(['|' + pad_row(row, longest) + '|' for row in rows])\n",
" repr += footer(self._cols, longest)\n",
"\n",
" return repr\n",
" \n",
" def __getitem__(self, key):\n",
" if isinstance(key, int):\n",
" lower = key * self._cols\n",
" return self._data[lower:lower + self._cols]\n",
" else:\n",
" row, col = key\n",
" return self._data[row * self._cols + col]\n",
" \n",
" def __setitem__(self, key, value):\n",
" if isinstance(key, int):\n",
" raise NotImplementedError()\n",
" row, col = key\n",
" self._data[row * self._cols + col] = value\n",
" \n",
" "
]
},
{
"cell_type": "code",
"execution_count": 201,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" / \\\n",
"/ \\\n",
"|0 1 2|\n",
"|3 4 5|\n",
"\\ /\n",
" \\ /\n"
]
}
],
"source": [
"M = Matrix((2,3), data=list(range(6)))\n",
"print(M)"
]
},
{
"cell_type": "code",
"execution_count": 203,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" / \\\n",
"/ \\\n",
"|0 1 2 3 |\n",
"|4 5 6 7 |\n",
"|8 9 10 11|\n",
"\\ /\n",
" \\ /\n"
]
}
],
"source": [
"N = Matrix((3, 4), data=list(range(12)))\n",
"print(N)"
]
},
{
"cell_type": "code",
"execution_count": 204,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" / \\\n",
"/ \\\n",
"|20 23 26 29|\n",
"|56 68 80 92|\n",
"\\ /\n",
" \\ /\n"
]
}
],
"source": [
"MN = M.dot(N)\n",
"print(MN)"
]
},
{
"cell_type": "code",
"execution_count": 205,
"metadata": {},
"outputs": [
{
"ename": "ValueError",
"evalue": "columns on left must equal rows on right, got (3, 2)",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-205-2e0147828f4b>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mMM\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mM\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdot\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mM\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mMM\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m<ipython-input-192-28be6e69d3dc>\u001b[0m in \u001b[0;36mdot\u001b[0;34m(self, rhs)\u001b[0m\n\u001b[1;32m 35\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_cols\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0mrows_rhs\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 36\u001b[0m \u001b[0mmessage\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m'got ({0}, {1})'\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_cols\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrows_rhs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 37\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mValueError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'columns on left must equal rows on right, '\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mmessage\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 38\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 39\u001b[0m \u001b[0mp\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mMatrix\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_rows\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcols_rhs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mValueError\u001b[0m: columns on left must equal rows on right, got (3, 2)"
]
}
],
"source": [
"MM = M.dot(M)\n",
"print(MM)"
]
},
{
"cell_type": "code",
"execution_count": 206,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" / \\\n",
"/ \\\n",
"|0 1 2|\n",
"|3 4 5|\n",
"\\ /\n",
" \\ /\n"
]
}
],
"source": [
"MI = Matrix((M.shape[1], M.shape[1]))\n",
"MI.identity()\n",
"MMI = M.dot(MI)\n",
"print(MMI)"
]
},
{
"cell_type": "code",
"execution_count": 208,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" / \\\n",
"/ \\\n",
"|0 1 2 3 |\n",
"|4 5 6 7 |\n",
"|8 9 10 11|\n",
"\\ /\n",
" \\ /\n"
]
}
],
"source": [
"NI = N.create_identity()\n",
"NNI = N.dot(NI)\n",
"print(NNI)"
]
},
{
"cell_type": "code",
"execution_count": 209,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" / \\\n",
"/ \\\n",
"|0 0|\n",
"|0 0|\n",
"|0 0|\n",
"\\ /\n",
" \\ /\n"
]
}
],
"source": [
"P = Matrix((3, 2))\n",
"P.zero()\n",
"print(P)"
]
},
{
"cell_type": "code",
"execution_count": 210,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" / \\\n",
"/ \\\n",
"|0 0 0|\n",
"|0 0 0|\n",
"\\ /\n",
" \\ /\n"
]
}
],
"source": [
"P.transpose()\n",
"print(P)"
]
},
{
"cell_type": "code",
"execution_count": 211,
"metadata": {},
"outputs": [
{
"ename": "ValueError",
"evalue": "Identity matrix must be square",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-211-e773846bb0f6>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mP\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0midentity\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;32m<ipython-input-192-28be6e69d3dc>\u001b[0m in \u001b[0;36midentity\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 8\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0midentity\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_rows\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_cols\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 10\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mValueError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'Identity matrix must be square'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 11\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_data\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 12\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0midx\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_rows\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mValueError\u001b[0m: Identity matrix must be square"
]
}
],
"source": [
"P.identity()"
]
},
{
"cell_type": "code",
"execution_count": 212,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]"
]
},
"execution_count": 212,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"N.raw()"
]
},
{
"cell_type": "code",
"execution_count": 213,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[4, 5, 6, 7]"
]
},
"execution_count": 213,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"N[1]"
]
},
{
"cell_type": "code",
"execution_count": 214,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"6"
]
},
"execution_count": 214,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"N[1, 2]"
]
},
{
"cell_type": "code",
"execution_count": 215,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" / \\\n",
"/ \\\n",
"|0 1 2 3 |\n",
"|4 5 10 7 |\n",
"|8 9 10 11|\n",
"\\ /\n",
" \\ /\n"
]
}
],
"source": [
"N[1, 2] = 10\n",
"print(N)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.5"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment