Skip to content

Instantly share code, notes, and snippets.

@ColCarroll
Created May 30, 2018 12:23
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ColCarroll/c607842947b08bc44d4e1588e6bef98d to your computer and use it in GitHub Desktop.
Save ColCarroll/c607842947b08bc44d4e1588e6bef98d to your computer and use it in GitHub Desktop.
Sample implementation of `to_xarray`
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Implementation of *to_xarray*\n",
"\n",
"This is three functions: \n",
"\n",
"- `to_xarray` does the real work, and is intended to be used directly\n",
"- `default_varnames_coords_dims` extracts or sets common defaults, \n",
"- `verify_coords_dims` does some extra work to try to make a helpful error message.\n",
"\n",
"More specifically, the last function could be excluded, but I found getting my head around what exactly `xarray` was doing and how the syntax works a little confusing. The extra work on the error message is meant to ease the transition. There are examples of the messages below."
]
},
{
"cell_type": "code",
"execution_count": 455,
"metadata": {},
"outputs": [],
"source": [
"import xarray as xr\n",
"\n",
"\n",
"def to_xarray(trace, coords=None, dims=None):\n",
" \"\"\"Convert a pymc3 trace to an xarray dataset.\n",
"\n",
" Parameters\n",
" ----------\n",
" trace : pymc3 trace\n",
" coords : dict\n",
" A dictionary containing the values that are used as index. The key\n",
" is the name of the dimension, the values are the index values.\n",
" dims : dict[str, Tuple(str)]\n",
" A mapping from pymc3 variables to a tuple corresponding to\n",
" the shape of the variable, where the elements of the tuples are\n",
" the names of the coordinate dimensions.\n",
" \"\"\"\n",
"\n",
" varnames, coords, dims = default_varnames_coords_dims(trace, coords, dims, include_transformed=False)\n",
" \n",
" verified, warning = verify_coords_dims(varnames, trace, coords, dims)\n",
" \n",
" data = xr.Dataset(coords=coords)\n",
" base_dims = ['chain', 'sample']\n",
" for key in varnames:\n",
" vals = trace.get_values(key, combine=False, squeeze=False)\n",
" vals = np.array(vals)\n",
" dims_str = base_dims + dims[key]\n",
" try:\n",
" data[key] = xr.DataArray(vals, coords={v: coords[v] for v in dims_str}, dims=dims_str)\n",
" except KeyError as exc:\n",
" if not verified:\n",
" raise exc from TypeError(warning)\n",
" else:\n",
" raise exc\n",
" \n",
" return data\n",
"\n",
"\n",
"def default_varnames_coords_dims(trace, coords, dims, include_transformed):\n",
" \"\"\"Set up varnames, coordinates, and dimensions for .to_xarray function\"\"\"\n",
" varnames = pm.utils.get_default_varnames(trace.varnames, include_transformed=include_transformed)\n",
" if coords is None:\n",
" coords = {}\n",
"\n",
" coords['sample'] = np.arange(len(trace))\n",
" coords['chain'] = np.arange(trace.nchains)\n",
" coords = {key: xr.IndexVariable((key,), data=vals) for key, vals in coords.items()}\n",
"\n",
" if dims is None:\n",
" dims = {}\n",
"\n",
" for varname in varnames:\n",
" dims.setdefault(varname, [])\n",
" \n",
" return varnames, coords, dims\n",
"\n",
"\n",
"def verify_coords_dims(varnames, trace, coords, dims):\n",
" \"\"\"Light checking and guessing on the structure of an xarray for a PyMC3 trace\n",
" \n",
" Parameters\n",
" ----------\n",
" varnames : iterable[string]\n",
" list of dims for the xarray\n",
" trace : pymc3.Multitrace\n",
" trace from pymc3 run\n",
" coords : dict\n",
" output of `default_varnames_coords_dims`\n",
" dims : dict\n",
" output of `default_varnames_coords_dims`\n",
" \n",
" Returns\n",
" -------\n",
" bool Whether it passes the check\n",
" str Warning string in case it does not pass\n",
" \"\"\"\n",
" inferred_coords = coords.copy() \n",
" inferred_dims = dims.copy()\n",
" for key in ('sample', 'chain'):\n",
" inferred_coords.pop(key)\n",
" global_coords = {}\n",
" throw = False\n",
" \n",
" for varname in varnames: \n",
" vals = trace.get_values(varname, combine=False, squeeze=False)\n",
" shapes = [d for shape in coords.values() for d in shape.shape]\n",
" for idx, shape in enumerate(vals[0].shape[1:], 1):\n",
" try:\n",
" shapes.remove(shape)\n",
" except ValueError:\n",
" throw = True\n",
" if shape not in global_coords:\n",
" global_coords[shape] = f'{varname}_dim_{idx}'\n",
" key = global_coords[shape]\n",
" inferred_dims[varname].append(key)\n",
" if key not in inferred_coords:\n",
" inferred_coords[key] = f'np.arange({shape})'\n",
" if throw:\n",
" inferred_dims = {k: v for k, v in inferred_dims.items() if v}\n",
" return False, f'Bad arguments! Try setting\\ncoords={inferred_coords}\\ndims={inferred_dims}'\n",
" return True, ''"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Example on eight schools model"
]
},
{
"cell_type": "code",
"execution_count": 451,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Auto-assigning NUTS sampler...\n",
"INFO:pymc3:Auto-assigning NUTS sampler...\n",
"Initializing NUTS using jitter+adapt_diag...\n",
"INFO:pymc3:Initializing NUTS using jitter+adapt_diag...\n",
"Multiprocess sampling (4 chains in 4 jobs)\n",
"INFO:pymc3:Multiprocess sampling (4 chains in 4 jobs)\n",
"NUTS: [theta_tilde, tau_log__, mu]\n",
"INFO:pymc3:NUTS: [theta_tilde, tau_log__, mu]\n",
"100%|██████████| 1000/1000 [00:01<00:00, 633.65it/s]\n",
"There were 2 divergences after tuning. Increase `target_accept` or reparameterize.\n",
"ERROR:pymc3:There were 2 divergences after tuning. Increase `target_accept` or reparameterize.\n",
"There were 1 divergences after tuning. Increase `target_accept` or reparameterize.\n",
"ERROR:pymc3:There were 1 divergences after tuning. Increase `target_accept` or reparameterize.\n",
"There were 2 divergences after tuning. Increase `target_accept` or reparameterize.\n",
"ERROR:pymc3:There were 2 divergences after tuning. Increase `target_accept` or reparameterize.\n"
]
}
],
"source": [
"# Data of the Eight Schools Model\n",
"J = 8\n",
"y = np.array([28., 8., -3., 7., -1., 1., 18., 12.])\n",
"sigma = np.array([15., 10., 16., 11., 9., 11., 10., 18.])\n",
"\n",
"with pm.Model() as non_centered:\n",
" mu = pm.Normal('mu', mu=0, sd=5)\n",
" tau = pm.HalfCauchy('tau', beta=5)\n",
" theta_tilde = pm.Normal('theta_tilde', mu=0, sd=1, shape=J)\n",
" theta = pm.Deterministic('theta', mu + tau * theta_tilde)\n",
" obs = pm.Normal('obs', mu=theta, sd=sigma, observed=y)\n",
" non_centered_eight_trace = pm.sample()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Bad invocation\n",
"\n",
"The error message actually gets the correct syntax right, up to the names"
]
},
{
"cell_type": "code",
"execution_count": 456,
"metadata": {},
"outputs": [
{
"ename": "KeyError",
"evalue": "'theta_tilde_dim_1'",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
"\u001b[0;31mTypeError\u001b[0m: Bad arguments! Try setting\ncoords={'theta_tilde_dim_1': 'np.arange(8)'}\ndims={'theta_tilde': ['theta_tilde_dim_1'], 'theta': ['theta_tilde_dim_1']}",
"\nThe above exception was the direct cause of the following exception:\n",
"\u001b[0;31mKeyError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-456-9cf944f0cc2a>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mto_xarray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnon_centered_eight_trace\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;32m<ipython-input-455-9a4106d779b3>\u001b[0m in \u001b[0;36mto_xarray\u001b[0;34m(trace, coords, dims)\u001b[0m\n\u001b[1;32m 31\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mKeyError\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mexc\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 32\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mverified\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 33\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mexc\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mTypeError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mwarning\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 34\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 35\u001b[0m \u001b[0;32mraise\u001b[0m \u001b[0mexc\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m<ipython-input-455-9a4106d779b3>\u001b[0m in \u001b[0;36mto_xarray\u001b[0;34m(trace, coords, dims)\u001b[0m\n\u001b[1;32m 28\u001b[0m \u001b[0mdims_str\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mbase_dims\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mdims\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mkey\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 29\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 30\u001b[0;31m \u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mkey\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mxr\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mDataArray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mvals\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcoords\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m{\u001b[0m\u001b[0mv\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mcoords\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mv\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mv\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mdims_str\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdims\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdims_str\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 31\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mKeyError\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mexc\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 32\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mverified\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m<ipython-input-455-9a4106d779b3>\u001b[0m in \u001b[0;36m<dictcomp>\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m 28\u001b[0m \u001b[0mdims_str\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mbase_dims\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mdims\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mkey\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 29\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 30\u001b[0;31m \u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mkey\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mxr\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mDataArray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mvals\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcoords\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m{\u001b[0m\u001b[0mv\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mcoords\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mv\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mv\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mdims_str\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdims\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdims_str\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 31\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mKeyError\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mexc\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 32\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mverified\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mKeyError\u001b[0m: 'theta_tilde_dim_1'"
]
}
],
"source": [
"to_xarray(non_centered_eight_trace, {}, {})"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Good invocation"
]
},
{
"cell_type": "code",
"execution_count": 473,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<xarray.Dataset>\n",
"Dimensions: (chain: 4, sample: 500, school: 8)\n",
"Coordinates:\n",
" * school (school) int64 0 1 2 3 4 5 6 7\n",
" * sample (sample) int64 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 ...\n",
" * chain (chain) int64 0 1 2 3\n",
"Data variables:\n",
" mu (chain, sample) float64 5.373 11.18 -0.9847 10.68 -2.413 ...\n",
" theta_tilde (chain, sample, school) float64 -0.8934 -0.9305 -0.0413 ...\n",
" tau (chain, sample) float64 1.118 2.065 2.029 3.968 0.6049 ...\n",
" theta (chain, sample, school) float64 4.375 4.333 5.327 5.727 ..."
]
},
"execution_count": 473,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data = to_xarray(non_centered_eight_trace, \n",
" coords={'school': np.arange(8)}, \n",
" dims={'theta_tilde': ['school'], 'theta': ['school']}\n",
" )\n",
"data"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Example on oncology data\n",
"\n",
"Taken from [aseybolt](https://gist.github.com/aseyboldt/7eb724a21b21165a4d10c936ae1bcdba)'s gist."
]
},
{
"cell_type": "code",
"execution_count": 475,
"metadata": {},
"outputs": [],
"source": [
"N = 200\n",
"\n",
"coords = {\n",
" 'subject': np.array([f'm{i:03}' for i in range(N)]),\n",
" 'treatment': np.array(['Sorafenib', 'Lurbinectedin']),\n",
" 'oncogene': np.array(['P19', 'MYC', 'AKT']),\n",
"}\n",
"\n",
"data = xr.Dataset(coords=coords)\n",
"\n",
"data['treated_idx'] = (\n",
" 'subject',\n",
" np.random.randint(2, size=N))\n",
"data['treated'] = (\n",
" 'subject',\n",
" data['treatment'].isel_points(data.subject, treatment=data.treated_idx))\n",
"\n",
"data['genotype_idx'] = (\n",
" 'subject',\n",
" np.random.randint(3, size=N))\n",
"data['genotype'] = (\n",
" 'subject',\n",
" data['oncogene'].isel_points(data.subject, oncogene=data.genotype_idx))\n",
"\n",
"data.set_coords(data.variables, inplace=True)\n",
"data['true_treatment_effect'] = (\n",
" 'treatment',\n",
" 0.08 * np.random.randn(data.dims['treatment']))\n",
"\n",
"data['true_interaction'] = (\n",
" ('oncogene', 'treatment'),\n",
" 0.05 * np.random.randn(data.dims['oncogene'], data.dims['treatment']))\n",
"\n",
"data['true_intercept'] = np.log(30.)\n",
"data['true_sigma'] = 0.13\n",
"data['true_expected_survival'] = (\n",
" 'subject',\n",
" data['true_intercept']\n",
" + data.true_treatment_effect.sel_points(\n",
" data.subject,\n",
" treatment=data.treated)\n",
" + data.true_interaction.sel_points(\n",
" data.subject,\n",
" oncogene=data.genotype,\n",
" treatment=data.treated))\n",
"\n",
"data['survival'] = (\n",
" 'subject',\n",
" data['true_expected_survival']\n",
" + data['true_sigma'].values * np.random.randn(N))"
]
},
{
"cell_type": "code",
"execution_count": 476,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Auto-assigning NUTS sampler...\n",
"INFO:pymc3:Auto-assigning NUTS sampler...\n",
"Initializing NUTS using jitter+adapt_diag...\n",
"INFO:pymc3:Initializing NUTS using jitter+adapt_diag...\n",
"Multiprocess sampling (4 chains in 4 jobs)\n",
"INFO:pymc3:Multiprocess sampling (4 chains in 4 jobs)\n",
"NUTS: [sigma_log__, interaction, interaction_sd_log__, treatment_effect, treatment_sd_log__, intercept]\n",
"INFO:pymc3:NUTS: [sigma_log__, interaction, interaction_sd_log__, treatment_effect, treatment_sd_log__, intercept]\n",
"100%|██████████| 1000/1000 [00:18<00:00, 55.18it/s]\n",
"There were 4 divergences after tuning. Increase `target_accept` or reparameterize.\n",
"ERROR:pymc3:There were 4 divergences after tuning. Increase `target_accept` or reparameterize.\n",
"There were 9 divergences after tuning. Increase `target_accept` or reparameterize.\n",
"ERROR:pymc3:There were 9 divergences after tuning. Increase `target_accept` or reparameterize.\n",
"There were 12 divergences after tuning. Increase `target_accept` or reparameterize.\n",
"ERROR:pymc3:There were 12 divergences after tuning. Increase `target_accept` or reparameterize.\n",
"The acceptance probability does not match the target. It is 0.7084448760353492, but should be close to 0.8. Try to increase the number of tuning steps.\n",
"WARNING:pymc3:The acceptance probability does not match the target. It is 0.7084448760353492, but should be close to 0.8. Try to increase the number of tuning steps.\n",
"There were 41 divergences after tuning. Increase `target_accept` or reparameterize.\n",
"ERROR:pymc3:There were 41 divergences after tuning. Increase `target_accept` or reparameterize.\n",
"The acceptance probability does not match the target. It is 0.623557745909284, but should be close to 0.8. Try to increase the number of tuning steps.\n",
"WARNING:pymc3:The acceptance probability does not match the target. It is 0.623557745909284, but should be close to 0.8. Try to increase the number of tuning steps.\n",
"The gelman-rubin statistic is larger than 1.05 for some parameters. This indicates slight problems during sampling.\n",
"INFO:pymc3:The gelman-rubin statistic is larger than 1.05 for some parameters. This indicates slight problems during sampling.\n",
"The estimated number of effective samples is smaller than 200 for some parameters.\n",
"ERROR:pymc3:The estimated number of effective samples is smaller than 200 for some parameters.\n"
]
}
],
"source": [
"with pm.Model() as model:\n",
" intercept = pm.Flat('intercept')\n",
" \n",
" treat_sd = pm.HalfStudentT('treatment_sd', nu=3, sd=0.1)\n",
" treatment = pm.Normal('treatment_effect', shape=data.treatment.shape)\n",
" \n",
" interact_sd = pm.HalfStudentT('interaction_sd', nu=3, sd=0.1)\n",
" interaction = pm.Normal('interaction', shape=data.oncogene.shape + data.treatment.shape)\n",
" \n",
" mu = (intercept\n",
" + treat_sd * treatment[data.treated_idx.values]\n",
" + interact_sd * interaction[data.genotype_idx.values, data.treated_idx.values])\n",
" sigma = pm.HalfStudentT('sigma', nu=3, sd=0.2)\n",
" pm.Normal('y', mu=mu, sd=sigma, observed=data.survival.values)\n",
" new_trace = pm.sample()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Bad invocation\n",
"\n",
"Again, it does a decent job guessing the correct structure of the data. It can definitely go wrong!"
]
},
{
"cell_type": "code",
"execution_count": 477,
"metadata": {},
"outputs": [
{
"ename": "KeyError",
"evalue": "'treatment_effect_dim_1'",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
"\u001b[0;31mTypeError\u001b[0m: Bad arguments! Try setting\ncoords={'treatment_effect_dim_1': 'np.arange(2)', 'interaction_dim_1': 'np.arange(3)'}\ndims={'treatment_effect': ['treatment_effect_dim_1'], 'interaction': ['interaction_dim_1', 'treatment_effect_dim_1']}",
"\nThe above exception was the direct cause of the following exception:\n",
"\u001b[0;31mKeyError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-477-517305d5e90c>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mto_xarray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnew_trace\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;32m<ipython-input-455-9a4106d779b3>\u001b[0m in \u001b[0;36mto_xarray\u001b[0;34m(trace, coords, dims)\u001b[0m\n\u001b[1;32m 31\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mKeyError\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mexc\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 32\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mverified\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 33\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mexc\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mTypeError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mwarning\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 34\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 35\u001b[0m \u001b[0;32mraise\u001b[0m \u001b[0mexc\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m<ipython-input-455-9a4106d779b3>\u001b[0m in \u001b[0;36mto_xarray\u001b[0;34m(trace, coords, dims)\u001b[0m\n\u001b[1;32m 28\u001b[0m \u001b[0mdims_str\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mbase_dims\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mdims\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mkey\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 29\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 30\u001b[0;31m \u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mkey\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mxr\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mDataArray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mvals\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcoords\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m{\u001b[0m\u001b[0mv\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mcoords\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mv\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mv\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mdims_str\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdims\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdims_str\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 31\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mKeyError\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mexc\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 32\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mverified\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m<ipython-input-455-9a4106d779b3>\u001b[0m in \u001b[0;36m<dictcomp>\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m 28\u001b[0m \u001b[0mdims_str\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mbase_dims\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mdims\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mkey\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 29\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 30\u001b[0;31m \u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mkey\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mxr\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mDataArray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mvals\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcoords\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m{\u001b[0m\u001b[0mv\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mcoords\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mv\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mv\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mdims_str\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdims\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdims_str\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 31\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mKeyError\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mexc\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 32\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mverified\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mKeyError\u001b[0m: 'treatment_effect_dim_1'"
]
}
],
"source": [
"to_xarray(new_trace)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Good invocation\n",
"\n",
"Adrian adds more coords than this, though I am not sure exactly what effect that has."
]
},
{
"cell_type": "code",
"execution_count": 482,
"metadata": {},
"outputs": [],
"source": [
"trace_xr = to_xarray(new_trace,\n",
" coords={'treatment': data.treatment.values, 'oncogene': data.oncogene.values},\n",
" dims={'treatment_effect': ['treatment'], 'interaction': ['oncogene', 'treatment']})"
]
},
{
"cell_type": "code",
"execution_count": 483,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<xarray.Dataset>\n",
"Dimensions: (chain: 4, oncogene: 3, sample: 500, treatment: 2)\n",
"Coordinates:\n",
" * treatment (treatment) <U13 'Sorafenib' 'Lurbinectedin'\n",
" * oncogene (oncogene) <U3 'P19' 'MYC' 'AKT'\n",
" * sample (sample) int64 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 ...\n",
" * chain (chain) int64 0 1 2 3\n",
"Data variables:\n",
" intercept (chain, sample) float64 3.408 3.31 3.205 3.572 3.604 ...\n",
" treatment_effect (chain, sample, treatment) float64 -1.29 -1.433 ...\n",
" interaction (chain, sample, oncogene, treatment) float64 -0.6027 ...\n",
" treatment_sd (chain, sample) float64 0.0577 0.1139 0.06507 0.04131 ...\n",
" interaction_sd (chain, sample) float64 0.1007 0.0456 0.08703 0.1225 ...\n",
" sigma (chain, sample) float64 0.1367 0.1288 0.1283 0.1376 ..."
]
},
"execution_count": 483,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"trace_xr"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "arviz3.6",
"language": "python",
"name": "arviz3_6"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.5"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment