Created
May 12, 2022 09:46
-
-
Save Gilles86/fe94395b53844fc35cdd7031bbd4b57d to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"id": "147f4020", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import numpy as np\n", | |
"from tqdm.notebook import tqdm" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "b7e9732c", | |
"metadata": {}, | |
"source": [ | |
"So we split up the problem in two subproblems.\n", | |
"\n", | |
"Problem 1: we need to find 3 sets of 20 pairs that roughly correspond to the given joint probability matrix. Moreover, over all 60 pairs, their frequencies should _exactly_ correspond to the joint probability matrix\n", | |
"\n", | |
"Problem 2: we need to make sure that there are any repetitions in either of the dimensions" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "a3ef90c5", | |
"metadata": {}, | |
"source": [ | |
"# Problem 1" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "4e827ce8", | |
"metadata": {}, | |
"source": [ | |
"We need 3 sub co-occurence matrices that add up to the following summed matrix" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"id": "be4513a7", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([[10, 7, 3],\n", | |
" [ 3, 10, 7],\n", | |
" [ 7, 3, 10]])" | |
] | |
}, | |
"execution_count": 12, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"jp = np.array([[0.5, 0.35, 0.15],\n", | |
" [0.15, 0.5, 0.35],\n", | |
" [0.35, 0.15, 0.5]])\n", | |
"jp = jp/3.\n", | |
"n = jp*60\n", | |
"n = np.round(n).astype(int)\n", | |
"n" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "7688399c", | |
"metadata": {}, | |
"source": [ | |
"This function samples matrices that have only 20 pairs but the co-occurences are _roughly_ equal to the join probabilities" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"id": "9d26d507", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def sample_joint_n():\n", | |
" \n", | |
" high = [4, 3, 3]\n", | |
" medium = [2,2 ,3]\n", | |
" low = [1,1, 1]\n", | |
" \n", | |
" np.random.shuffle(high)\n", | |
" np.random.shuffle(low)\n", | |
" \n", | |
" n = [[high[0], medium[0], low[0]],\n", | |
" [low[1], high[1], medium[1]],\n", | |
" [medium[2], low[2], high[2]],]\n", | |
" \n", | |
" return np.array(n)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 31, | |
"id": "071e6044", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"(array([[4, 2, 1],\n", | |
" [1, 3, 2],\n", | |
" [3, 1, 3]]),\n", | |
" array([[3, 2, 1],\n", | |
" [1, 3, 2],\n", | |
" [3, 1, 4]]),\n", | |
" array([[3, 3, 1],\n", | |
" [1, 4, 3],\n", | |
" [1, 1, 3]]))" | |
] | |
}, | |
"execution_count": 31, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"n1 = sample_joint_n()\n", | |
"n2 = sample_joint_n()\n", | |
"n3 = n - n1 - n2\n", | |
"\n", | |
"\n", | |
"n1, n2, n3" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "acbceda7", | |
"metadata": {}, | |
"source": [ | |
"# Problem 2" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 33, | |
"id": "e55f8d38", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def get_cost(pairs, last_pair=None):\n", | |
" \n", | |
" cost = (pairs[1:] == pairs[:-1]).any(1).sum()\n", | |
"\n", | |
" if last_pair is not None:\n", | |
" \n", | |
" cost += (pairs[0] == last_pair).any()\n", | |
" \n", | |
" return cost" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 38, | |
"id": "727bbcbb", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def get_sequence(n, last_pair=None):\n", | |
" i = 0\n", | |
" \n", | |
" pairs = []\n", | |
" for i, j in zip(*map(np.ravel, np.meshgrid(range(3), range(3)))):\n", | |
" pairs.append([(i, j)] *int(n[i, j]))\n", | |
"\n", | |
" pairs = np.concatenate(pairs)\n", | |
" \n", | |
" np.random.shuffle(pairs)\n", | |
" \n", | |
" with tqdm(range(100000)) as pbar:\n", | |
"\n", | |
" # Get cost of current pairs\n", | |
" cost = get_cost(pairs) \n", | |
" for i in pbar:\n", | |
" pbar.set_description(f'Current cost: {cost}')\n", | |
"\n", | |
" new_pairs = pairs.copy()\n", | |
"\n", | |
" # Flip 3 pairs so we don't get stuck in a local minimum\n", | |
" ix0, ix1, ix2 = np.random.randint(0, len(new_pairs), 3)\n", | |
" new_pairs[ix0], new_pairs[ix1], new_pairs[ix2] = pairs[ix2], pairs[ix0], pairs[ix1]\n", | |
"\n", | |
" new_cost = get_cost(new_pairs, last_pair)\n", | |
"\n", | |
" if new_cost<cost:\n", | |
" pairs = new_pairs\n", | |
" cost = new_cost\n", | |
"\n", | |
" if cost == 0.0:\n", | |
" pbar.set_description(f'Current cost: {cost}')\n", | |
" return pairs\n", | |
" \n", | |
" raise Exception(f'Did not converge: {pairs}')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 45, | |
"id": "85cd25e4", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "5a6afd13075644a98715fdca613e26d9", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
" 0%| | 0/100000 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "ee0dfcea04864ec68c04c00e488134b1", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
" 0%| | 0/100000 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "d68672bd45e248999a4b743f94f0dffc", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
" 0%| | 0/100000 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"seq1 = get_sequence(n1)\n", | |
"seq2 = get_sequence(n2, seq1[-1])\n", | |
"seq3 = get_sequence(n3, seq2[-1])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 46, | |
"id": "b33f304d", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"seq = np.concatenate((seq1, seq2, seq3))" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "6e690d89", | |
"metadata": {}, | |
"source": [ | |
"No repetitions..." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 50, | |
"id": "73d78304", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"0" | |
] | |
}, | |
"execution_count": 50, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"get_cost(seq)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "e4e7bf30", | |
"metadata": {}, | |
"source": [ | |
"This is what it looked like" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 51, | |
"id": "ef2ee1cb", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([[2, 2],\n", | |
" [0, 0],\n", | |
" [1, 1],\n", | |
" [2, 2],\n", | |
" [1, 0],\n", | |
" [0, 2],\n", | |
" [2, 0],\n", | |
" [1, 1],\n", | |
" [0, 0],\n", | |
" [1, 1],\n", | |
" [2, 0],\n", | |
" [1, 2],\n", | |
" [2, 1],\n", | |
" [0, 0],\n", | |
" [2, 2],\n", | |
" [0, 0],\n", | |
" [1, 2],\n", | |
" [2, 0],\n", | |
" [1, 2],\n", | |
" [0, 1],\n", | |
" [2, 0],\n", | |
" [0, 2],\n", | |
" [1, 1],\n", | |
" [2, 0],\n", | |
" [0, 1],\n", | |
" [1, 2],\n", | |
" [2, 0],\n", | |
" [1, 1],\n", | |
" [2, 2],\n", | |
" [0, 1],\n", | |
" [1, 2],\n", | |
" [0, 0],\n", | |
" [2, 2],\n", | |
" [0, 0],\n", | |
" [2, 1],\n", | |
" [0, 0],\n", | |
" [2, 2],\n", | |
" [1, 1],\n", | |
" [2, 2],\n", | |
" [1, 0],\n", | |
" [0, 1],\n", | |
" [1, 2],\n", | |
" [0, 0],\n", | |
" [1, 1],\n", | |
" [2, 2],\n", | |
" [1, 1],\n", | |
" [0, 0],\n", | |
" [1, 2],\n", | |
" [0, 0],\n", | |
" [1, 1],\n", | |
" [2, 2],\n", | |
" [1, 1],\n", | |
" [0, 2],\n", | |
" [2, 1],\n", | |
" [1, 2],\n", | |
" [0, 1],\n", | |
" [2, 0],\n", | |
" [0, 1],\n", | |
" [1, 0],\n", | |
" [2, 2]])" | |
] | |
}, | |
"execution_count": 51, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"seq" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "32c8c653", | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.8.5" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment