Created
January 26, 2022 00:11
-
-
Save jstac/94ed45b492c6d767ed844a8b4a9d3f18 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"id": "d3dcad31", | |
"metadata": {}, | |
"source": [ | |
"# Verification of Python OT Lecture" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 38, | |
"id": "3870b03a", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# !pip install --upgrade POT # Install Python Optimal Transport if necessary" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 39, | |
"id": "7bbbf84b", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import numpy as np\n", | |
"import ot\n", | |
"from scipy.optimize import linprog" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "407e1d06", | |
"metadata": {}, | |
"source": [ | |
"Set up primitives as in OT lecture." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 25, | |
"id": "08777710", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"m = 3\n", | |
"n = 5\n", | |
"\n", | |
"p = np.array([50, 100, 150])\n", | |
"q = np.array([25, 115, 60, 30, 70])\n", | |
"\n", | |
"C = np.array([[10, 15, 20, 20, 40],\n", | |
" [20, 40, 15, 30, 30],\n", | |
" [30, 35, 40, 55, 25]])" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "44b0654f", | |
"metadata": {}, | |
"source": [ | |
"Vectorize $C$, being sure to use column major order." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 40, | |
"id": "9fe1216d", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"C_vec = C.reshape((m*n, 1), order='F')" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "9cf0faad", | |
"metadata": {}, | |
"source": [ | |
"Let's check it." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 28, | |
"id": "f6825b18", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([[10],\n", | |
" [20],\n", | |
" [30],\n", | |
" [15],\n", | |
" [40],\n", | |
" [35],\n", | |
" [20],\n", | |
" [15],\n", | |
" [40],\n", | |
" [20],\n", | |
" [30],\n", | |
" [55],\n", | |
" [40],\n", | |
" [30],\n", | |
" [25]])" | |
] | |
}, | |
"execution_count": 28, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"C_vec" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "346d8448", | |
"metadata": {}, | |
"source": [ | |
"Now set up the linear program using regular column major calculations." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "dbf2a2f0", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# Construct matrix A by Kronecker product\n", | |
"A1 = np.kron(np.ones((1, n)), np.identity(m))\n", | |
"A2 = np.kron(np.identity(n), np.ones((1, m)))\n", | |
"A = np.vstack([A1, A2])\n", | |
"\n", | |
"# Construct vector b\n", | |
"b = np.hstack([p, q])\n", | |
"\n", | |
"# Solve the primal problem\n", | |
"res = linprog(C_vec, A_eq=A, b_eq=b, method='Revised simplex')" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "4953245c", | |
"metadata": {}, | |
"source": [ | |
"We get 7225 for the minimized function value, as in the OT lecture:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 41, | |
"id": "b303f136", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"7225.0" | |
] | |
}, | |
"execution_count": 41, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"res.fun" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "ef970aed", | |
"metadata": {}, | |
"source": [ | |
"The solution comes back to us in vectorized form." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 42, | |
"id": "8ac9e7fd", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([15., 10., 0., 35., 0., 80., 0., 60., 0., 0., 30., 0., 0.,\n", | |
" 0., 70.])" | |
] | |
}, | |
"execution_count": 42, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"res.x" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "143d0860", | |
"metadata": {}, | |
"source": [ | |
"We return it to matrix form, being careful to use column major." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 43, | |
"id": "bd9548c2", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"X = res.x.reshape((m,n), order='F')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 37, | |
"id": "a5e28496", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([[15., 35., 0., 0., 0.],\n", | |
" [10., 0., 60., 30., 0.],\n", | |
" [ 0., 80., 0., 0., 70.]])" | |
] | |
}, | |
"execution_count": 37, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"X" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "c5a3ecf8", | |
"metadata": {}, | |
"source": [ | |
"This is the correct solution. We can verify this using the Python OT package:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 44, | |
"id": "603fbe01", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([[15, 35, 0, 0, 0],\n", | |
" [10, 0, 60, 30, 0],\n", | |
" [ 0, 80, 0, 0, 70]])" | |
] | |
}, | |
"execution_count": 44, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"X = ot.emd(p, q, C)\n", | |
"X" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 45, | |
"id": "600a2804", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"7225" | |
] | |
}, | |
"execution_count": 45, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"total_cost = np.sum(X * C)\n", | |
"total_cost" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "82c973a8", | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.8.5" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment