Skip to content

Instantly share code, notes, and snippets.

@michaelchughes
Created March 15, 2023 17:28
Show Gist options
  • Save michaelchughes/07492ecdc61df96d7fef0bd2744ade83 to your computer and use it in GitHub Desktop.
Save michaelchughes/07492ecdc61df96d7fef0bd2744ade83 to your computer and use it in GitHub Desktop.
Reflection on how to pick tracts to optimize BPR
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "00bf1b63",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import pandas as pd\n",
"import scipy.stats"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "15a561d5",
"metadata": {},
"outputs": [],
"source": [
"np.set_printoptions(precision=4, suppress=False)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "db6ed9af",
"metadata": {},
"outputs": [],
"source": [
"pd.options.display.float_format = '{:,.4g}'.format # show 4 digits of precision"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "7515eeac",
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"import seaborn as sns\n",
"sns.set_style(\"whitegrid\")\n",
"sns.set_context(\"notebook\", font_scale=1.25)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5ccf48f6",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 5,
"id": "456a43d2",
"metadata": {},
"outputs": [],
"source": [
"class ZeroInflatedDist(object):\n",
" \n",
" def __init__(self, dist, zero_proba):\n",
" self.dist = dist\n",
" self.zero_proba = float(zero_proba)\n",
" \n",
" def rvs(self, size=1, random_state=np.random):\n",
" vals = np.atleast_1d(np.round(self.dist.rvs(size=size, random_state=random_state)))\n",
" zmask = random_state.rand(size) < self.zero_proba\n",
" vals[zmask] = 0\n",
" return np.maximum(0, vals)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "a9f9c48f",
"metadata": {},
"outputs": [],
"source": [
"class QuantizedNormal(object):\n",
" \n",
" def __init__(self, loc, scale):\n",
" self.dist = scipy.stats.norm(loc, scale)\n",
" \n",
" def rvs(self, *args, **kwargs):\n",
" vals = np.atleast_1d(np.round(self.dist.rvs(*args, **kwargs)))\n",
" return np.maximum(0, vals)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "0bbb1832",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"''"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"consistent_3 = [QuantizedNormal(7, 0.1) for _ in range(3)]\n",
"\n",
"highvar_3 = [ZeroInflatedDist(QuantizedNormal(10, 0.1), 0.3) for _ in range(3)]\n",
"\n",
"powerball_3 = [ZeroInflatedDist(QuantizedNormal(100, 0.1), 0.9) for _ in range(3)]\n",
"\n",
"dist_N = consistent_3 + highvar_3 +powerball_3\n",
"\n",
"\n",
"'''\n",
"poisson_N = [scipy.stats.poisson(k) for k in range(1, 4)]\n",
"smallvar_norm_N = [QuantizedNormal(k + 0.5, 0.1) for k in range(1, 4)]\n",
"bigvar_norm_N = [QuantizedNormal(k + 0.5, 5.0) for k in range(1, 4)]\n",
"\n",
"consistent_3 = [QuantizedNormal(7, 0.1) for _ in range(3)]\n",
"\n",
"highvar_3 = [ZeroInflatedDist(QuantizedNormal(10, 0.1), 0.3) for _ in range(3)]\n",
"\n",
"powerball_3 = [ZeroInflatedDist(QuantizedNormal(100, 0.1), 0.9) for _ in range(3)]\n",
"\n",
"dist_N = poisson_N + smallvar_norm_N + bigvar_norm_N + consistent_3 + highvar_3 +powerball_3\n",
"'''\n",
";"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "450389f7",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7.,\n",
" 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7.,\n",
" 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7.,\n",
" 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7.,\n",
" 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7.,\n",
" 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7.])"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"consistent_3[0].rvs(size=100)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "91ae5166",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([ 0., 10., 0., 10., 10., 10., 0., 0., 10., 10., 0., 0., 10.,\n",
" 0., 10., 10., 0., 0., 10., 10., 10., 10., 10., 0., 0., 0.,\n",
" 0., 10., 0., 0., 0., 10., 10., 10., 10., 10., 10., 10., 10.,\n",
" 10., 10., 10., 0., 0., 0., 10., 10., 10., 10., 10., 10., 10.,\n",
" 10., 10., 10., 0., 10., 0., 0., 10., 10., 0., 0., 0., 10.,\n",
" 10., 10., 10., 10., 10., 10., 0., 10., 10., 10., 10., 10., 0.,\n",
" 10., 10., 0., 0., 0., 10., 0., 10., 0., 10., 10., 10., 10.,\n",
" 10., 0., 0., 10., 0., 0., 10., 10., 10.])"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"highvar_3[0].rvs(size=100)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "04e22840",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([ 0., 0., 0., 0., 100., 0., 0., 0., 0., 100., 0.,\n",
" 0., 0., 0., 0., 100., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 100., 0., 0., 100., 0., 0., 0., 0., 0., 0.,\n",
" 100., 0., 0., 0., 0., 0., 0., 0., 0., 100., 0.,\n",
" 0., 0., 100., 0., 0., 100., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 100., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 100., 100., 0., 0., 0., 0., 0., 0.,\n",
" 0.])"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"powerball_3[0].rvs(size=100)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "b257f88a",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"7.0"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"consistent_3[0].rvs(size=10000).mean()"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "86ae273f",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"9.22"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"powerball_3[0].rvs(size=10000).mean()"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "eb45729d",
"metadata": {},
"outputs": [],
"source": [
"def calc_bpr_many_trials(\n",
" dist_N, K=3, n_trials=10000, seed=101,\n",
" strategy='pick_mean',\n",
" percentile_as_frac=0.95):\n",
" N = len(dist_N)\n",
" y_RN = np.zeros((n_trials, N), dtype=np.int32)\n",
" for n, dist in enumerate(dist_N):\n",
" random_state = np.random.RandomState(10000 * seed + n)\n",
" y_RN[:, n] = dist.rvs(size=n_trials, random_state=random_state)\n",
" if strategy.count('cross_ratio'):\n",
" S = 100*n_trials\n",
" y_SN = np.zeros((S, N))\n",
" sum_str_N = [None for _ in range(N)]\n",
" for n, dist in enumerate(dist_N):\n",
" random_state = np.random.RandomState(10000 * seed + n)\n",
" y_SN[:,n] = dist.rvs(size=S, random_state=random_state)\n",
" sum_str_N[n] = \" \".join(['%.1f' % np.percentile(y_SN[:,n], p)\n",
" for p in [0, 10, 50, 90, 100]])\n",
" ratio_N = np.mean(y_SN / np.sum(y_SN, axis=1, keepdims=1), axis=0)\n",
" assert ratio_N.shape == (N,)\n",
" selected_ids_K = np.argsort(-1 * ratio_N)[:K]\n",
" for kk in selected_ids_K:\n",
" print(sum_str_N[kk])\n",
" selected_ids_RK = np.tile(selected_ids_K, (n_trials,1))\n",
" if strategy.count('pick'):\n",
" score_N = np.zeros(N)\n",
" sum_str_N = [None for _ in range(N)]\n",
" for n, dist in enumerate(dist_N):\n",
" random_state = np.random.RandomState(10000 * seed + n)\n",
" y_samples_S = dist.rvs(size=100*n_trials, random_state=random_state)\n",
" sum_str_N[n] = \" \".join(['%.1f' % np.percentile(y_samples_S, p)\n",
" for p in [0, 10, 50, 90, 100]])\n",
" \n",
" if strategy == 'pick_mean':\n",
" score_N[n] = np.mean(y_samples_S)\n",
" elif strategy == 'pick_mean_of_squares':\n",
" score_N[n] = np.mean(np.square(y_samples_S))\n",
" elif strategy == 'pick_mean_of_sqrt':\n",
" score_N[n] = np.mean(np.sqrt(y_samples_S))\n",
" elif strategy == 'pick_max':\n",
" score_N[n] = np.max(y_samples_S)\n",
" elif strategy == 'pick_percentile':\n",
" score_N[n] = np.percentile(y_samples_S, percentile_as_frac) \n",
" else:\n",
" score_N[n] = np.median(y_samples_S)\n",
" selected_ids_K = np.argsort(-1 * score_N)[:K]\n",
" for kk in selected_ids_K:\n",
" print(sum_str_N[kk])\n",
" selected_ids_RK = np.tile(selected_ids_K, (n_trials,1))\n",
" if strategy == 'guess_random':\n",
" random_state = np.random.RandomState(10000 * seed)\n",
" selected_ids_RK = np.zeros((n_trials, K), dtype=np.int32)\n",
" for trial in range(n_trials):\n",
" selected_ids_RK[trial,:] = random_state.permutation(N)[:K]\n",
" \n",
" yselect_RK = np.take_along_axis(y_RN, selected_ids_RK, axis=1)\n",
" topk_ids_RK = np.argsort(-1 * y_RN, axis=1)[:, :K]\n",
" ytop_RK = np.take_along_axis(y_RN, topk_ids_RK, axis=1)\n",
"\n",
" numer_R = np.sum(yselect_RK, axis=1)\n",
" denom_R = np.sum(ytop_RK, axis=1)\n",
" \n",
" assert np.all(numer_R <= denom_R + 1e-10)\n",
" \n",
" return numer_R / denom_R"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "a1771804",
"metadata": {},
"outputs": [],
"source": [
"y_RN = np.random.poisson(5, size=40).reshape(10, 4)"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "0de586dc",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[ 5, 9, 6, 7],\n",
" [ 6, 5, 7, 7],\n",
" [ 5, 12, 2, 3],\n",
" [ 7, 3, 7, 4],\n",
" [ 3, 5, 3, 4],\n",
" [ 6, 5, 4, 4],\n",
" [ 3, 3, 6, 7],\n",
" [ 8, 6, 6, 2],\n",
" [10, 4, 6, 3],\n",
" [ 5, 5, 5, 6]])"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"y_RN"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "14880209",
"metadata": {},
"outputs": [],
"source": [
"K = 2\n",
"topk_ids_RK = np.argsort(-1 * y_RN, axis=1)[:, :K]"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "6233e5ee",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[ 9, 7],\n",
" [ 7, 7],\n",
" [12, 5],\n",
" [ 7, 7],\n",
" [ 5, 4],\n",
" [ 6, 5],\n",
" [ 7, 6],\n",
" [ 8, 6],\n",
" [10, 6],\n",
" [ 6, 5]])"
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"topk_ids_RK\n",
"np.take_along_axis(y_RN, topk_ids_RK, axis=1)"
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "9c545471",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.0 0.0 0.0 100.0 100.0\n",
"0.0 0.0 0.0 100.0 100.0\n",
"0.0 0.0 0.0 0.0 101.0\n"
]
},
{
"data": {
"text/plain": [
"0.22923552078356427"
]
},
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"np.mean(calc_bpr_many_trials(dist_N, K=3, n_trials=100000, strategy='pick_mean'))"
]
},
{
"cell_type": "code",
"execution_count": 23,
"id": "cee1f56c",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.4780103310589275"
]
},
"execution_count": 23,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"np.mean(calc_bpr_many_trials(dist_N, K=3, n_trials=100000, strategy='guess_random'))"
]
},
{
"cell_type": "code",
"execution_count": 24,
"id": "171d8006",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"6.0 7.0 7.0 7.0 8.0\n",
"6.0 7.0 7.0 7.0 8.0\n",
"6.0 7.0 7.0 7.0 8.0\n"
]
},
{
"data": {
"text/plain": [
"0.6116739632292829"
]
},
"execution_count": 24,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"np.mean(calc_bpr_many_trials(dist_N, K=3, n_trials=100000, strategy='cross_ratio'))"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "7026d9ee",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"7.0 7.0 7.0 7.0 8.0\n",
"7.0 7.0 7.0 7.0 7.0\n",
"6.0 7.0 7.0 7.0 7.0\n",
"0.5943714824065128\n",
"7.0 7.0 7.0 7.0 8.0\n",
"7.0 7.0 7.0 7.0 7.0\n",
"6.0 7.0 7.0 7.0 7.0\n",
"0.5943714824065128\n",
"0.0 0.0 10.0 10.0 10.0\n",
"0.0 0.0 10.0 10.0 10.0\n",
"0.0 0.0 10.0 10.0 10.0\n",
"0.5741076003753522\n",
"0.0 0.0 10.0 10.0 10.0\n",
"0.0 0.0 10.0 10.0 10.0\n",
"0.0 0.0 10.0 10.0 10.0\n",
"0.5741076003753522\n",
"0.0 0.0 10.0 10.0 10.0\n",
"0.0 0.0 10.0 10.0 10.0\n",
"0.0 0.0 10.0 10.0 10.0\n",
"0.5741076003753522\n",
"0.0 0.0 10.0 10.0 10.0\n",
"0.0 0.0 10.0 10.0 10.0\n",
"0.0 0.0 10.0 10.0 10.0\n",
"0.5741076003753522\n",
"0.0 0.0 10.0 10.0 10.0\n",
"0.0 0.0 10.0 10.0 10.0\n",
"0.0 0.0 10.0 10.0 10.0\n",
"0.5741076003753522\n",
"0.0 0.0 10.0 10.0 10.0\n",
"0.0 0.0 10.0 10.0 10.0\n",
"0.0 0.0 10.0 10.0 10.0\n",
"0.5741076003753522\n",
"0.0 0.0 3.0 10.0 27.0\n",
"0.0 0.0 10.0 10.0 10.0\n",
"0.0 0.0 10.0 10.0 10.0\n",
"0.49775269031932723\n"
]
}
],
"source": [
"for perc in [10, 20, 30, 40, 50, 60, 70, 80, 90]:\n",
" print(np.mean(calc_bpr_many_trials(\n",
" dist_N, K=3, strategy='pick_percentile',\n",
" percentile_as_frac=perc)))"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "568c8d44",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"7.0 7.0 7.0 7.0 8.0\n",
"7.0 7.0 7.0 7.0 7.0\n",
"6.0 7.0 7.0 7.0 7.0\n"
]
},
{
"data": {
"text/plain": [
"0.5943714824065128"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"np.mean(calc_bpr_many_trials(dist_N, K=3, strategy='pick_mean_of_sqrt'))"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "21dd9509",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"14.0"
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"a = np.asarray([1,2,3])\n",
"np.square(np.linalg.norm(a))"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "03fc9be3",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"14"
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"np.sum(np.square(a))"
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "3d465e51",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[1., 0., 0.],\n",
" [0., 1., 0.],\n",
" [0., 0., 1.]])"
]
},
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"np.identity(3)"
]
},
{
"cell_type": "code",
"execution_count": 25,
"id": "f5c29729",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"7.0 7.0 7.0 7.0 8.0\n",
"7.0 7.0 7.0 7.0 7.0\n",
"6.0 7.0 7.0 7.0 7.0\n",
"0.0 0.0 0.0 100.0 100.0\n",
"0.0 0.0 0.0 0.0 100.0\n",
"0.0 0.0 0.0 0.0 100.0\n",
"0.0 0.0 10.0 10.0 10.0\n",
"0.0 0.0 10.0 10.0 11.0\n",
"0.0 0.0 10.0 10.0 10.0\n"
]
}
],
"source": [
"mnames = ['cross_ratio', 'pick_mean', 'pick_median', 'guess_random']\n",
"R = 10000\n",
"\n",
"scores_MR = np.zeros((4, R))\n",
"for mm, method in enumerate(mnames):\n",
" scores_MR[mm] = calc_bpr_many_trials(\n",
" dist_N, K=3, n_trials=R, strategy=method)"
]
},
{
"cell_type": "code",
"execution_count": 26,
"id": "3a6c5ac8",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[0.7778, 0.7778, 0.7 , 0.7778, 0.7 , 0.7 , 0.7 , 0.7778, 0.175 , 0.175 ],\n",
" [0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0.8333, 0.8333],\n",
" [0.7407, 0.7407, 1. , 0.7407, 1. , 1. , 1. , 0.7407, 0.1667, 0.1667],\n",
" [0.3704, 0.6296, 0. , 0.5185, 0.8 , 0.6667, 0.9 , 0.2593, 0.1417, 0.1417]])"
]
},
"execution_count": 26,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"np.set_printoptions(precision=4, linewidth=120)\n",
"scores_MR[:, :10]"
]
},
{
"cell_type": "code",
"execution_count": 27,
"id": "338bdbd2",
"metadata": {},
"outputs": [],
"source": [
"winscore_1R = scores_MR.max(axis=0, keepdims=1)\n",
"\n",
"winners_MR = np.abs(scores_MR - winscore_1R) < 0.02"
]
},
{
"cell_type": "code",
"execution_count": 28,
"id": "7dfa0bc5",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" cross_ratio won 4453/10000 trials\n",
" pick_mean won 1953/10000 trials\n",
" pick_median won 2514/10000 trials\n",
" guess_random won 1270/10000 trials\n"
]
}
],
"source": [
"for mm, mname in enumerate(mnames):\n",
" print(\"%13s won % 6d/%d trials\" % (mname, winners_MR[mm].sum(), R))"
]
},
{
"cell_type": "code",
"execution_count": 29,
"id": "26ca43d1",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([0.21 , 0. , 0.25 , 0.2333])"
]
},
"execution_count": 29,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"np.percentile(scores_MR / winscore_1R, 20, axis=1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "593a5188",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.2"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment