Skip to content

Instantly share code, notes, and snippets.

@rlvaugh
Created April 22, 2023 23:03
Show Gist options
  • Save rlvaugh/f1c293c7b31fa878b0cdeb256c6e99ab to your computer and use it in GitHub Desktop.
Save rlvaugh/f1c293c7b31fa878b0cdeb256c6e99ab to your computer and use it in GitHub Desktop.
Jupyter Notebook for the article: "Predict the Limits of Human Performance with Python" by Lee Vaughan
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "1f5c0276-ce95-4fa9-95cb-138627023018",
"metadata": {},
"outputs": [],
"source": [
"import warnings\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"import pandas as pd\n",
"import scipy.optimize\n",
"from sklearn.metrics import mean_squared_error\n",
"\n",
"# Suppress warnings for using np.exp() on large values during optimization:\n",
"warnings.filterwarnings('ignore')\n",
"\n",
"# Set default run configuration for plots:\n",
"plt.rcParams['figure.figsize'] = (6, 4)\n",
"plt.rc('font', size=12)\n",
"plt.rc('axes', titlesize=14) \n",
"plt.rc('axes', labelsize=12) \n",
"plt.rc('xtick', labelsize=11) \n",
"plt.rc('ytick', labelsize=11) \n",
"plt.rc('legend', fontsize=11) "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "236c0543-5955-4866-94bb-f1dc44af5760",
"metadata": {},
"outputs": [],
"source": [
"# Input men's 100 m world records in seconds.\n",
"# If two records set in the same year, list only the latest (lowest):\n",
"records = {2009: 9.58, 2008: 9.69, 2007: 9.74, 2005: 9.77, 2002: 9.78,\n",
" 1999: 9.79, 1996: 9.84, 1994: 9.85, 1991: 9.86, 1988: 9.92, \n",
" 1983: 9.93, 1968: 9.95, 1960: 10, 1956: 10.1, 1936: 10.2, \n",
" 1930: 10.3, 1921: 10.4, 1912: 10.6} \n",
"\n",
"# Turn dictionary into a DataFrame:\n",
"df = pd.DataFrame(records.items(), columns=['year', 'time'])\n",
"df['years'] = df['year'] - 1912 # Years since first record.\n",
"df = df.sort_values('year').reset_index(drop=True)\n",
"display(df)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "61968a7a-879b-4d00-a797-34075411b53f",
"metadata": {},
"outputs": [],
"source": [
"# Graph the world records:\n",
"plt.stem(df.year, df.time)\n",
"plt.title(\"Men's 100 m Sprint World Records\")\n",
"plt.ylabel(\"Time (secs)\")\n",
"plt.ylim(9.5, 10.8)\n",
"plt.grid(True);"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3217ac09-b86e-4517-88b1-c08c2279c920",
"metadata": {},
"outputs": [],
"source": [
"def expo(x, a, b, c):\n",
" \"\"\"Return exponential decay curve.\"\"\"\n",
" return a * np.exp(-b * x) + c\n",
"\n",
"def optimize_curve_fit(a_func, x, y):\n",
" \"\"\"Return optimized parameters for curve fit.\"\"\"\n",
" params, covar = scipy.optimize.curve_fit(a_func, x, y, p0=None)\n",
" return params"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c0518805-1d2b-48df-a0cd-9d538b647eff",
"metadata": {},
"outputs": [],
"source": [
"def curve_fit_stats(x, y, a, b, c):\n",
" \"\"\"Calculate R-squared, Mean Squared Error, Root MSE, Normalized RMSE.\"\"\"\n",
" y_predicted = expo(x, a, b, c) \n",
" diffs_squared = np.square(y - y_predicted)\n",
" diffs_squared_from_mean = np.square(y - np.mean(y))\n",
" r_squared = 1 - np.sum(diffs_squared) / np.sum(diffs_squared_from_mean) \n",
" mse = mean_squared_error(y, y_predicted) \n",
" rmse = mean_squared_error(y, y_predicted, squared=False)\n",
" nrmse = rmse / (y.max() - y.min())\n",
" return r_squared, mse, rmse, nrmse"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4b28fedd-98ee-4bd7-abd6-bc67e8d89292",
"metadata": {},
"outputs": [],
"source": [
"# Generate datasets with and without Bolt's times (nB = No Bolt):\n",
"x_all, y_all = df.years, df.time\n",
"x_nB, y_nB = x_all[:-2], y_all[:-2]\n",
"\n",
"# Find optimized parameters for fitting curve to points for no Bolt and Bolt.\n",
"params_nB = optimize_curve_fit(expo, x_nB, y_nB)\n",
"params_all = optimize_curve_fit(expo, x_all, y_all)\n",
"print(f\"Parameters without Bolt (a, b, c) = {params_nB}\") \n",
"print(f\"Parameters with Bolt (a, b, c) = {params_all}\")\n",
"\n",
"# Calculate fitting statistics:\n",
"r2, mse, rmse, nrmse = curve_fit_stats(x_all, y_all, *params_all)\n",
"r2_nB, mse_nB, rmse_nB, nrmse_nB = curve_fit_stats(x_nB, y_nB, *params_nB)\n",
"print(\"\\nCurve-fitting statistics WITHOUT Bolt data:\")\n",
"print(f\"R² = {r2_nB: .3f}\")\n",
"print(f\"Mean Square Error = {mse_nB: .4f}\")\n",
"print(f\"Root MSE = {rmse_nB: .4f}\")\n",
"print(f\"Normalized RMSE = {nrmse_nB: .4f}\\n\")\n",
"\n",
"print(\"Curve-fitting statistics WITH Bolt data:\")\n",
"print(f\"R² = {r2: .3f}\")\n",
"print(f\"Mean Square Error = {mse: .4f}\")\n",
"print(f\"Root MSE = {rmse: .4f}\")\n",
"print(f\"Normalized RMSE = {nrmse: .4f}\\n\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "78b8b97f-29af-427e-bdff-0daac0ae4d23",
"metadata": {},
"outputs": [],
"source": [
"# Plot exponential curves for data with and without Bolt's times:\n",
"plt.plot(x_all, y_all, '.', label='measured data', c='k')\n",
"plt.plot(x_nB, expo(x_nB, *params_nB), \n",
" '-', label='fitted without Bolt')\n",
"plt.plot(x_all, expo(x_all, *params_all), '--', \n",
" label='fitted with Bolt', c='red')\n",
"plt.title(\"Men's 100 m World Record\")\n",
"plt.xlabel('Years Since First Record (1912)')\n",
"plt.ylabel('Times (s)')\n",
"plt.grid(True)\n",
"plt.legend(framealpha=1);"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9f4ae4f2-e2e5-4a1b-843d-cbec3a17e867",
"metadata": {},
"outputs": [],
"source": [
"# Extrapolate exponential curves to predict future performance:\n",
"x_extrap = np.arange(-20, 550) # Curve flattens at ~550 years since 1912.\n",
"y_nB_extrap = expo(x_extrap, *params_nB) # Without Bolt.\n",
"y_B_extrap = expo(x_extrap, *params_all) # With Bolt.\n",
"\n",
"# Create a plot of the world record times and the extrapolated curves.\n",
"fig, ax = plt.subplots()\n",
"ax.plot(x_all, y_all, '.', label='data', c='k')\n",
"ax.plot(x_extrap, y_nB_extrap, '-', label='fitted without Bolt')\n",
"ax.plot(x_extrap, y_B_extrap, '--', c='red', label='fitted with Bolt')\n",
"ax.set(title=\"Men's 100 m World Record Extrapolated\",\n",
" xlabel='Years Since First Record (1912)',\n",
" ylabel='Time (s)',\n",
" yticks=np.arange(9.0, 11.0, 0.2))\n",
"ax.grid(True)\n",
"ax.legend(framealpha=1)\n",
"\n",
"# Add a dotted horizontal line for each of Bolt's world record times.\n",
"bolt_times = {2009: 9.58, 2008: 9.69}\n",
"for year, time in bolt_times.items():\n",
" ax.axhline(time, ls=':', linewidth=1.3, color='red')\n",
" ax.text(0, time + 0.01, f\"Bolt {year}\", color='red',\n",
" horizontalalignment='left', size=9)\n",
"\n",
"# Define function and inverse function to permit a secondary x-axis for year:\n",
"axis_transform = lambda x_extrap: x_extrap + 1912\n",
"axis_inverse = lambda x_extrap: x_extrap - 1912\n",
"ax2 = ax.secondary_xaxis('top', functions=(axis_transform, axis_inverse))\n",
"\n",
"print(f\"\\nMinimum predicted time without Bolt data = {min(y_nB_extrap):.2f} sec.\")\n",
"print(f\"Minimum predicted time with Bolt data = {min(y_B_extrap):.2f} sec.\\n\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.13"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment