Skip to content

Instantly share code, notes, and snippets.

@jstac
Created January 26, 2022 00:11
Show Gist options
  • Save jstac/94ed45b492c6d767ed844a8b4a9d3f18 to your computer and use it in GitHub Desktop.
Save jstac/94ed45b492c6d767ed844a8b4a9d3f18 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"id": "d3dcad31",
"metadata": {},
"source": [
"# Verification of Python OT Lecture"
]
},
{
"cell_type": "code",
"execution_count": 38,
"id": "3870b03a",
"metadata": {},
"outputs": [],
"source": [
"# !pip install --upgrade POT # Install Python Optimal Transport if necessary"
]
},
{
"cell_type": "code",
"execution_count": 39,
"id": "7bbbf84b",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import ot\n",
"from scipy.optimize import linprog"
]
},
{
"cell_type": "markdown",
"id": "407e1d06",
"metadata": {},
"source": [
"Set up primitives as in OT lecture."
]
},
{
"cell_type": "code",
"execution_count": 25,
"id": "08777710",
"metadata": {},
"outputs": [],
"source": [
"m = 3\n",
"n = 5\n",
"\n",
"p = np.array([50, 100, 150])\n",
"q = np.array([25, 115, 60, 30, 70])\n",
"\n",
"C = np.array([[10, 15, 20, 20, 40],\n",
" [20, 40, 15, 30, 30],\n",
" [30, 35, 40, 55, 25]])"
]
},
{
"cell_type": "markdown",
"id": "44b0654f",
"metadata": {},
"source": [
"Vectorize $C$, being sure to use column major order."
]
},
{
"cell_type": "code",
"execution_count": 40,
"id": "9fe1216d",
"metadata": {},
"outputs": [],
"source": [
"C_vec = C.reshape((m*n, 1), order='F')"
]
},
{
"cell_type": "markdown",
"id": "9cf0faad",
"metadata": {},
"source": [
"Let's check it."
]
},
{
"cell_type": "code",
"execution_count": 28,
"id": "f6825b18",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[10],\n",
" [20],\n",
" [30],\n",
" [15],\n",
" [40],\n",
" [35],\n",
" [20],\n",
" [15],\n",
" [40],\n",
" [20],\n",
" [30],\n",
" [55],\n",
" [40],\n",
" [30],\n",
" [25]])"
]
},
"execution_count": 28,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"C_vec"
]
},
{
"cell_type": "markdown",
"id": "346d8448",
"metadata": {},
"source": [
"Now set up the linear program using regular column major calculations."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "dbf2a2f0",
"metadata": {},
"outputs": [],
"source": [
"# Construct matrix A by Kronecker product\n",
"A1 = np.kron(np.ones((1, n)), np.identity(m))\n",
"A2 = np.kron(np.identity(n), np.ones((1, m)))\n",
"A = np.vstack([A1, A2])\n",
"\n",
"# Construct vector b\n",
"b = np.hstack([p, q])\n",
"\n",
"# Solve the primal problem\n",
"res = linprog(C_vec, A_eq=A, b_eq=b, method='Revised simplex')"
]
},
{
"cell_type": "markdown",
"id": "4953245c",
"metadata": {},
"source": [
"We get 7225 for the minimized function value, as in the OT lecture:"
]
},
{
"cell_type": "code",
"execution_count": 41,
"id": "b303f136",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"7225.0"
]
},
"execution_count": 41,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"res.fun"
]
},
{
"cell_type": "markdown",
"id": "ef970aed",
"metadata": {},
"source": [
"The solution comes back to us in vectorized form."
]
},
{
"cell_type": "code",
"execution_count": 42,
"id": "8ac9e7fd",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([15., 10., 0., 35., 0., 80., 0., 60., 0., 0., 30., 0., 0.,\n",
" 0., 70.])"
]
},
"execution_count": 42,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"res.x"
]
},
{
"cell_type": "markdown",
"id": "143d0860",
"metadata": {},
"source": [
"We return it to matrix form, being careful to use column major."
]
},
{
"cell_type": "code",
"execution_count": 43,
"id": "bd9548c2",
"metadata": {},
"outputs": [],
"source": [
"X = res.x.reshape((m,n), order='F')"
]
},
{
"cell_type": "code",
"execution_count": 37,
"id": "a5e28496",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[15., 35., 0., 0., 0.],\n",
" [10., 0., 60., 30., 0.],\n",
" [ 0., 80., 0., 0., 70.]])"
]
},
"execution_count": 37,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X"
]
},
{
"cell_type": "markdown",
"id": "c5a3ecf8",
"metadata": {},
"source": [
"This is the correct solution. We can verify this using the Python OT package:"
]
},
{
"cell_type": "code",
"execution_count": 44,
"id": "603fbe01",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[15, 35, 0, 0, 0],\n",
" [10, 0, 60, 30, 0],\n",
" [ 0, 80, 0, 0, 70]])"
]
},
"execution_count": 44,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X = ot.emd(p, q, C)\n",
"X"
]
},
{
"cell_type": "code",
"execution_count": 45,
"id": "600a2804",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"7225"
]
},
"execution_count": 45,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"total_cost = np.sum(X * C)\n",
"total_cost"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "82c973a8",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.5"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment