Skip to content

Instantly share code, notes, and snippets.

@velikodniy
Last active November 29, 2019 12:35
Show Gist options
  • Save velikodniy/7bab6f236b45a70361506186a1120fbf to your computer and use it in GitHub Desktop.
Save velikodniy/7bab6f236b45a70361506186a1120fbf to your computer and use it in GitHub Desktop.
assocplot from R implemented in Python
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"inputHidden": false,
"outputHidden": false
},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"import numpy as np"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"inputHidden": false,
"outputHidden": false
},
"outputs": [],
"source": [
"def assocplot(x, rownames, colnames, col=['k', 'r'], space=0.3):\n",
" if x.ndim != 2:\n",
" raise ValueError('\"x\" must be a 2-d contingency table')\n",
" if (x < 0).any() or np.isnan(x).any():\n",
" raise ValueError('all entries of \"x\" must be nonnegative and finite')\n",
" if len(col) != 2:\n",
" raise ValueError('incorrect \"col\": must be length 2') \n",
" if len(rownames) != x.shape[0]:\n",
" raise ValueError('the length of \"rownames\" must be equal to the number of rows')\n",
" if len(colnames) != x.shape[1]:\n",
" raise ValueError('the length of \"colnames\" must be equal to the number of columns')\n",
" \n",
" n = x.sum() \n",
" if n == 0:\n",
" raise ValueError('at least one entry of \"x\" must be positive')\n",
" \n",
" f = x[:, ::-1]\n",
" colnames = colnames[::-1]\n",
" \n",
" e2 = f.sum(axis=1, keepdims=True) @ f.sum(axis=0, keepdims=True) / n\n",
" e = np.sqrt(e2)\n",
" d = (f - e2) / e\n",
" \n",
" x_w = e.max(axis=1)\n",
" y_h = d.max(axis=0) - d.min(axis=0)\n",
"\n",
" x_delta = x_w.mean() * space\n",
" y_delta = y_h.mean() * space\n",
" \n",
" x_r = (x_w + x_delta).cumsum()\n",
" x_m = np.convolve(x_r, [0.5, 0.5], 'same')\n",
" \n",
" y_u = (y_h + y_delta).cumsum()\n",
" y_m = y_u - np.maximum(d, 0).max(axis=0) - y_delta / 2\n",
" \n",
" xs, ys = np.meshgrid(x_m, y_m)\n",
"\n",
" ax = plt.gca()\n",
" \n",
" ax.set_xlim(0, x_w.sum() + x.shape[0] * x_delta)\n",
" ax.set_ylim(0, y_h.sum() + x.shape[1] * y_delta)\n",
" \n",
" ax.set_xticks(x_m)\n",
" ax.set_xticklabels(rownames)\n",
" \n",
" ax.set_yticks(y_m)\n",
" ax.set_yticklabels(colnames)\n",
" ax.grid(True, axis='y')\n",
" \n",
" for x, w, y, h in zip(xs.flatten(), e.T.flatten(), ys.flatten(), d.T.flatten()):\n",
" color = col[int(h < 0)]\n",
" ax.add_patch(plt.Rectangle((x - w / 2, y), w, h, color=color, alpha=0.6))"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"df = pd.DataFrame(\n",
" [[ 68, 20, 15, 5],\n",
" [119, 84, 54, 29],\n",
" [ 26, 17, 14, 14],\n",
" [ 7, 94, 10, 16]],\n",
" index=['Black', 'Brown', 'Red', 'Blond'],\n",
" columns=['Brown', 'Blue', 'Hazel', 'Green'])"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Brown</th>\n",
" <th>Blue</th>\n",
" <th>Hazel</th>\n",
" <th>Green</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>Black</th>\n",
" <td>68</td>\n",
" <td>20</td>\n",
" <td>15</td>\n",
" <td>5</td>\n",
" </tr>\n",
" <tr>\n",
" <th>Brown</th>\n",
" <td>119</td>\n",
" <td>84</td>\n",
" <td>54</td>\n",
" <td>29</td>\n",
" </tr>\n",
" <tr>\n",
" <th>Red</th>\n",
" <td>26</td>\n",
" <td>17</td>\n",
" <td>14</td>\n",
" <td>14</td>\n",
" </tr>\n",
" <tr>\n",
" <th>Blond</th>\n",
" <td>7</td>\n",
" <td>94</td>\n",
" <td>10</td>\n",
" <td>16</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Brown Blue Hazel Green\n",
"Black 68 20 15 5\n",
"Brown 119 84 54 29\n",
"Red 26 17 14 14\n",
"Blond 7 94 10 16"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 576x432 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plt.figure(figsize=(8, 6))\n",
"plt.title('Relation between hair and eye color')\n",
"plt.xlabel('Hair')\n",
"plt.ylabel('Eye')\n",
"assocplot(df.values, df.index, df.columns)\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernel_info": {
"name": "python3"
},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.4"
},
"nteract": {
"version": "0.15.0"
}
},
"nbformat": 4,
"nbformat_minor": 1
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment