Skip to content

Instantly share code, notes, and snippets.

@mbjoseph
Created June 4, 2020 14:26
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save mbjoseph/86b621d5c286527c76d1ea8519ff67c7 to your computer and use it in GitHub Desktop.
Save mbjoseph/86b621d5c286527c76d1ea8519ff67c7 to your computer and use it in GitHub Desktop.
Zero one inflated beta distribution in PyTorch
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from torch.distributions import Beta\n",
"from torch.distributions import constraints\n",
"from torch.distributions.exp_family import ExponentialFamily\n",
"from torch.distributions.utils import broadcast_all\n",
"from torch.distributions.dirichlet import Dirichlet\n",
"from numbers import Number"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"class ZOIBeta(ExponentialFamily):\n",
" \"\"\" Zero one inflated Beta distribution\n",
" \n",
" Args: \n",
" p (float or Tensor): Pr(y = 0)\n",
" q (float or Tensor): Pr(y = 1 | y != 0)\n",
" concentration1 (float or Tensor): 1st Beta dist. parameter \n",
" (often referred to as alpha)\n",
" concentration0 (float or Tensor): 2nd Beta dist. parameter\n",
" (often referred to as beta)\n",
" \"\"\"\n",
" \n",
" arg_constraints = {\n",
" 'p': constraints.unit_interval, \n",
" 'q': constraints.unit_interval, \n",
" 'concentration1': constraints.positive, \n",
" 'concentration0': constraints.positive\n",
" }\n",
" support = constraints.unit_interval # does this include 0 and 1?\n",
" has_rsample = False\n",
" \n",
" def __init__(self, p, q, concentration1, concentration0, validate_args=None):\n",
" if isinstance(concentration1, Number) and isinstance(concentration0, Number):\n",
" concentration1_concentration0 = torch.tensor([float(concentration1), float(concentration0)])\n",
" else:\n",
" concentration1, concentration0 = broadcast_all(concentration1, concentration0)\n",
" concentration1_concentration0 = torch.stack([concentration1, concentration0], -1)\n",
" self._dirichlet = Dirichlet(concentration1_concentration0)\n",
" self.log_p = torch.log(p)\n",
" self.log1m_p = torch.log(1 - p)\n",
" self.log_q = torch.log(q)\n",
" self.log1m_q = torch.log(1 - q)\n",
" super(ZOIBeta, self).__init__(self._dirichlet._batch_shape, validate_args=validate_args)\n",
" \n",
" def beta_lp(self, value):\n",
" if self._validate_args:\n",
" self._validate_sample(value)\n",
" heads_tails = torch.stack([value, 1.0 - value], -1)\n",
" return self._dirichlet.log_prob(heads_tails)\n",
" \n",
" def log_prob(self, value):\n",
" lp = torch.zeros_like(value, dtype = torch.float)\n",
" if any (0. < value < 1.): \n",
" beta_idx = torch.where(0. < value < 1.)\n",
" lp[beta_idx] = self.log1m_p + self.log1m_q + self.beta_lp(value[beta_idx])\n",
" lp[torch.where(value == 0.)] = self.log_p\n",
" lp[torch.where(value == 1.)] = self.log1m_p + self.log_q\n",
" return lp\n",
" "
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"zoib = ZOIBeta(p=torch.tensor(.5), \n",
" q=torch.tensor(.3), \n",
" concentration1=torch.tensor(1.), \n",
" concentration0=torch.tensor(1.))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Check that the log probabilities returned are what we expect\n",
"\n",
"For zeros:"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[-0.6931]])"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"zoib.log_prob(torch.tensor(0.).view(1, 1))"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor(-0.6931)"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"zoib.log_p"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"For ones:"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[-1.8971]])"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"zoib.log_prob(torch.tensor(1.).view(1, 1))"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor(-1.8971)"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"zoib.log1m_p + zoib.log_q"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"For proportions (if concentration1 and concentration0 are both 1, then we have a uniform Beta prior and we should get the same log probability for all values between 0 and 1:"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[-1.0498]])"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"zoib.log_prob(torch.tensor(.4).view(1, 1))"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[-1.0498]])"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"zoib.log_prob(torch.tensor(.9).view(1, 1))"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[-1.0498]])"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"zoib.log_prob(torch.tensor(.2).view(1, 1))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.7"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment