Skip to content

Instantly share code, notes, and snippets.

@ricardoV94
Last active June 30, 2021 13:09
Show Gist options
  • Save ricardoV94/b2e3c43b0fba03845cc7bc80e4bd2dd8 to your computer and use it in GitHub Desktop.
Save ricardoV94/b2e3c43b0fba03845cc7bc80e4bd2dd8 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"outputs": [],
"source": [
"import numpy as np\n",
"import aesara\n",
"import aesara.tensor as at\n",
"from aesara.graph.basic import Apply\n",
"from aesara.tensor.random.op import RandomVariable"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": 2,
"outputs": [],
"source": [
"class TruncatedRV(RandomVariable):\n",
"\n",
" def __init__(self, *args, base_op=None, **kwargs):\n",
" self.base_op = base_op\n",
" super().__init__(*args, **kwargs)\n",
"\n",
" def make_node(self, *args, **kwargs):\n",
" *dist_params, lower, upper = args\n",
"\n",
" lower = at.as_tensor_variable(lower)\n",
" upper = at.as_tensor_variable(upper)\n",
" base_apply = super().make_node(*dist_params, **kwargs)\n",
"\n",
" return Apply(\n",
" self,\n",
" [*base_apply.inputs, lower, upper],\n",
" [out.type() for out in base_apply.outputs]\n",
" )\n",
"\n",
" def rng_fn(self, *args):\n",
" rng, *dist_args, lower, upper, size = args\n",
" base_rng_fn = self.base_op.rng_fn\n",
"\n",
" res = base_rng_fn(rng, *dist_args, size)\n",
" is_scalar = np.isscalar(res)\n",
"\n",
" outside_bounds = (res < lower) | (res > upper)\n",
" while np.any(outside_bounds):\n",
" new_res = base_rng_fn(rng, *dist_args, size)\n",
" if is_scalar:\n",
" res = new_res\n",
" else:\n",
" res[outside_bounds] = new_res[outside_bounds]\n",
" outside_bounds = (res < lower) | (res > upper)\n",
"\n",
" return res\n",
"\n",
"def truncate(rv, lower, upper):\n",
" base_op = rv.owner.op\n",
" truncated_rv = TruncatedRV(\n",
" f'truncated_{rv.name}',\n",
" base_op.ndim_supp,\n",
" list(base_op.ndims_params) + [base_op.ndim_supp]*2,\n",
" base_op.dtype,\n",
" inplace=False,\n",
" base_op=base_op,\n",
" )\n",
" rng, size, dtype, *dist_params = rv.owner.inputs\n",
" return truncated_rv(*dist_params, lower, upper, size=size)"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": 3,
"outputs": [],
"source": [
"x = at.random.uniform(0, 1, size=10)\n",
"y = truncate(x, lower=0.4, upper=0.45)"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": 4,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"truncated_None_rv.1 [id A] '' \n",
" |RandomStateSharedVariable(<RandomState(MT19937) at 0x7FCE59180640>) [id B]\n",
" |TensorConstant{(1,) of 10} [id C]\n",
" |TensorConstant{11} [id D]\n",
" |TensorConstant{0} [id E]\n",
" |TensorConstant{1} [id F]\n",
" |TensorConstant{0.4} [id G]\n",
" |TensorConstant{0.45} [id H]\n"
]
}
],
"source": [
"aesara.dprint(y)"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": 5,
"outputs": [
{
"data": {
"text/plain": "array([0.42512075, 0.43904175, 0.44705009, 0.43412428, 0.42623413,\n 0.42201117, 0.42752124, 0.40940581, 0.43109232, 0.43450091])"
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"y.eval()"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment