Skip to content

Instantly share code, notes, and snippets.

@alexander-wei
Created August 26, 2022 07:29
Show Gist options
  • Save alexander-wei/2ba110a1dee51211d99fdd661b7f2599 to your computer and use it in GitHub Desktop.
Save alexander-wei/2ba110a1dee51211d99fdd661b7f2599 to your computer and use it in GitHub Desktop.
bootstrap_t_aug_2
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 448,
"id": "ab497cc2-d762-42ed-b90c-39d8eefb0bc4",
"metadata": {},
"outputs": [],
"source": [
"LAM = lambda **args: \\\n",
"args['alpha'] * keyword_data[args['keyword']]['effect'] + args['beta'] * keyword_data[args['keyword']]['delta']\n",
"\n",
"Filter = lambda functional: \\\n",
"np.argmax(functional)"
]
},
{
"cell_type": "code",
"execution_count": 449,
"id": "ea37de13-38a0-404f-9d94-6869ad24d155",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'programming': {'effect': array([ 0.52727697, 0.62275902, 0.31992764, -1.47398622, 0.54000594,\n",
" 0.69074962, -0.56053367, -0.26081456, -0.46688856, -0.47580749]),\n",
" 'delta': array([ 0.0670864 , 0.08354775, 0.04003509, -0.10704586, 0.0629095 ,\n",
" 0.05717983, -0.03446123, -0.01716745, -0.03431878, -0.03526842])},\n",
" 'iphone': {'effect': array([ 0.39874329, -0.22812587, -1.5355249 , 0.57258721, 0.04357991,\n",
" 0.61691486, 0.41334765, 0.23663399, -1.00647934, 0.33687891]),\n",
" 'delta': array([ 0.04906107, -0.02348979, -0.11533874, 0.07157464, 0.00465422,\n",
" 0.05769759, 0.03820275, 0.01912015, -0.06106605, 0.03389128])},\n",
" 'lessons': {'effect': array([ 0.92381862, -0.1606284 , 0.3732374 , 0.08254097, -0.810276 ,\n",
" 0.46417035, -0.63788485, -1.25698357, 0.57690274, 0.59065989]),\n",
" 'delta': array([ 0.17997197, -0.01967962, 0.06275833, 0.00991244, -0.06583493,\n",
" 0.04353218, -0.03588914, -0.06327196, 0.05551289, 0.09985146])},\n",
" 'law': {'effect': array([ 0.85737955, 0.14047853, 0.07776645, -0.24235926, -0.29275087,\n",
" -0.04523782, 0.14086686, -0.29405953, -0.13099893, 0.13346111]),\n",
" 'delta': array([ 0.1179123 , 0.01402793, 0.00810001, -0.02359755, -0.02626759,\n",
" -0.00352435, 0.00948065, -0.01827477, -0.00961103, 0.01083245])}}"
]
},
"execution_count": 449,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"keyword_data"
]
},
{
"cell_type": "code",
"execution_count": 460,
"id": "d7388fa2-1860-4a3e-a1ba-aedd6e52cabc",
"metadata": {},
"outputs": [],
"source": [
"# what does a randomly sampled vector of comment likelihoods look like?\n",
"N = 3 # as N->inf, we're more certain of our observed likelihoods and the filter will approximate it exactly\n",
"# each step, we feed the filter some average of observed likelihoods and prior likelihoods\n",
"\n",
"_M = len(words)\n",
"def get_y_sample(t=1,s=2.5,N=3,shift=0):\n",
" M = 500;\n",
" train_y = np.zeros((_M, M ,10))\n",
" for j in range(_M):\n",
" comparisonvect = word_stats[j]\n",
" # For simulating we'll assume that a keyword's commented distribution falls between 1) what we observed (and evaluated t scores for)\n",
" # and 2) a normal centered at 25% which is the average over the entire Hacker News dataset for all time slots\n",
" UNIFORM_PRIOR = .25\n",
" train_y[j,:,:] = \\\n",
" OHEnc.transform( \\\n",
" np.array(\n",
" [np.argmax(\n",
" np.array(\n",
" [1/(t+s)*(t*stats.t.rvs(loc=u[0],scale=u[1]**.5,df=30,size=N) + (.25+shift) * s*np.ones(N)) \\\n",
" # DRAW NEW SAMPLES \\\n",
" for u in comparisonvect]\n",
" ).mean(axis=1)) for _ in range(M)]).reshape(-1,1))\n",
" return train_y"
]
},
{
"cell_type": "code",
"execution_count": 465,
"id": "c05b4c20-4b9b-4e43-a61a-b9ac04bea497",
"metadata": {},
"outputs": [],
"source": [
"def train_step(alpha,beta,Y,verbose=False,ret_matrix=False):\n",
" train_x = np.zeros((_M ,500, 10))\n",
" i = 0;\n",
" for i in range(_M ):\n",
" # get the \"most likely\" time point from the Filter\n",
" predicted = Filter(LAM(alpha=alpha,beta=beta,keyword=words[i]))\n",
" # Mark our predictions, and take a difference from \"actual\"/sampled best times to get an error\n",
" train_x[i,:,predicted] = 1\n",
" delt = np.abs(train_x - Y)\n",
" return (np.mean(delt,axis=1) ** 2).mean(axis=1).mean()"
]
},
{
"cell_type": "code",
"execution_count": 528,
"id": "ab3613b0-2c59-405e-bade-86b33eab661e",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Text(0.5, 0, 'alpha parameter')"
]
},
"execution_count": 528,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"# if our train data were 100% accurate, we would want to filter based solely highest recorded comment rate\n",
"# and we would always seek to minimize Training set error\n",
"y_train = get_y_sample(t=1,s=0,N=200,shift=-.2)\n",
"I = np.linspace(0,.4,30)\n",
"plt.plot(I, [train_step(v,1., y_train) for v in I])\n",
"plt.plot(I, [train_step(v,1., y_validate ) for v in I])\n",
"plt.title(\"Error: Training set (blue) vs Augmented set (orange)\")\n",
"plt.xlabel(\"alpha parameter\")"
]
},
{
"cell_type": "code",
"execution_count": 511,
"id": "9868d645-60ba-4031-99ba-aaa349af44c4",
"metadata": {},
"outputs": [],
"source": [
"import plotly.express as px"
]
},
{
"cell_type": "code",
"execution_count": 525,
"id": "0ed8b368-5a16-46bc-bf56-056cf00f7e4c",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 360x360 with 2 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plt.figure(figsize=(5,5,));\n",
"plt.pcolormesh(np.squeeze(q));\n",
"plt.xticks(np.arange(0,30,30/4), np.round(I[::8],1))\n",
"plt.yticks(np.arange(0,30,30/4), np.round(I[::8],1)); plt.colorbar();\n",
"plt.xlabel(\"Effect weight\"); plt.ylabel(\"raw delta weight\");\n",
"plt.title(\"Time/Comment rate Mode Identification Error\");"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.5"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment