Created
March 15, 2023 17:28
-
-
Save michaelchughes/07492ecdc61df96d7fef0bd2744ade83 to your computer and use it in GitHub Desktop.
Reflection on how to pick tracts to optimize BPR
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"id": "00bf1b63", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import numpy as np\n", | |
"import pandas as pd\n", | |
"import scipy.stats" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"id": "15a561d5", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"np.set_printoptions(precision=4, suppress=False)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"id": "db6ed9af", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"pd.options.display.float_format = '{:,.4g}'.format # show 4 digits of precision" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"id": "7515eeac", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import matplotlib.pyplot as plt\n", | |
"import seaborn as sns\n", | |
"sns.set_style(\"whitegrid\")\n", | |
"sns.set_context(\"notebook\", font_scale=1.25)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "5ccf48f6", | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"id": "456a43d2", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class ZeroInflatedDist(object):\n", | |
" \n", | |
" def __init__(self, dist, zero_proba):\n", | |
" self.dist = dist\n", | |
" self.zero_proba = float(zero_proba)\n", | |
" \n", | |
" def rvs(self, size=1, random_state=np.random):\n", | |
" vals = np.atleast_1d(np.round(self.dist.rvs(size=size, random_state=random_state)))\n", | |
" zmask = random_state.rand(size) < self.zero_proba\n", | |
" vals[zmask] = 0\n", | |
" return np.maximum(0, vals)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"id": "a9f9c48f", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class QuantizedNormal(object):\n", | |
" \n", | |
" def __init__(self, loc, scale):\n", | |
" self.dist = scipy.stats.norm(loc, scale)\n", | |
" \n", | |
" def rvs(self, *args, **kwargs):\n", | |
" vals = np.atleast_1d(np.round(self.dist.rvs(*args, **kwargs)))\n", | |
" return np.maximum(0, vals)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"id": "0bbb1832", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"''" | |
] | |
}, | |
"execution_count": 10, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"consistent_3 = [QuantizedNormal(7, 0.1) for _ in range(3)]\n", | |
"\n", | |
"highvar_3 = [ZeroInflatedDist(QuantizedNormal(10, 0.1), 0.3) for _ in range(3)]\n", | |
"\n", | |
"powerball_3 = [ZeroInflatedDist(QuantizedNormal(100, 0.1), 0.9) for _ in range(3)]\n", | |
"\n", | |
"dist_N = consistent_3 + highvar_3 +powerball_3\n", | |
"\n", | |
"\n", | |
"'''\n", | |
"poisson_N = [scipy.stats.poisson(k) for k in range(1, 4)]\n", | |
"smallvar_norm_N = [QuantizedNormal(k + 0.5, 0.1) for k in range(1, 4)]\n", | |
"bigvar_norm_N = [QuantizedNormal(k + 0.5, 5.0) for k in range(1, 4)]\n", | |
"\n", | |
"consistent_3 = [QuantizedNormal(7, 0.1) for _ in range(3)]\n", | |
"\n", | |
"highvar_3 = [ZeroInflatedDist(QuantizedNormal(10, 0.1), 0.3) for _ in range(3)]\n", | |
"\n", | |
"powerball_3 = [ZeroInflatedDist(QuantizedNormal(100, 0.1), 0.9) for _ in range(3)]\n", | |
"\n", | |
"dist_N = poisson_N + smallvar_norm_N + bigvar_norm_N + consistent_3 + highvar_3 +powerball_3\n", | |
"'''\n", | |
";" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"id": "450389f7", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7.,\n", | |
" 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7.,\n", | |
" 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7.,\n", | |
" 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7.,\n", | |
" 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7.,\n", | |
" 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7.])" | |
] | |
}, | |
"execution_count": 11, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"consistent_3[0].rvs(size=100)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"id": "91ae5166", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([ 0., 10., 0., 10., 10., 10., 0., 0., 10., 10., 0., 0., 10.,\n", | |
" 0., 10., 10., 0., 0., 10., 10., 10., 10., 10., 0., 0., 0.,\n", | |
" 0., 10., 0., 0., 0., 10., 10., 10., 10., 10., 10., 10., 10.,\n", | |
" 10., 10., 10., 0., 0., 0., 10., 10., 10., 10., 10., 10., 10.,\n", | |
" 10., 10., 10., 0., 10., 0., 0., 10., 10., 0., 0., 0., 10.,\n", | |
" 10., 10., 10., 10., 10., 10., 0., 10., 10., 10., 10., 10., 0.,\n", | |
" 10., 10., 0., 0., 0., 10., 0., 10., 0., 10., 10., 10., 10.,\n", | |
" 10., 0., 0., 10., 0., 0., 10., 10., 10.])" | |
] | |
}, | |
"execution_count": 12, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"highvar_3[0].rvs(size=100)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"id": "04e22840", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([ 0., 0., 0., 0., 100., 0., 0., 0., 0., 100., 0.,\n", | |
" 0., 0., 0., 0., 100., 0., 0., 0., 0., 0., 0.,\n", | |
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", | |
" 0., 100., 0., 0., 100., 0., 0., 0., 0., 0., 0.,\n", | |
" 100., 0., 0., 0., 0., 0., 0., 0., 0., 100., 0.,\n", | |
" 0., 0., 100., 0., 0., 100., 0., 0., 0., 0., 0.,\n", | |
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", | |
" 100., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", | |
" 0., 0., 0., 100., 100., 0., 0., 0., 0., 0., 0.,\n", | |
" 0.])" | |
] | |
}, | |
"execution_count": 13, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"powerball_3[0].rvs(size=100)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 15, | |
"id": "b257f88a", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"7.0" | |
] | |
}, | |
"execution_count": 15, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"consistent_3[0].rvs(size=10000).mean()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 14, | |
"id": "86ae273f", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"9.22" | |
] | |
}, | |
"execution_count": 14, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"powerball_3[0].rvs(size=10000).mean()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 16, | |
"id": "eb45729d", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def calc_bpr_many_trials(\n", | |
" dist_N, K=3, n_trials=10000, seed=101,\n", | |
" strategy='pick_mean',\n", | |
" percentile_as_frac=0.95):\n", | |
" N = len(dist_N)\n", | |
" y_RN = np.zeros((n_trials, N), dtype=np.int32)\n", | |
" for n, dist in enumerate(dist_N):\n", | |
" random_state = np.random.RandomState(10000 * seed + n)\n", | |
" y_RN[:, n] = dist.rvs(size=n_trials, random_state=random_state)\n", | |
" if strategy.count('cross_ratio'):\n", | |
" S = 100*n_trials\n", | |
" y_SN = np.zeros((S, N))\n", | |
" sum_str_N = [None for _ in range(N)]\n", | |
" for n, dist in enumerate(dist_N):\n", | |
" random_state = np.random.RandomState(10000 * seed + n)\n", | |
" y_SN[:,n] = dist.rvs(size=S, random_state=random_state)\n", | |
" sum_str_N[n] = \" \".join(['%.1f' % np.percentile(y_SN[:,n], p)\n", | |
" for p in [0, 10, 50, 90, 100]])\n", | |
" ratio_N = np.mean(y_SN / np.sum(y_SN, axis=1, keepdims=1), axis=0)\n", | |
" assert ratio_N.shape == (N,)\n", | |
" selected_ids_K = np.argsort(-1 * ratio_N)[:K]\n", | |
" for kk in selected_ids_K:\n", | |
" print(sum_str_N[kk])\n", | |
" selected_ids_RK = np.tile(selected_ids_K, (n_trials,1))\n", | |
" if strategy.count('pick'):\n", | |
" score_N = np.zeros(N)\n", | |
" sum_str_N = [None for _ in range(N)]\n", | |
" for n, dist in enumerate(dist_N):\n", | |
" random_state = np.random.RandomState(10000 * seed + n)\n", | |
" y_samples_S = dist.rvs(size=100*n_trials, random_state=random_state)\n", | |
" sum_str_N[n] = \" \".join(['%.1f' % np.percentile(y_samples_S, p)\n", | |
" for p in [0, 10, 50, 90, 100]])\n", | |
" \n", | |
" if strategy == 'pick_mean':\n", | |
" score_N[n] = np.mean(y_samples_S)\n", | |
" elif strategy == 'pick_mean_of_squares':\n", | |
" score_N[n] = np.mean(np.square(y_samples_S))\n", | |
" elif strategy == 'pick_mean_of_sqrt':\n", | |
" score_N[n] = np.mean(np.sqrt(y_samples_S))\n", | |
" elif strategy == 'pick_max':\n", | |
" score_N[n] = np.max(y_samples_S)\n", | |
" elif strategy == 'pick_percentile':\n", | |
" score_N[n] = np.percentile(y_samples_S, percentile_as_frac) \n", | |
" else:\n", | |
" score_N[n] = np.median(y_samples_S)\n", | |
" selected_ids_K = np.argsort(-1 * score_N)[:K]\n", | |
" for kk in selected_ids_K:\n", | |
" print(sum_str_N[kk])\n", | |
" selected_ids_RK = np.tile(selected_ids_K, (n_trials,1))\n", | |
" if strategy == 'guess_random':\n", | |
" random_state = np.random.RandomState(10000 * seed)\n", | |
" selected_ids_RK = np.zeros((n_trials, K), dtype=np.int32)\n", | |
" for trial in range(n_trials):\n", | |
" selected_ids_RK[trial,:] = random_state.permutation(N)[:K]\n", | |
" \n", | |
" yselect_RK = np.take_along_axis(y_RN, selected_ids_RK, axis=1)\n", | |
" topk_ids_RK = np.argsort(-1 * y_RN, axis=1)[:, :K]\n", | |
" ytop_RK = np.take_along_axis(y_RN, topk_ids_RK, axis=1)\n", | |
"\n", | |
" numer_R = np.sum(yselect_RK, axis=1)\n", | |
" denom_R = np.sum(ytop_RK, axis=1)\n", | |
" \n", | |
" assert np.all(numer_R <= denom_R + 1e-10)\n", | |
" \n", | |
" return numer_R / denom_R" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 17, | |
"id": "a1771804", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"y_RN = np.random.poisson(5, size=40).reshape(10, 4)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 18, | |
"id": "0de586dc", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([[ 5, 9, 6, 7],\n", | |
" [ 6, 5, 7, 7],\n", | |
" [ 5, 12, 2, 3],\n", | |
" [ 7, 3, 7, 4],\n", | |
" [ 3, 5, 3, 4],\n", | |
" [ 6, 5, 4, 4],\n", | |
" [ 3, 3, 6, 7],\n", | |
" [ 8, 6, 6, 2],\n", | |
" [10, 4, 6, 3],\n", | |
" [ 5, 5, 5, 6]])" | |
] | |
}, | |
"execution_count": 18, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"y_RN" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 19, | |
"id": "14880209", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"K = 2\n", | |
"topk_ids_RK = np.argsort(-1 * y_RN, axis=1)[:, :K]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 20, | |
"id": "6233e5ee", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([[ 9, 7],\n", | |
" [ 7, 7],\n", | |
" [12, 5],\n", | |
" [ 7, 7],\n", | |
" [ 5, 4],\n", | |
" [ 6, 5],\n", | |
" [ 7, 6],\n", | |
" [ 8, 6],\n", | |
" [10, 6],\n", | |
" [ 6, 5]])" | |
] | |
}, | |
"execution_count": 20, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"topk_ids_RK\n", | |
"np.take_along_axis(y_RN, topk_ids_RK, axis=1)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 22, | |
"id": "9c545471", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"0.0 0.0 0.0 100.0 100.0\n", | |
"0.0 0.0 0.0 100.0 100.0\n", | |
"0.0 0.0 0.0 0.0 101.0\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"0.22923552078356427" | |
] | |
}, | |
"execution_count": 22, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"np.mean(calc_bpr_many_trials(dist_N, K=3, n_trials=100000, strategy='pick_mean'))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 23, | |
"id": "cee1f56c", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"0.4780103310589275" | |
] | |
}, | |
"execution_count": 23, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"np.mean(calc_bpr_many_trials(dist_N, K=3, n_trials=100000, strategy='guess_random'))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 24, | |
"id": "171d8006", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"6.0 7.0 7.0 7.0 8.0\n", | |
"6.0 7.0 7.0 7.0 8.0\n", | |
"6.0 7.0 7.0 7.0 8.0\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"0.6116739632292829" | |
] | |
}, | |
"execution_count": 24, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"np.mean(calc_bpr_many_trials(dist_N, K=3, n_trials=100000, strategy='cross_ratio'))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 18, | |
"id": "7026d9ee", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"7.0 7.0 7.0 7.0 8.0\n", | |
"7.0 7.0 7.0 7.0 7.0\n", | |
"6.0 7.0 7.0 7.0 7.0\n", | |
"0.5943714824065128\n", | |
"7.0 7.0 7.0 7.0 8.0\n", | |
"7.0 7.0 7.0 7.0 7.0\n", | |
"6.0 7.0 7.0 7.0 7.0\n", | |
"0.5943714824065128\n", | |
"0.0 0.0 10.0 10.0 10.0\n", | |
"0.0 0.0 10.0 10.0 10.0\n", | |
"0.0 0.0 10.0 10.0 10.0\n", | |
"0.5741076003753522\n", | |
"0.0 0.0 10.0 10.0 10.0\n", | |
"0.0 0.0 10.0 10.0 10.0\n", | |
"0.0 0.0 10.0 10.0 10.0\n", | |
"0.5741076003753522\n", | |
"0.0 0.0 10.0 10.0 10.0\n", | |
"0.0 0.0 10.0 10.0 10.0\n", | |
"0.0 0.0 10.0 10.0 10.0\n", | |
"0.5741076003753522\n", | |
"0.0 0.0 10.0 10.0 10.0\n", | |
"0.0 0.0 10.0 10.0 10.0\n", | |
"0.0 0.0 10.0 10.0 10.0\n", | |
"0.5741076003753522\n", | |
"0.0 0.0 10.0 10.0 10.0\n", | |
"0.0 0.0 10.0 10.0 10.0\n", | |
"0.0 0.0 10.0 10.0 10.0\n", | |
"0.5741076003753522\n", | |
"0.0 0.0 10.0 10.0 10.0\n", | |
"0.0 0.0 10.0 10.0 10.0\n", | |
"0.0 0.0 10.0 10.0 10.0\n", | |
"0.5741076003753522\n", | |
"0.0 0.0 3.0 10.0 27.0\n", | |
"0.0 0.0 10.0 10.0 10.0\n", | |
"0.0 0.0 10.0 10.0 10.0\n", | |
"0.49775269031932723\n" | |
] | |
} | |
], | |
"source": [ | |
"for perc in [10, 20, 30, 40, 50, 60, 70, 80, 90]:\n", | |
" print(np.mean(calc_bpr_many_trials(\n", | |
" dist_N, K=3, strategy='pick_percentile',\n", | |
" percentile_as_frac=perc)))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 19, | |
"id": "568c8d44", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"7.0 7.0 7.0 7.0 8.0\n", | |
"7.0 7.0 7.0 7.0 7.0\n", | |
"6.0 7.0 7.0 7.0 7.0\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"0.5943714824065128" | |
] | |
}, | |
"execution_count": 19, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"np.mean(calc_bpr_many_trials(dist_N, K=3, strategy='pick_mean_of_sqrt'))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 20, | |
"id": "21dd9509", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"14.0" | |
] | |
}, | |
"execution_count": 20, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"a = np.asarray([1,2,3])\n", | |
"np.square(np.linalg.norm(a))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 21, | |
"id": "03fc9be3", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"14" | |
] | |
}, | |
"execution_count": 21, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"np.sum(np.square(a))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 22, | |
"id": "3d465e51", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([[1., 0., 0.],\n", | |
" [0., 1., 0.],\n", | |
" [0., 0., 1.]])" | |
] | |
}, | |
"execution_count": 22, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"np.identity(3)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 25, | |
"id": "f5c29729", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"7.0 7.0 7.0 7.0 8.0\n", | |
"7.0 7.0 7.0 7.0 7.0\n", | |
"6.0 7.0 7.0 7.0 7.0\n", | |
"0.0 0.0 0.0 100.0 100.0\n", | |
"0.0 0.0 0.0 0.0 100.0\n", | |
"0.0 0.0 0.0 0.0 100.0\n", | |
"0.0 0.0 10.0 10.0 10.0\n", | |
"0.0 0.0 10.0 10.0 11.0\n", | |
"0.0 0.0 10.0 10.0 10.0\n" | |
] | |
} | |
], | |
"source": [ | |
"mnames = ['cross_ratio', 'pick_mean', 'pick_median', 'guess_random']\n", | |
"R = 10000\n", | |
"\n", | |
"scores_MR = np.zeros((4, R))\n", | |
"for mm, method in enumerate(mnames):\n", | |
" scores_MR[mm] = calc_bpr_many_trials(\n", | |
" dist_N, K=3, n_trials=R, strategy=method)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 26, | |
"id": "3a6c5ac8", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([[0.7778, 0.7778, 0.7 , 0.7778, 0.7 , 0.7 , 0.7 , 0.7778, 0.175 , 0.175 ],\n", | |
" [0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0.8333, 0.8333],\n", | |
" [0.7407, 0.7407, 1. , 0.7407, 1. , 1. , 1. , 0.7407, 0.1667, 0.1667],\n", | |
" [0.3704, 0.6296, 0. , 0.5185, 0.8 , 0.6667, 0.9 , 0.2593, 0.1417, 0.1417]])" | |
] | |
}, | |
"execution_count": 26, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"np.set_printoptions(precision=4, linewidth=120)\n", | |
"scores_MR[:, :10]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 27, | |
"id": "338bdbd2", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"winscore_1R = scores_MR.max(axis=0, keepdims=1)\n", | |
"\n", | |
"winners_MR = np.abs(scores_MR - winscore_1R) < 0.02" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 28, | |
"id": "7dfa0bc5", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
" cross_ratio won 4453/10000 trials\n", | |
" pick_mean won 1953/10000 trials\n", | |
" pick_median won 2514/10000 trials\n", | |
" guess_random won 1270/10000 trials\n" | |
] | |
} | |
], | |
"source": [ | |
"for mm, mname in enumerate(mnames):\n", | |
" print(\"%13s won % 6d/%d trials\" % (mname, winners_MR[mm].sum(), R))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 29, | |
"id": "26ca43d1", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([0.21 , 0. , 0.25 , 0.2333])" | |
] | |
}, | |
"execution_count": 29, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"np.percentile(scores_MR / winscore_1R, 20, axis=1)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "593a5188", | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3 (ipykernel)", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.9.2" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment