Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save Gilles86/fe94395b53844fc35cdd7031bbd4b57d to your computer and use it in GitHub Desktop.
Save Gilles86/fe94395b53844fc35cdd7031bbd4b57d to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 11,
"id": "147f4020",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"from tqdm.notebook import tqdm"
]
},
{
"cell_type": "markdown",
"id": "b7e9732c",
"metadata": {},
"source": [
"So we split up the problem in two subproblems.\n",
"\n",
"Problem 1: we need to find 3 sets of 20 pairs that roughly correspond to the given joint probability matrix. Moreover, over all 60 pairs, their frequencies should _exactly_ correspond to the joint probability matrix\n",
"\n",
"Problem 2: we need to make sure that there are any repetitions in either of the dimensions"
]
},
{
"cell_type": "markdown",
"id": "a3ef90c5",
"metadata": {},
"source": [
"# Problem 1"
]
},
{
"cell_type": "markdown",
"id": "4e827ce8",
"metadata": {},
"source": [
"We need 3 sub co-occurence matrices that add up to the following summed matrix"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "be4513a7",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[10, 7, 3],\n",
" [ 3, 10, 7],\n",
" [ 7, 3, 10]])"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"jp = np.array([[0.5, 0.35, 0.15],\n",
" [0.15, 0.5, 0.35],\n",
" [0.35, 0.15, 0.5]])\n",
"jp = jp/3.\n",
"n = jp*60\n",
"n = np.round(n).astype(int)\n",
"n"
]
},
{
"cell_type": "markdown",
"id": "7688399c",
"metadata": {},
"source": [
"This function samples matrices that have only 20 pairs but the co-occurences are _roughly_ equal to the join probabilities"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "9d26d507",
"metadata": {},
"outputs": [],
"source": [
"def sample_joint_n():\n",
" \n",
" high = [4, 3, 3]\n",
" medium = [2,2 ,3]\n",
" low = [1,1, 1]\n",
" \n",
" np.random.shuffle(high)\n",
" np.random.shuffle(low)\n",
" \n",
" n = [[high[0], medium[0], low[0]],\n",
" [low[1], high[1], medium[1]],\n",
" [medium[2], low[2], high[2]],]\n",
" \n",
" return np.array(n)"
]
},
{
"cell_type": "code",
"execution_count": 31,
"id": "071e6044",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(array([[4, 2, 1],\n",
" [1, 3, 2],\n",
" [3, 1, 3]]),\n",
" array([[3, 2, 1],\n",
" [1, 3, 2],\n",
" [3, 1, 4]]),\n",
" array([[3, 3, 1],\n",
" [1, 4, 3],\n",
" [1, 1, 3]]))"
]
},
"execution_count": 31,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"n1 = sample_joint_n()\n",
"n2 = sample_joint_n()\n",
"n3 = n - n1 - n2\n",
"\n",
"\n",
"n1, n2, n3"
]
},
{
"cell_type": "markdown",
"id": "acbceda7",
"metadata": {},
"source": [
"# Problem 2"
]
},
{
"cell_type": "code",
"execution_count": 33,
"id": "e55f8d38",
"metadata": {},
"outputs": [],
"source": [
"def get_cost(pairs, last_pair=None):\n",
" \n",
" cost = (pairs[1:] == pairs[:-1]).any(1).sum()\n",
"\n",
" if last_pair is not None:\n",
" \n",
" cost += (pairs[0] == last_pair).any()\n",
" \n",
" return cost"
]
},
{
"cell_type": "code",
"execution_count": 38,
"id": "727bbcbb",
"metadata": {},
"outputs": [],
"source": [
"def get_sequence(n, last_pair=None):\n",
" i = 0\n",
" \n",
" pairs = []\n",
" for i, j in zip(*map(np.ravel, np.meshgrid(range(3), range(3)))):\n",
" pairs.append([(i, j)] *int(n[i, j]))\n",
"\n",
" pairs = np.concatenate(pairs)\n",
" \n",
" np.random.shuffle(pairs)\n",
" \n",
" with tqdm(range(100000)) as pbar:\n",
"\n",
" # Get cost of current pairs\n",
" cost = get_cost(pairs) \n",
" for i in pbar:\n",
" pbar.set_description(f'Current cost: {cost}')\n",
"\n",
" new_pairs = pairs.copy()\n",
"\n",
" # Flip 3 pairs so we don't get stuck in a local minimum\n",
" ix0, ix1, ix2 = np.random.randint(0, len(new_pairs), 3)\n",
" new_pairs[ix0], new_pairs[ix1], new_pairs[ix2] = pairs[ix2], pairs[ix0], pairs[ix1]\n",
"\n",
" new_cost = get_cost(new_pairs, last_pair)\n",
"\n",
" if new_cost<cost:\n",
" pairs = new_pairs\n",
" cost = new_cost\n",
"\n",
" if cost == 0.0:\n",
" pbar.set_description(f'Current cost: {cost}')\n",
" return pairs\n",
" \n",
" raise Exception(f'Did not converge: {pairs}')"
]
},
{
"cell_type": "code",
"execution_count": 45,
"id": "85cd25e4",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "5a6afd13075644a98715fdca613e26d9",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/100000 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "ee0dfcea04864ec68c04c00e488134b1",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/100000 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "d68672bd45e248999a4b743f94f0dffc",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/100000 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"seq1 = get_sequence(n1)\n",
"seq2 = get_sequence(n2, seq1[-1])\n",
"seq3 = get_sequence(n3, seq2[-1])"
]
},
{
"cell_type": "code",
"execution_count": 46,
"id": "b33f304d",
"metadata": {},
"outputs": [],
"source": [
"seq = np.concatenate((seq1, seq2, seq3))"
]
},
{
"cell_type": "markdown",
"id": "6e690d89",
"metadata": {},
"source": [
"No repetitions..."
]
},
{
"cell_type": "code",
"execution_count": 50,
"id": "73d78304",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0"
]
},
"execution_count": 50,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"get_cost(seq)"
]
},
{
"cell_type": "markdown",
"id": "e4e7bf30",
"metadata": {},
"source": [
"This is what it looked like"
]
},
{
"cell_type": "code",
"execution_count": 51,
"id": "ef2ee1cb",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[2, 2],\n",
" [0, 0],\n",
" [1, 1],\n",
" [2, 2],\n",
" [1, 0],\n",
" [0, 2],\n",
" [2, 0],\n",
" [1, 1],\n",
" [0, 0],\n",
" [1, 1],\n",
" [2, 0],\n",
" [1, 2],\n",
" [2, 1],\n",
" [0, 0],\n",
" [2, 2],\n",
" [0, 0],\n",
" [1, 2],\n",
" [2, 0],\n",
" [1, 2],\n",
" [0, 1],\n",
" [2, 0],\n",
" [0, 2],\n",
" [1, 1],\n",
" [2, 0],\n",
" [0, 1],\n",
" [1, 2],\n",
" [2, 0],\n",
" [1, 1],\n",
" [2, 2],\n",
" [0, 1],\n",
" [1, 2],\n",
" [0, 0],\n",
" [2, 2],\n",
" [0, 0],\n",
" [2, 1],\n",
" [0, 0],\n",
" [2, 2],\n",
" [1, 1],\n",
" [2, 2],\n",
" [1, 0],\n",
" [0, 1],\n",
" [1, 2],\n",
" [0, 0],\n",
" [1, 1],\n",
" [2, 2],\n",
" [1, 1],\n",
" [0, 0],\n",
" [1, 2],\n",
" [0, 0],\n",
" [1, 1],\n",
" [2, 2],\n",
" [1, 1],\n",
" [0, 2],\n",
" [2, 1],\n",
" [1, 2],\n",
" [0, 1],\n",
" [2, 0],\n",
" [0, 1],\n",
" [1, 0],\n",
" [2, 2]])"
]
},
"execution_count": 51,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"seq"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "32c8c653",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.5"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment