Skip to content

Instantly share code, notes, and snippets.

@camriddell
Last active May 22, 2024 14:19
Show Gist options
  • Save camriddell/bfa2fbd454546d399b37e786ae485782 to your computer and use it in GitHub Desktop.
Save camriddell/bfa2fbd454546d399b37e786ae485782 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"id": "16d04fc1",
"metadata": {},
"source": [
"# The Central Limit Theorem - Visualized\n",
"\n",
"For this week, I'm finally sharing the code I wrote to produce my visualization demonstrating the Central Limit Theorem! But before we get to the code, I wanted to discuss the impact of this visualization and how it can be interpreted.\n",
"\n",
"## What is the Central Limit Theorem?\n",
"\n",
"*This is a very brief background & example of the Central Limit Theorem and\n",
"is not intended to be comprehensive.*\n",
"\n",
"The Central Limit Theorem (CLT) establishes that the summation of independent, random variables. The resultant distribution of their normalized sums takes on a normal shape even if the original random variables were not normally distributed themselves.\n",
"\n",
"If you’re not familiar with statistics, that sentence may make absolutely no sense. In that case, think of it this way: say you want to estimate the average height of a given adult (human) population. It would be nearly impossible to measure everyone’s height, so you instead take sample of 30 people and calculate their average height. You then want to generalize and claim that the average height of these 30 people represents the average height of the entire population.\n",
"\n",
"Are we sure this claim can be made? If we were to sample another 30 people and\n",
"calculate their average height, would we get the exact same answer? While it might be close, we almost certainly would not get the \"same\" answer, which indicates that there is some amount of *sampling bias* (changes in the variable of interest due to randomness in our sample).\n",
"\n",
"So where does the CLT come into play? The theorem argues that if we were to\n",
"repeat the sampling process many times, and then take the average height of each of \n",
"those samples, those averages would deliver a result similar to a normal\n",
"distribution. This distribution of averages is called the sampling distribution\n",
"of the mean (a sampling distribution is just a distribution made up of statistics\n",
"performed on many samples from the same population) and can be used to reliably\n",
"estimate the true average height of our population! The best part of the CLT is\n",
"that even if the population of heights is *not* normally distributed, the\n",
"sampling distribution of the mean will be. The upshot here is that the mean of sampled means\n",
"will be a reliable estimator of the population mean regardless of the shape of\n",
"the population distribution.\n",
"\n",
"However, we typically cannot sample 30 random people hundreds of times to\n",
"accurately estimate the average height of the population, so we often rely on\n",
"the CLT to argue that our single sample should be *good enough*.\n",
"\n",
"## Interpreting the Figure\n",
"\n",
"![](https://pbs.twimg.com/media/FZ0A52sUUAEl36y?format=jpg&name=large)\n",
"\n",
"Applied to the above figure, you can think of the population of heights as the\n",
"first row of plots. If we simulate samples from those populations (blue points\n",
"in the second row of plots) and overlay the means of each of those samples\n",
"(dark orange line), you can clearly see the sampling variability of the mean—the amount the dark orange zigs and zags from sample to sample. Then, in the final\n",
"row of the image, we plot the sampling distribution—essentially a histogram of\n",
"all of those means we calculated in the second row of plots. The CLT argues that\n",
"the final row of plots will *always* approach a normal shape even if the population\n",
"(first row of plots) is not normally distributed.\n",
"\n",
"Whew! that’s enough statistics talk, let’s get to the actual code!\n",
"\n",
"## Creating The Visualization\n",
"\n",
"How would we go about creating the above visualization? I’ll walk you through the steps and code I wrote to make it, highlighting some of the important `matplotlib` concepts you’ll need to understand to produce high quality visualizations."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "59ca3e20",
"metadata": {
"tags": [
"remove-input"
]
},
"outputs": [],
"source": [
"%matplotlib inline\n",
"%config InlineBackend.print_figure_kwargs = {'bbox_inches':None}"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f34a4159",
"metadata": {
"tags": [
"remove-input",
"remove-output"
]
},
"outputs": [],
"source": [
"from matplotlib import pyplot as plt\n",
"plt.ioff()"
]
},
{
"cell_type": "markdown",
"id": "c044e2a4",
"metadata": {},
"source": [
"I want to preface that I’ll import each specific function I use in each cell of this\n",
"notebook to keep each section clean and so you can readily map the function\n",
"back to its import.\n",
"\n",
"Firstly, I set some defaults:\n",
"1. Bump up the font size\n",
"2. Force the figure to have a white background"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "54a30894",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"from matplotlib.pyplot import rc, rcdefaults\n",
"from IPython.display import display\n",
"\n",
"rcdefaults()\n",
"rc('figure', facecolor='white')\n",
"rc('font', size=16)"
]
},
{
"cell_type": "markdown",
"id": "2e01e3a1",
"metadata": {},
"source": [
"### Figure Layout & Distribution Set Up\n",
"\n",
"With the plotting defaults out of the way, lets set up our populations\n",
"and create a layout for the visualization. First, I'll create the distributions in\n",
"`scipy`. From there, I'll set up my grid. I should note that the `gridspec_kw`\n",
"are determined post-hoc (after I finished the final version of the figure) to ensure the layout had no overlapping text and everything was spaced correctly.\n",
"\n",
"Note that you do not need to take this manual approach—relying on `matplotlib`–to clean \n",
"up your layout, but it is extremely easy with `matplotlibs` [tight layout](https://matplotlib.org/stable/tutorials/intermediate/tight_layout_guide.html)\n",
"and [constrained layout](https://matplotlib.org/stable/tutorials/intermediate/constrainedlayout_guide.html).\n",
"I manually specified my layout to ensure my plot is 100% reproducible without any further tweaks.\n",
"\n",
"By establishing my population distributions first, I can easily add more distributions later without needing to change my plotting code since I generate `len(populations)` number of columns in my `subplots` grid.\n",
"\n",
"Additionally, I can readily zip my distribution together with the first row\n",
"of `Axes` to create plots on just the top row!"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a8950c6c",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"from numpy import linspace\n",
"from scipy.stats import norm, uniform, gamma\n",
"from matplotlib.pyplot import subplots\n",
"\n",
"\n",
"## All populations should have central tendency ~70, \n",
"# keeps things comparable across distributions\n",
"populations = [\n",
" norm(loc=70, scale=3), uniform(60, 20), gamma(1.99, loc=65, scale=2)\n",
"]\n",
"\n",
"fig, axes = subplots(\n",
" 3, len(populations), figsize=(20, 16),\n",
" sharex=True, sharey='row',\n",
" gridspec_kw={\n",
" 'height_ratios': [1, 5, 1], 'hspace': 0, 'wspace': .1,\n",
" 'bottom': .08, 'right': .9, 'left': .3, 'top': .9\n",
" },\n",
" dpi=52\n",
")\n",
"\n",
"for pop, ax in zip(populations, axes[0]):\n",
" xs = linspace(*pop.ppf([.001, .999]), 4_000)\n",
" ax.fill_between(xs, pop.pdf(xs), alpha=.5, label='Distribution')\n",
" ax.axvline(\n",
" pop.mean(), 0, .95, ls='dashed', color='k', lw=2, label=r'Mean $\\mu$'\n",
" )\n",
" ax.set_title(pop.dist.name.title(), fontsize='large')\n",
" \n",
" ax.xaxis.set_tick_params(length=0)\n",
" ax.yaxis.set_tick_params(labelleft=False, length=0)\n",
" ax.margins(y=0)\n",
" \n",
"display(fig)"
]
},
{
"cell_type": "markdown",
"id": "2fc04905",
"metadata": {},
"source": [
"### Running & Displaying the Simulation\n",
"\n",
"I then set up my samples from each population distribution. Each sample has a\n",
"pre-specified size and is re-drawn many times to capture the sampling \n",
"variability. I use NumPy methods as often as possible to ensure my code runs\n",
"quickly, and then zip those samples together with the second row of plots to\n",
"perform the actual plotting. There are also a few other NumPy tricks I use\n",
"to ensure that my arrays align with each other when performing plotting."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "15e7d147",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"from numpy.random import default_rng\n",
"from numpy import arange, broadcast_to\n",
"\n",
"rng = default_rng(0)\n",
"\n",
"n_samples, sample_size = 100, 25\n",
"samples = [\n",
" p.rvs(size=(n_samples, sample_size), random_state=rng) \n",
" for p in populations\n",
"]\n",
"\n",
"ys = arange(n_samples)\n",
"scatter_ys = broadcast_to(ys[:, None], (n_samples, sample_size))\n",
"rng = default_rng(0)\n",
"\n",
"for s, samp_ax in zip(samples, axes[1, :]):\n",
" smean, sdev = s.mean(axis=1), s.std(axis=1)\n",
"\n",
" samp_ax.scatter(\n",
" s, y=scatter_ys, s=4, c='tab:blue', alpha=.7, label='Observations'\n",
" )\n",
" samp_ax.fill_betweenx(\n",
" ys, x1=smean - sdev, x2=smean + sdev, color='tab:orange', \n",
" alpha=.3, label='Std. Dev. $S_x$'\n",
" )\n",
" samp_ax.plot(smean, ys, color='tab:orange', label=r'Mean $\\bar{x}$')\n",
" samp_ax.xaxis.set_tick_params(length=0)\n",
"\n",
"display(fig)"
]
},
{
"cell_type": "markdown",
"id": "3ab2fcd2",
"metadata": {},
"source": [
"### Creating the Sampling Distributions\n",
"\n",
"In the final row of my plot, I create a histogram with a fitted gaussian on the \n",
"bottom. This is intended to directly map observed values to their smoothed\n",
"counterpart. I plot the mean of the sampling distribution as a vertical black\n",
"bar that represents the estimated population average."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e4a05f7e",
"metadata": {},
"outputs": [],
"source": [
"for s, mean_ax in zip(samples, axes[2, :]):\n",
" smean, sdev = s.mean(axis=1), s.std(axis=1)\n",
"\n",
" mean_ax.axvline(\n",
" smean.mean(), ymin=.5, ymax=.9, c='k', ls='dashed', \n",
" lw=2, label='Est. Pop. Mean $\\hat{\\mu}$'\n",
" )\n",
" \n",
" norm_density = norm(*norm.fit(smean))\n",
" xs = linspace(*norm_density.ppf([.001, .999]), 4000)\n",
" mean_ax.fill_between(\n",
" xs, -norm_density.pdf(xs), label='Fitted Gaussian', \n",
" alpha=.5, color='tab:orange'\n",
" )\n",
" mean_ax.hist(\n",
" smean, bins='auto', label=r'Sample Means $\\bar{x}$', \n",
" density=True, color='tab:orange', ec='white'\n",
" )\n",
" mean_ax.yaxis.set_visible(False)\n",
" mean_ax.spines['bottom'].set_position('zero')\n",
" \n",
"display(fig)"
]
},
{
"cell_type": "markdown",
"id": "6be5f377",
"metadata": {},
"source": [
"### Cleaning Aesthetics\n",
"\n",
"All we have left to do now is a little cleanup. First, I force major ticks to appear on every\n",
"ten units of the y-axis on the second row of plots. These plots all share a y-axis,\n",
"so I only need to make the change on one of them. Since the figure is designed\n",
"to be read from top to bottom, I invert the y-axis so that it increases as the\n",
"readers' eyes move downwards. I also set the y label and rotate it to\n",
"ensure its orientation is in line with the downwards count.\n",
"\n",
"Next, I tidy the bottom row of `Axes`; since all plots share an x-axis, I can \n",
"uniformly set the x-axis on all of the plots and reduce their margins and center\n",
"the expected population mean. I also add some extra padding on the y-margin of \n",
"these plots so the fitted distribution/histograms don’t bump up against their\n",
"`Axes` limits."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "df2eeb2d",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"from matplotlib.ticker import MultipleLocator\n",
"\n",
"## Set custom ticks and limits to the Simulation & Sampling Dist. Axes\n",
"# The y-axes are shared across rows of the figure, \n",
"# so we only need to invert 1 y-axis out of the row of sample Axes\n",
"samp_ax = axes[1, 0]\n",
"samp_ax.yaxis.set_major_locator(MultipleLocator(10))\n",
"samp_ax.yaxis.set_major_formatter('Simulation {x:.0f}')\n",
"samp_ax.invert_yaxis()\n",
"# samp_ax.set_ylabel('Simulation', size='large', rotation=-90, va='top')\n",
"\n",
"# Manually set the xlimits, they're shared across all Axes\n",
"# the population means hover ~70, so we drop that xtick for visibility\n",
"mean_ax = axes[2, 0]\n",
"mean_ax.set(xlim=(60, 80), xticks=[60, 65, 75, 80])\n",
"mean_ax.invert_yaxis() # flip the histgram and KDE, so the KDE is on top\n",
"mean_ax.margins(y=0.1)\n",
"\n",
"display(fig)"
]
},
{
"cell_type": "markdown",
"id": "2d80115b",
"metadata": {},
"source": [
"### Adding Text Annotations & Legends\n",
"\n",
"Now I want to add some descriptive legends I'll use to annotate my figure.\n",
"To do this, I rely on `matplotlibs` internal Legend generation from the\n",
"`Axes.legend`. However, I want to add custom titles and subtitles to these plots.\n",
"Do to this, I use a little trick of decomposing the legends via their `.get_children()` \n",
"method and pack on my own `TextArea` to have explicit control over multiple\n",
"fonts, the title spacing, and alignment with the Legend."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "bf165c0e",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"from matplotlib.transforms import blended_transform_factory\n",
"from matplotlib.offsetbox import VPacker, TextArea, AnchoredOffsetbox\n",
"\n",
"## Create left-aligned titles for each row\n",
"axes_titles = [\n",
" ('Population', ''), \n",
" ('Samples', f'n = {sample_size}'), \n",
" ('Sampling Distribution', 'of the sample mean')\n",
"]\n",
"for (title, subtitle), ax in zip(axes_titles, axes[:, 0], strict=True):\n",
" titlebox = [\n",
" TextArea(title, textprops={'size': 'large', 'weight': 'semibold'})\n",
" ]\n",
" if subtitle:\n",
" titlebox.append(TextArea(subtitle, textprops={'style': 'italic'}))\n",
" title_packer = VPacker(pad=0, sep=5, children=titlebox)\n",
" \n",
" legend = fig.legend(\n",
" *ax.get_legend_handles_labels(), markerscale=4, scatterpoints=4\n",
" )\n",
" legend.remove()\n",
" \n",
" # Legends are composed of two children: VPacker & FancyBboxPatch\n",
" # We can extract the VPacker and add it to our own for a very custom title\n",
" legend_body, _ = legend.get_children() \n",
" transform = blended_transform_factory(fig.transFigure, ax.transAxes)\n",
" fig.add_artist(\n",
" AnchoredOffsetbox(\n",
" loc='upper left',\n",
" child=VPacker(\n",
" align='left', pad=10, sep=10, \n",
" children=[title_packer, legend_body]\n",
" ),\n",
" bbox_to_anchor=(0.05, 1), bbox_transform=transform, \n",
" borderpad=0, frameon=False\n",
" )\n",
" )\n",
"\n",
"display(fig)"
]
},
{
"cell_type": "markdown",
"id": "a42a6e46",
"metadata": {},
"source": [
"Now, I want to add some horizontal separation to force the viewer to separate the\n",
"three stages of these plots. I leverage `matplotlib`'s transforms to place a horizontal\n",
"line that aligns with the bottom of the first and second rows of my `Axes` and\n",
"spans the majority of the `Figure`.\n",
"\n",
"I then use a similar approach to the legend title and subtitle to add a figure title\n",
"to the top of my Figure. I could have aligned this to my `GridSpec` instead of \n",
"on the `Figure`, but I felt that the center `Figure` alignment worked better\n",
"to act as a title for the entire `Figure`, and not just the 3 columns of plots."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ddfa06d7",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"from matplotlib.lines import Line2D\n",
"\n",
"## Add lines to separate rows of plots\n",
"for ax in axes[1:, 0]:\n",
" transform = blended_transform_factory(fig.transFigure, ax.transAxes)\n",
" fig.add_artist(\n",
" Line2D([.05, .95], [1, 1], color='lightgray', transform=transform)\n",
" )\n",
"\n",
"for ax in axes[:, 1:].flat:\n",
" ax.yaxis.set_tick_params(left=False)\n",
"\n",
"## Figure title - VPacker with TextAreas enables great control\n",
"# over alignment & different fonts\n",
"figure_title = VPacker(\n",
" align='center', pad=0, sep=5,\n",
" children=[\n",
" TextArea(\n",
" 'Visualizing the Central Limit Theorem', \n",
" textprops={'size': 'x-large', 'weight': 'bold'}\n",
" ),\n",
" TextArea(\n",
" 'Sampling Distributions of the Sample Mean', \n",
" textprops={'size': 'large', 'style': 'italic'}\n",
" )\n",
" ])\n",
"\n",
"fig.add_artist(\n",
" AnchoredOffsetbox(\n",
" loc='upper center', child=figure_title,\n",
" bbox_to_anchor=(0.5, 1.0), bbox_transform=fig.transFigure,\n",
" frameon=False\n",
" )\n",
")\n",
"\n",
"# Remove spines from charts\n",
"for ax in fig.axes:\n",
" sides = ['bottom', 'left', 'top', 'right']\n",
" if ax in axes[-1, :]:\n",
" sides.remove('bottom')\n",
" ax.spines[sides].set_visible(False)\n",
"\n",
"display(fig)"
]
},
{
"cell_type": "markdown",
"id": "4c97c51f",
"metadata": {},
"source": [
"## Wrap Up\n",
"\n",
"And that takes us to the end of our Central Limit Theorum visualization. Hopefully you\n",
"learned a little bit of statistics, as well as some tips you can use to take\n",
"your `matplotlib` game to the next level and create refined, communicative data\n",
"visualizations programmatically. Talk to you all next time!"
]
}
],
"metadata": {
"author": "Cameron Riddell",
"blogpost": true,
"date": "2022-09-28",
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.2"
},
"tags": "statistics,matplotlib,advanced"
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment