Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save michaelchughes/5345a740124578062ed744032354faa8 to your computer and use it in GitHub Desktop.
Save michaelchughes/5345a740124578062ed744032354faa8 to your computer and use it in GitHub Desktop.
Power Law Fit to the trend in Error vs Dev Set size in View Classifier Performance
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "forty-kenya",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import pandas as pd\n",
"import scipy.stats"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "perceived-duplicate",
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"import seaborn as sns\n",
"sns.set_style(\"whitegrid\")\n",
"sns.set_context(\"notebook\", font_scale=1.25)"
]
},
{
"cell_type": "markdown",
"id": "thousand-advisory",
"metadata": {},
"source": [
"There is compelling prior theoretical work and experimental work (see Hestness et al. 2017; https://arxiv.org/pdf/1712.00409.pdf) that suggests that the loss incurred by a model as a function of training set size $n$ scales using a power-law relationship\n",
"\n",
"$$\n",
"\\ell(n) = \\alpha n^{-\\beta} \n",
"$$\n",
"\n",
"for $\\alpha > 0$ and $\\beta \\in (0, 1)$."
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "agreed-suggestion",
"metadata": {},
"outputs": [],
"source": [
"# Raw Data from Zhe\n",
"\n",
"### view classification performance\n",
"acc_P = np.array([90.34, 95.77, 97.03]) / 100.\n",
"\n",
"### X: corresponding size of the development set\n",
"n_P = np.array([56, 165, 479])\n",
"\n",
"loss_P = 1.0 - acc_P\n"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "opened-vertex",
"metadata": {},
"outputs": [],
"source": [
"def project_loss(n, alpha, beta):\n",
" return alpha * np.power(n, -beta)"
]
},
{
"cell_type": "markdown",
"id": "amended-preference",
"metadata": {},
"source": [
"Let's show this class of projection models might reasonabily fit our data"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "instructional-package",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plt.plot( n_P, loss_P, 's');\n",
"\n",
"alpha = 1.0\n",
"\n",
"n_G = np.logspace(np.log10(10), np.log10(10000), 30)\n",
"\n",
"for beta in [0.3, 0.5, 0.7]:\n",
" plt.plot(n_G, project_loss(n_G, alpha, beta), label=r'$\\beta=%s$' % beta);\n",
"plt.legend();"
]
},
{
"cell_type": "markdown",
"id": "sustained-consortium",
"metadata": {},
"source": [
"## Now, let's do a mean-squared error fit to our data\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "solar-nutrition",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 6,
"id": "isolated-suite",
"metadata": {},
"outputs": [],
"source": [
"import autograd"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "brazilian-bicycle",
"metadata": {},
"outputs": [],
"source": [
"import autograd.numpy as ag_np\n",
"import autograd.scipy.special as ag_special"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "virgin-chocolate",
"metadata": {},
"outputs": [],
"source": [
"def ag_softplus(r_alpha):\n",
" return ag_np.log(1 + ag_np.exp(r_alpha))"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "adjusted-bruce",
"metadata": {},
"outputs": [],
"source": [
"ag_sigmoid = ag_special.expit"
]
},
{
"cell_type": "code",
"execution_count": 25,
"id": "recreational-former",
"metadata": {},
"outputs": [],
"source": [
"def ag_project_loss(n, alpha, beta):\n",
" return alpha * ag_np.power(n, -beta)"
]
},
{
"cell_type": "code",
"execution_count": 26,
"id": "numerous-fellowship",
"metadata": {},
"outputs": [],
"source": [
"def calc_cost_at_param_vec(pvec):\n",
" hat_loss_P = ag_project_loss(n_P, ag_softplus(pvec[0]), ag_sigmoid(pvec[1]))\n",
" return ag_np.sum(ag_np.square(hat_loss_P - loss_P))"
]
},
{
"cell_type": "code",
"execution_count": 27,
"id": "sacred-andrews",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.0047561726148123065"
]
},
"execution_count": 27,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"calc_cost_at_param_vec([0.0, 1.0])\n"
]
},
{
"cell_type": "code",
"execution_count": 28,
"id": "weekly-orlando",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.0001556678872876396"
]
},
"execution_count": 28,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"calc_cost_at_param_vec(np.asarray([0.0, 0.0]))"
]
},
{
"cell_type": "code",
"execution_count": 29,
"id": "wrapped-primary",
"metadata": {},
"outputs": [],
"source": [
"calc_grad_at_param_vec = autograd.grad(calc_cost_at_param_vec, argnum=0)"
]
},
{
"cell_type": "code",
"execution_count": 30,
"id": "killing-stocks",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([ 0.0004668 , -0.00105819])"
]
},
"execution_count": 30,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"calc_grad_at_param_vec(ag_np.asarray([0.0, 0.0]))"
]
},
{
"cell_type": "markdown",
"id": "loaded-fifty",
"metadata": {},
"source": [
"# Minimize via scipy optimize"
]
},
{
"cell_type": "code",
"execution_count": 31,
"id": "endless-leisure",
"metadata": {},
"outputs": [],
"source": [
"import scipy.optimize"
]
},
{
"cell_type": "code",
"execution_count": 32,
"id": "atmospheric-supervision",
"metadata": {},
"outputs": [],
"source": [
"ans = scipy.optimize.minimize(\n",
" calc_cost_at_param_vec,\n",
" jac=calc_grad_at_param_vec,\n",
" x0=np.zeros(2),\n",
" method='bfgs')"
]
},
{
"cell_type": "code",
"execution_count": 33,
"id": "brutal-sequence",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
" fun: 6.214769899777294e-05\n",
" hess_inv: array([[4226.43851196, 2321.53517882],\n",
" [2321.53517882, 1312.89223456]])\n",
" jac: array([-3.41309401e-07, -6.30261846e-07])\n",
" message: 'Optimization terminated successfully.'\n",
" nfev: 20\n",
" nit: 17\n",
" njev: 20\n",
" status: 0\n",
" success: True\n",
" x: array([0.89378172, 0.56242587])"
]
},
"execution_count": 33,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"ans"
]
},
{
"cell_type": "code",
"execution_count": 34,
"id": "crucial-feature",
"metadata": {},
"outputs": [],
"source": [
"assert ans.message.count('success')"
]
},
{
"cell_type": "code",
"execution_count": 36,
"id": "aerial-lawrence",
"metadata": {},
"outputs": [],
"source": [
"best_alpha = ag_softplus(ans.x[0])"
]
},
{
"cell_type": "code",
"execution_count": 37,
"id": "tested-queens",
"metadata": {},
"outputs": [],
"source": [
"best_beta = ag_sigmoid(ans.x[1])"
]
},
{
"cell_type": "code",
"execution_count": 38,
"id": "settled-chorus",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"alpha = 1.2367\n",
"beta = 0.6370\n"
]
}
],
"source": [
"print(\"alpha = %.4f\" % best_alpha)\n",
"print(\"beta = %.4f\" % best_beta)"
]
},
{
"cell_type": "code",
"execution_count": 49,
"id": "disciplinary-latino",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plt.plot(np.log10(n_P), loss_P, 's');\n",
"\n",
"alpha = 1.0\n",
"\n",
"n_G = np.logspace(np.log10(10), np.log10(10000), 30)\n",
"\n",
"plt.plot(np.log10(n_G), project_loss(n_G, best_alpha, best_beta), label=r'$\\alpha$=%.3f, $\\beta$=%.3f' % (best_alpha,best_beta));\n",
"\n",
"\n",
"nticks_G = np.logspace(np.log10(10), np.log10(10000), 4)\n",
"plt.gca().set_xticks(np.log10(nticks_G))\n",
"plt.gca().set_xticklabels(nticks_G)\n",
"plt.ylabel('error rate');\n",
"plt.xlabel('size of dev set (log10 scale)');\n",
"\n",
"plt.legend();"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "junior-frame",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "settled-strengthening",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.2"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
@michaelchughes
Copy link
Author

Just in case, here's the final plot

image

Says that with size 1000, we can expect error rate below 2%, and essentially 0 error at size 10000

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment