Skip to content

Instantly share code, notes, and snippets.

@ricardoV94
Last active March 11, 2023 16:11
Show Gist options
  • Save ricardoV94/7aa6c23726f20fbec95a1cda449c15f5 to your computer and use it in GitHub Desktop.
Save ricardoV94/7aa6c23726f20fbec95a1cda449c15f5 to your computer and use it in GitHub Desktop.
PyTensor RandomVariables.ipynb
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "sGbFykg31R1O"
},
"source": [
"# Pseudo random number generation in PyTensor"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"PyTensor has native support for [pseudo random number generation (PRNG)](https://en.wikipedia.org/wiki/Pseudorandom_number_generator). \n",
"This document describes how PRNGs are implemented in PyTensor, \n",
"via the RandomVariable Operator. \n",
"\n",
"We also discuss how initial seeding and seeding updates are implemented, \n",
"and some harder cases such as using RandomVariables inside Scan, \n",
"or with other backends like JAX.\n",
"\n",
"We will use PRNG and RNG interchangeably, \n",
"keeping in mind we are always talking about PRNGs."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## The basics"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Numpy"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To start off, let's recall how PRNGs works in NumPy"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"id": "30xA3h1rtC-k"
},
"outputs": [],
"source": [
"import numpy as np"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "ZNwU5Hln1ZIk",
"outputId": "8409dc56-8233-4558-aab4-125b3f8b69b3"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0.68235186 0.05382102] [0.22035987 0.18437181]\n"
]
}
],
"source": [
"rng = np.random.default_rng(seed=123)\n",
"print(rng.uniform(size=2), rng.uniform(size=2))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In the first line `np.random.default_rng(seed)` creates a random Generator."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Generator(PCG64) at 0x7F6C04535820"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"rng"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Every numpy Generator holds a BitGenerator, which is able to generate high-quality sequences of pseudo random bits. Numpy generators convert these sequences of bits into sequences of numbers that follow a specific statistical distribution. For more details, you can read https://numpy.org/doc/stable/reference/random"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<numpy.random._pcg64.PCG64 at 0x7f6c045030f0>"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"rng.bit_generator"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'bit_generator': 'PCG64',\n",
" 'state': {'state': 143289216567205249174526524509312027761,\n",
" 'inc': 17686443629577124697969402389330893883},\n",
" 'has_uint32': 0,\n",
" 'uinteger': 0}"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"rng.bit_generator.state"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"When we call `rng.uniform(size=2)`, the Generator class requested a new array of pseudo random bits (state) from the BitGenerator, and used a deterministic mapping function to convert those into a float64 numbers. It did this twice, because we requested two draws via the `size` argument. In the long-run this deterministic mapping function should produce draws that are statistically indistinguishable from a true uniform distribution.\n",
"\n",
"For illustration we implement a very bad mapping function from a bit generator to uniform draws."
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"def bad_uniform_rng(rng, size):\n",
" bit_generator = rng.bit_generator\n",
" \n",
" uniform_draws = np.empty(size)\n",
" for i in range(size):\n",
" bit_generator.advance(1)\n",
" state = rng.bit_generator.state[\"state\"][\"state\"]\n",
" last_3_digits = state % 1_000\n",
" uniform_draws[i] = (last_3_digits + 1) / 1_000\n",
" return uniform_draws"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([0.033, 0.972, 0.459, 0.71 , 0.765])"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"bad_uniform_rng(rng, size=5)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Scipy"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Scipy wraps these Numpy routines in a slightly different API."
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"import scipy.stats as st"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "SEiSAo1e3BBN",
"outputId": "5dc4be27-e1e3-4dd9-8dc8-d1c048bdef54"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0.68235186 0.05382102] [0.22035987 0.18437181]\n"
]
}
],
"source": [
"rng = np.random.default_rng(seed=123)\n",
"print(st.uniform.rvs(size=2, random_state=rng), st.uniform.rvs(size=2, random_state=rng))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Pytensor"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"PyTensor does not implement its own bit/generators methods. \n",
"Just like Scipy, it borrows NumPy routines directly.\n",
"\n",
"The low-level API of PyTensor RNGs is similar to that of SciPy, \n",
"whereas the higher-level API of RandomStreams is more like that of NumPy. \n",
"We will look at RandomStreams shortly, but we will start with the low-level API."
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"import pytensor\n",
"import pytensor.tensor as pt"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"id": "u3ZZr_jts9QJ"
},
"outputs": [],
"source": [
"rng = pt.random.type.RandomGeneratorType()(\"rng\")\n",
"x = pt.random.uniform(size=2, rng=rng)\n",
"f = pytensor.function([rng], x)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We created a function that takes a Numpy RandomGenerator and returns two uniform draws. Let's evaluate it"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "xHNx9Qdv6ACP",
"outputId": "22af5f3b-c054-4829-9cdc-3dc1bb88a5e0"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0.68235186 0.05382102] [0.68235186 0.05382102]\n"
]
}
],
"source": [
"rng_val = np.random.default_rng(123)\n",
"print(f(rng_val), f(rng_val))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The first numbers were exactly the same as the numpy and scipy calls, \n",
"because we are using the very same routines.\n",
"\n",
"Perhaps surprisingly, we got the same results when we called the function the second time!\n",
"This is because PyTensor functions do not hold an internal state \n",
"and do not modify inputs inplace unless requested to.\n",
"\n",
"We made sure that the `rng_val` was not modified when calling our Pytensor function,\n",
"by copying it before using it. \n",
"This may feel inneficient (and it is), but PyTensor is built on a pure functional approach,\n",
"which is not allowed to have side-effects (such as changing global variables) by default. \n",
"\n",
"We will later see how we can get around this issue by making the inputs mutable\n",
"or using shared variables with explicit update rules. \n",
"\n",
"Before that, let's convince ourselves we can actually get different draws, \n",
"when we modify the bit generator of our input RNG."
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "S6hlUqljtb9C",
"outputId": "8d0db6d8-2040-4e49-9dc0-b164013e160d"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0.05382102 0.22035987] [0.05382102 0.22035987]\n"
]
}
],
"source": [
"rng_val.bit_generator.advance(1)\n",
"print(f(rng_val), f(rng_val))"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "yMpQVFOybo80",
"outputId": "547cd9b6-166f-48fb-bc9c-c216ee89a200"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0.22035987 0.18437181] [0.22035987 0.18437181]\n"
]
}
],
"source": [
"rng_val.bit_generator.advance(1)\n",
"print(f(rng_val), f(rng_val))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Updating the bit generator manually is not a good practice.\n",
"For starters, it may be unclear how much we have to advance it!\n",
"\n",
"In this case we had to advance it twice to get two completely new draws, \n",
"because the inner function uses two states. \n",
"But other distributions could need more states for a single draw, \n",
"or they could be clever and reuse the same state for multiple draws.\n",
"\n",
"Because it is not in generally possible to know how much one should modify the generator's bit generator,\n",
"PyTensor RandomVariables actually return the copied generator as a hidden output. \n",
"This copied generator can be safely used again because it contains the bit generator \n",
"that was already modified when taking draws."
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(RandomGeneratorType, TensorType(float64, (2,)))"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"next_rng, x = x.owner.outputs\n",
"next_rng.type, x.type"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "UOPQx4iOtMYy",
"outputId": "4c5a0dc2-e0f4-4923-a96a-70f57c777c5b"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"uniform_rv{0, (0, 0), floatX, False}.0 [id A] <RandomGeneratorType> 'next_rng'\n",
" |rng [id B] <RandomGeneratorType>\n",
" |TensorConstant{(1,) of 2} [id C] <TensorType(int64, (1,))>\n",
" |TensorConstant{11} [id D] <TensorType(int64, ())>\n",
" |TensorConstant{0.0} [id E] <TensorType(float32, ())>\n",
" |TensorConstant{1.0} [id F] <TensorType(float32, ())>\n",
"uniform_rv{0, (0, 0), floatX, False}.1 [id A] <TensorType(float64, (2,))> 'x'\n"
]
},
{
"data": {
"text/plain": [
"<ipykernel.iostream.OutStream at 0x7f6c308c3eb0>"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"next_rng.name = \"next_rng\"\n",
"x.name = \"x\"\n",
"pytensor.dprint([next_rng, x], print_type=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can see the single node with `[id A]`, has two outputs, which we named `next_rng` and `x`. \n",
"By default only the second output `x` is given to the user directly, and the other is \"hidden\".\n",
"\n",
"We can compile a function that returns the `next_rng` explicitly,\n",
"so that we can use it as the input of the function in subsequent calls."
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "s5Ynk8FCwuCh",
"outputId": "2dcbf788-123e-42d9-a30f-a6959b4327a7"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0.68235186 0.05382102]\n",
"[0.22035987 0.18437181]\n",
"[0.1759059 0.81209451]\n"
]
}
],
"source": [
"f = pytensor.function([rng], [next_rng, x])\n",
"\n",
"rng_val = np.random.default_rng(123)\n",
"next_rng_val, x = f(rng_val)\n",
"print(x)\n",
"\n",
"next_rng_val, x = f(next_rng_val)\n",
"print(x)\n",
"\n",
"next_rng_val, x = f(next_rng_val)\n",
"print(x)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Shared variables"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"At this point we can make use of PyTensor shared variables.\n",
"Shared variables are global variables that don't need\n",
"(and can't) be passed as explicit inputs \n",
"to the functions where they are used."
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "65D2bmwCtzKT",
"outputId": "0d79fde2-9c4a-4d62-93e8-b2d677aa678d"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.6823518632481435\n"
]
}
],
"source": [
"rng = pytensor.shared(np.random.default_rng(123))\n",
"next_rng, x = pt.random.uniform(rng=rng).owner.outputs\n",
"\n",
"f = pytensor.function([], [next_rng, x])\n",
"\n",
"next_rng_val, x = f()\n",
"print(x)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can update the value of shared variables across calls."
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.053821018802222675\n",
"0.22035987277261138\n"
]
}
],
"source": [
"rng.set_value(next_rng_val)\n",
"next_rng_val, x = f()\n",
"print(x)\n",
"\n",
"rng.set_value(next_rng_val)\n",
"next_rng_val, x = f()\n",
"print(x)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The real benefit of using shared variables is that we can automate this updating\n",
"via the aptly named `updates` kwarg of PyTensor functions.\n",
"\n",
"In this case it makes sense to simply replace the original value by the `next_rng_val` (there is not really any other operation we can do with PyTensor RNGs)"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "pB3O4MtsuK-l",
"outputId": "00972cda-8c79-4128-c1a7-52bb0729afc1"
},
"outputs": [
{
"data": {
"text/plain": [
"(array(0.68235186), array(0.05382102), array(0.22035987))"
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"rng = pytensor.shared(np.random.default_rng(123))\n",
"next_rng, x = pt.random.uniform(rng=rng).owner.outputs\n",
"\n",
"f = pytensor.function([], x, updates={rng: next_rng})\n",
"\n",
"f(), f(), f()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Another way of doing that is setting a default_update in the shared RNG variable"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "Sm7-qpULufpO",
"outputId": "5fe41106-0345-409b-a842-6953710e7d26"
},
"outputs": [
{
"data": {
"text/plain": [
"(array(0.68235186), array(0.05382102), array(0.22035987))"
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"rng = pytensor.shared(np.random.default_rng(123))\n",
"next_rng, x = pt.random.uniform(rng=rng).owner.outputs\n",
"\n",
"rng.default_update = next_rng\n",
"f = pytensor.function([], x)\n",
"\n",
"f(), f(), f()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Which is exactly what RandomStream does behind the scenes"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "iuYLU83eusre",
"outputId": "ae8faa4d-e9f7-4d08-9113-3d841d7ac522"
},
"outputs": [
{
"data": {
"text/plain": [
"(RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7F6BEB5E79E0>),\n",
" uniform_rv{0, (0, 0), floatX, False}.0)"
]
},
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"srng = pt.random.RandomStream(seed=123)\n",
"x = srng.uniform()\n",
"x.owner.inputs[0], x.owner.inputs[0].default_update"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "9mtwrC_zu6t0",
"outputId": "b2e53f4f-7a24-4601-82a3-1333b8ad97aa"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.19365083425294516 0.7541389670292019 0.2762903411491048\n"
]
}
],
"source": [
"f = pytensor.function([], x)\n",
"print(f(), f(), f())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Shared RNGs are created by default "
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bQvIVgy41e7L"
},
"source": [
"If no rng is provided to a RandomVariable Op, a shared RandomGenerator is created automatically.\n",
"\n",
"This can give the appearance that PyTensor functions of random variables don't have any variable inputs, \n",
"but this is not true.\n",
"They are simply shared variables."
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "eYZv7lduygxL",
"outputId": "d693d35e-06aa-403a-cf4d-669d7e2ce486"
},
"outputs": [
{
"data": {
"text/plain": [
"RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7F6BEB5C5C80>)"
]
},
"execution_count": 24,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x = pt.random.normal()\n",
"x.owner.inputs[0]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Reseeding"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Shared RNG variables can be \"reseeded\" by setting them to the original RNG"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "JSt1xT-2xbUX",
"outputId": "58b28a7e-556b-442c-cba4-3198b5aff98e"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"-0.9891213503478509 -0.3677866514678832\n",
"-0.9891213503478509 -0.3677866514678832\n"
]
}
],
"source": [
"rng = pytensor.shared(np.random.default_rng(123))\n",
"next_rng, x = pt.random.normal(rng=rng).owner.outputs\n",
"\n",
"rng.default_update = next_rng\n",
"f = pytensor.function([], x)\n",
"\n",
"print(f(), f())\n",
"rng.set_value(np.random.default_rng(123))\n",
"print(f(), f())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"RandomStreams provide a helper method to achieve the same"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "TLP1QjsaxqcT",
"outputId": "41a57b97-1a7a-4bfa-8897-73db51a893ef"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"-0.5812234917408711 -0.047499225218726786\n",
"-0.5812234917408711 -0.047499225218726786\n"
]
}
],
"source": [
"srng = pt.random.RandomStream(seed=123)\n",
"x = srng.normal()\n",
"f = pytensor.function([], x)\n",
"\n",
"print(f(), f())\n",
"srng.seed(123)\n",
"print(f(), f())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Inplace optimization"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"As mentioned before, \n",
"by default RandomVariables return a copy of the next RNG state, \n",
"which can be quite slow."
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "JpWnykYj2ymu",
"outputId": "a31d91e1-b23e-4c73-92bb-f3241ad49f17"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"uniform_rv{0, (0, 0), floatX, False}.1 [id A] 'x' 0\n",
" |rng [id B]\n",
" |TensorConstant{[]} [id C]\n",
" |TensorConstant{11} [id D]\n",
" |TensorConstant{0.0} [id E]\n",
" |TensorConstant{1.0} [id F]\n"
]
},
{
"data": {
"text/plain": [
"<ipykernel.iostream.OutStream at 0x7f6c308c3eb0>"
]
},
"execution_count": 27,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"rng = np.random.default_rng(123)\n",
"rng_shared = pytensor.shared(rng, name=\"rng\")\n",
"x = pt.random.uniform(rng=rng_shared, name=\"x\")\n",
"f = pytensor.function([], x)\n",
"pytensor.dprint(f, print_destroy_map=True)"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"169 µs ± 24.6 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)\n"
]
}
],
"source": [
"%timeit f()"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"3.56 µs ± 106 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)\n"
]
}
],
"source": [
"%timeit rng.uniform()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Like other PyTensor operators, \n",
"RandomVariable's can be given permission to modify inputs inplace during their operation.\n",
"\n",
"In this case, there is a `inplace` flag that when `True` \n",
"tells the RandomVariable Op that it is safe to modify the RNG input inplace.\n",
"If the flag is set, the RNG will not be copied before taking random draws."
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"False"
]
},
"execution_count": 30,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x.owner.op.inplace"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This flag is printed as the last argument of the Op in the `dprint`"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"uniform_rv{0, (0, 0), floatX, False}.1 [id A] 'x'\n",
" |rng [id B]\n",
" |TensorConstant{[]} [id C]\n",
" |TensorConstant{11} [id D]\n",
" |TensorConstant{0.0} [id E]\n",
" |TensorConstant{1.0} [id F]\n"
]
},
{
"data": {
"text/plain": [
"<ipykernel.iostream.OutStream at 0x7f6c308c3eb0>"
]
},
"execution_count": 31,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pytensor.dprint(x)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"For illustration purposes,\n",
"we will subclass the Uniform Op class and set inplace to True by default.\n",
"\n",
"Users should never do this directly!"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {},
"outputs": [],
"source": [
"class InplaceUniform(type(pt.random.uniform)):\n",
" inplace = True\n",
"\n",
"inplace_uniform = InplaceUniform()"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {},
"outputs": [],
"source": [
"x = inplace_uniform()"
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 34,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x.owner.op.inplace"
]
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"uniform_rv{0, (0, 0), floatX, True}.1 [id A] d={0: [0]} 0\n",
" |RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7F6BEB52BC80>) [id B]\n",
" |TensorConstant{[]} [id C]\n",
" |TensorConstant{11} [id D]\n",
" |TensorConstant{0.0} [id E]\n",
" |TensorConstant{1.0} [id F]\n"
]
},
{
"data": {
"text/plain": [
"<ipykernel.iostream.OutStream at 0x7f6c308c3eb0>"
]
},
"execution_count": 35,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"inplace_f = pytensor.function([], x, accept_inplace=True)\n",
"pytensor.dprint(inplace_f, print_destroy_map=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The destroy map annotation tells us that \n",
"the first output of the x variable is allowed to alter the first input."
]
},
{
"cell_type": "code",
"execution_count": 36,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"35.5 µs ± 1.87 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)\n"
]
}
],
"source": [
"%timeit inplace_f()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Performance is now much closer to calling numpy directly, \n",
"with only a small overhead introduced by the PyTensor function."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The [random_make_inplace](https://github.com/pymc-devs/pytensor/blob/3fcf6369d013c597a9c964b2400a3c5e20aa8dce/pytensor/tensor/random/rewriting/basic.py#L42-L52) \n",
"rewrite automatically replaces RandomVariable Ops by their inplace counterparts, \n",
"when such operation is deemed safe. \n",
"This happens when:\n",
"\n",
"1. An input RNG is flagged as `mutable` and is used in not used anywhere else.\n",
"2. A RNG is created intermediately and used in not used anywhere else.\n",
"\n",
"The first case is true when a users uses the mutable kwarg directly, \n",
"or much more commonly, when a shared RNG is used and a (default or manual) update expression is given.\n",
"In this case, a RandomVariable is allowed to modify the RNG\n",
"because the shared variable holding it will be rewritten anyway.\n",
"\n",
"The second case is not very common, \n",
"because RNGs are not usually chained across multiple RandomVariable Ops.\n",
"See more details in the next section."
]
},
{
"cell_type": "code",
"execution_count": 37,
"metadata": {},
"outputs": [],
"source": [
"from pytensor.compile.io import In"
]
},
{
"cell_type": "code",
"execution_count": 38,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"rewriting: rewrite random_make_inplace replaces uniform_rv{0, (0, 0), floatX, False}.out of uniform_rv{0, (0, 0), floatX, False}(rng, TensorConstant{[]}, TensorConstant{11}, TensorConstant{0.0}, TensorConstant{1.0}) with uniform_rv{0, (0, 0), floatX, True}.out of uniform_rv{0, (0, 0), floatX, True}(rng, TensorConstant{[]}, TensorConstant{11}, TensorConstant{0.0}, TensorConstant{1.0})\n",
"\n",
"uniform_rv{0, (0, 0), floatX, True}.1 [id A] d={0: [0]} 0\n",
" |rng [id B]\n",
" |TensorConstant{[]} [id C]\n",
" |TensorConstant{11} [id D]\n",
" |TensorConstant{0.0} [id E]\n",
" |TensorConstant{1.0} [id F]\n"
]
},
{
"data": {
"text/plain": [
"<ipykernel.iostream.OutStream at 0x7f6c308c3eb0>"
]
},
"execution_count": 38,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"rng = pt.random.type.RandomGeneratorType()(\"rng\")\n",
"next_rng, x = pt.random.uniform(rng=rng).owner.outputs\n",
"with pytensor.config.change_flags(optimizer_verbose=True):\n",
" inplace_f = pytensor.function([In(rng, mutable=True)], [x])\n",
"print(\"\")\n",
"pytensor.dprint(inplace_f, print_destroy_map=True)"
]
},
{
"cell_type": "code",
"execution_count": 39,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"uniform_rv{0, (0, 0), floatX, True}.1 [id A] d={0: [0]} 0\n",
" |rng [id B]\n",
" |TensorConstant{[]} [id C]\n",
" |TensorConstant{11} [id D]\n",
" |TensorConstant{0.0} [id E]\n",
" |TensorConstant{1.0} [id F]\n",
"uniform_rv{0, (0, 0), floatX, True}.0 [id A] d={0: [0]} 0\n"
]
},
{
"data": {
"text/plain": [
"<ipykernel.iostream.OutStream at 0x7f6c308c3eb0>"
]
},
"execution_count": 39,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"rng = pytensor.shared(np.random.default_rng(), name=\"rng\")\n",
"next_rng, x = pt.random.uniform(rng=rng).owner.outputs\n",
"\n",
"inplace_f = pytensor.function([], [x], updates={rng: next_rng})\n",
"pytensor.dprint(inplace_f, print_destroy_map=True)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "gI2yZWdOvMO3"
},
"source": [
"## Multiple random variables"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"It's common practice to use separate RNG variables for each RandomVariable in PyTensor."
]
},
{
"cell_type": "code",
"execution_count": 40,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "bSyXYZ5eu95F",
"outputId": "f33fc5f5-fe93-4516-8a5b-f8e20319a238"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"normal_rv{0, (0, 0), floatX, True}.1 [id A] <TensorType(float64, ())> 0\n",
" |rng_x [id B] <RandomGeneratorType>\n",
" |TensorConstant{[]} [id C] <TensorType(int64, (0,))>\n",
" |TensorConstant{11} [id D] <TensorType(int64, ())>\n",
" |TensorConstant{0} [id E] <TensorType(int8, ())>\n",
" |TensorConstant{10} [id F] <TensorType(int8, ())>\n",
"normal_rv{0, (0, 0), floatX, True}.1 [id G] <TensorType(float64, ())> 1\n",
" |rng_y [id H] <RandomGeneratorType>\n",
" |TensorConstant{[]} [id C] <TensorType(int64, (0,))>\n",
" |TensorConstant{11} [id I] <TensorType(int64, ())>\n",
" |normal_rv{0, (0, 0), floatX, True}.1 [id A] <TensorType(float64, ())> 0\n",
" |TensorConstant{0.1} [id J] <TensorType(float64, ())>\n",
"normal_rv{0, (0, 0), floatX, True}.0 [id A] <RandomGeneratorType> 'next_rng_x' 0\n",
"normal_rv{0, (0, 0), floatX, True}.0 [id G] <RandomGeneratorType> 'next_rng_y' 1\n"
]
},
{
"data": {
"text/plain": [
"<ipykernel.iostream.OutStream at 0x7f6c308c3eb0>"
]
},
"execution_count": 40,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"rng_x = pytensor.shared(np.random.default_rng(123), name=\"rng_x\")\n",
"rng_y = pytensor.shared(np.random.default_rng(456), name=\"rng_y\")\n",
"\n",
"next_rng_x, x = pt.random.normal(loc=0, scale=10, rng=rng_x).owner.outputs\n",
"next_rng_y, y = pt.random.normal(loc=x, scale=0.1, rng=rng_y).owner.outputs\n",
"\n",
"next_rng_x.name = \"next_rng_x\"\n",
"next_rng_y.name = \"next_rng_y\"\n",
"rng_x.default_update = next_rng_x\n",
"rng_y.default_update = next_rng_y\n",
"\n",
"f = pytensor.function([], [x, y])\n",
"pytensor.dprint(f, print_type=True)"
]
},
{
"cell_type": "code",
"execution_count": 41,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "1j786YJE455a",
"outputId": "3fe315f4-2162-486e-c875-dc223b0363d1"
},
"outputs": [
{
"data": {
"text/plain": [
"([array(-9.8912135), array(-9.80160951)],\n",
" [array(-3.67786651), array(-3.89026137)],\n",
" [array(12.87925261), array(13.04327299)])"
]
},
"execution_count": 41,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"f(), f(), f()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This is what RandomStream does as well"
]
},
{
"cell_type": "code",
"execution_count": 42,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "8kQrk8Nmvk-m",
"outputId": "62405561-a7ee-40d5-994c-ae03896b7445"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"normal_rv{0, (0, 0), floatX, True}.1 [id A] <TensorType(float64, ())> 0\n",
" |RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7F6BEB52B660>) [id B] <RandomGeneratorType>\n",
" |TensorConstant{[]} [id C] <TensorType(int64, (0,))>\n",
" |TensorConstant{11} [id D] <TensorType(int64, ())>\n",
" |TensorConstant{0} [id E] <TensorType(int8, ())>\n",
" |TensorConstant{10} [id F] <TensorType(int8, ())>\n",
"normal_rv{0, (0, 0), floatX, True}.1 [id G] <TensorType(float64, ())> 1\n",
" |RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7F6BEB5E7900>) [id H] <RandomGeneratorType>\n",
" |TensorConstant{[]} [id C] <TensorType(int64, (0,))>\n",
" |TensorConstant{11} [id I] <TensorType(int64, ())>\n",
" |normal_rv{0, (0, 0), floatX, True}.1 [id A] <TensorType(float64, ())> 0\n",
" |TensorConstant{0.1} [id J] <TensorType(float64, ())>\n",
"normal_rv{0, (0, 0), floatX, True}.0 [id A] <RandomGeneratorType> 0\n",
"normal_rv{0, (0, 0), floatX, True}.0 [id G] <RandomGeneratorType> 1\n"
]
},
{
"data": {
"text/plain": [
"<ipykernel.iostream.OutStream at 0x7f6c308c3eb0>"
]
},
"execution_count": 42,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"srng = pt.random.RandomStream(seed=123)\n",
"x = srng.normal(loc=0, scale=10)\n",
"y = srng.normal(loc=x, scale=0.1)\n",
"\n",
"f = pytensor.function([], [x, y])\n",
"pytensor.dprint(f, print_type=True)"
]
},
{
"cell_type": "code",
"execution_count": 43,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "eiqd9jdj2Mct",
"outputId": "19b9af1a-d9bb-4bfa-8a8d-eef04c16d122"
},
"outputs": [
{
"data": {
"text/plain": [
"([array(-5.81223492), array(-5.85081162)],\n",
" [array(-0.47499225), array(-0.64636099)],\n",
" [array(-1.11452059), array(-1.09642036)])"
]
},
"execution_count": 43,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"f(), f(), f()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We could have used a single rng. "
]
},
{
"cell_type": "code",
"execution_count": 44,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "QWcjOB64vvbX",
"outputId": "46a6665d-1efa-4cd2-ab5d-3be5f6eb8e25"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"normal_rv{0, (0, 0), floatX, True}.1 [id A] <TensorType(float64, ())> 0\n",
" |rng [id B] <RandomGeneratorType>\n",
" |TensorConstant{[]} [id C] <TensorType(int64, (0,))>\n",
" |TensorConstant{11} [id D] <TensorType(int64, ())>\n",
" |TensorConstant{0} [id E] <TensorType(int8, ())>\n",
" |TensorConstant{1} [id F] <TensorType(int8, ())>\n",
"normal_rv{0, (0, 0), floatX, True}.1 [id G] <TensorType(float64, ())> 1\n",
" |normal_rv{0, (0, 0), floatX, True}.0 [id A] <RandomGeneratorType> 'next_rng_x' 0\n",
" |TensorConstant{[]} [id C] <TensorType(int64, (0,))>\n",
" |TensorConstant{11} [id H] <TensorType(int64, ())>\n",
" |TensorConstant{100} [id I] <TensorType(int8, ())>\n",
" |TensorConstant{1} [id F] <TensorType(int8, ())>\n",
"normal_rv{0, (0, 0), floatX, True}.0 [id G] <RandomGeneratorType> 'next_rng_y' 1\n"
]
},
{
"data": {
"text/plain": [
"<ipykernel.iostream.OutStream at 0x7f6c308c3eb0>"
]
},
"execution_count": 44,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"rng_x = pytensor.shared(np.random.default_rng(seed=123), name=\"rng_x\")\n",
"next_rng_x, x = pt.random.normal(loc=0, scale=1, rng=rng).owner.outputs\n",
"next_rng_x.name = \"next_rng_x\"\n",
"next_rng_y, y = pt.random.normal(loc=100, scale=1, rng=next_rng_x).owner.outputs\n",
"next_rng_y.name = \"next_rng_y\"\n",
"\n",
"f = pytensor.function([], [x, y], updates={rng: next_rng_y})\n",
"pytensor.dprint(f, print_type=True)"
]
},
{
"cell_type": "code",
"execution_count": 45,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "TouDabd74ETU",
"outputId": "4f5935ec-1578-48b2-abb9-2a5e835ef275"
},
"outputs": [
{
"data": {
"text/plain": [
"([array(0.91110389), array(101.4795275)],\n",
" [array(0.0908175), array(100.59639646)])"
]
},
"execution_count": 45,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"f(), f()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"It works, but that graph is slightly unorthodox in Pytensor. \n",
"\n",
"One practical reason is that it is more difficult to \n",
"define the correct update expression for the shared RNG variable.\n",
"\n",
"One techincal reason is that it makes rewrites more challenging \n",
"in cases where RandomVariables could otherwise be manipulated independently."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "IKyoVdjw2RtM"
},
"source": [
"### Creating multiple RNG variables"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ug3u0OZSx2lX"
},
"source": [
"RandomStreams generate high quality seeds for multiple variables, following the NumPy best practices: https://numpy.org/doc/stable/reference/random/parallel.html#parallel-random-number-generation \n",
"\n",
"Users who create their own RNGs should follow the same practice!"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "YiGShj54wd-4"
},
"source": [
"## Random variables in inner graphs"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Scan"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Scan works very similar to a function (that is called repeatedly inside an outer scope).\n",
"\n",
"This means that random variables will always return the same output unless updates are specified."
]
},
{
"cell_type": "code",
"execution_count": 46,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "dBc30fhlbzgl",
"outputId": "413abf61-9c36-4b4c-cd1f-29c0a415159c"
},
"outputs": [
{
"data": {
"text/plain": [
"(array([-0.98912135, -0.98912135, -0.98912135, -0.98912135, -0.98912135]),\n",
" array([-0.98912135, -0.98912135, -0.98912135, -0.98912135, -0.98912135]))"
]
},
"execution_count": 46,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"rng = pytensor.shared(np.random.default_rng(123), name=\"rng\")\n",
"\n",
"def constant_step(rng):\n",
" return pt.random.normal(rng=rng)\n",
"\n",
"draws, updates = pytensor.scan(\n",
" fn=constant_step,\n",
" outputs_info=[None],\n",
" non_sequences=[rng],\n",
" n_steps=5,\n",
" strict=True,\n",
")\n",
"\n",
"f = pytensor.function([], draws, updates=updates)\n",
"f(), f()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Scan accepts an update dictionary as an output \n",
"to tell how shared variables should be updated after every iteration."
]
},
{
"cell_type": "code",
"execution_count": 47,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "zoGAQK8qcYKa",
"outputId": "c14d300e-10b2-4ff4-a715-8ffca2e70789"
},
"outputs": [],
"source": [
"rng = pytensor.shared(np.random.default_rng(123))\n",
"\n",
"def random_step(rng):\n",
" next_rng, x = pt.random.normal(rng=rng).owner.outputs\n",
" scan_update = {rng: next_rng}\n",
" return x, scan_update\n",
"\n",
"draws, updates = pytensor.scan(\n",
" fn=random_step,\n",
" outputs_info=[None],\n",
" non_sequences=[rng],\n",
" n_steps=5,\n",
" strict=True\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 48,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(array([-0.98912135, -0.36778665, 1.28792526, 0.19397442, 0.9202309 ]),\n",
" array([-0.98912135, -0.36778665, 1.28792526, 0.19397442, 0.9202309 ]))"
]
},
"execution_count": 48,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"f = pytensor.function([], draws)\n",
"f(), f()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"However, we still have to tell the outer function to update the shared RNG across calls,\n",
"using the last state returned by the Scan"
]
},
{
"cell_type": "code",
"execution_count": 49,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(array([-0.98912135, -0.36778665, 1.28792526, 0.19397442, 0.9202309 ]),\n",
" array([ 0.57710379, -0.63646365, 0.54195222, -0.31659545, -0.32238912]))"
]
},
"execution_count": 49,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"f = pytensor.function([], draws, updates=updates)\n",
"f(), f()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Default updates"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Like function, scan also respects shared variables default updates"
]
},
{
"cell_type": "code",
"execution_count": 50,
"metadata": {},
"outputs": [],
"source": [
"def random_step():\n",
" rng = pytensor.shared(np.random.default_rng(123), name=\"rng\")\n",
" next_rng, x = pt.random.normal(rng=rng).owner.outputs\n",
" rng.default_update = next_rng\n",
" return x\n",
"\n",
"draws, updates = pytensor.scan(\n",
" fn=random_step,\n",
" outputs_info=[None],\n",
" non_sequences=[],\n",
" n_steps=5,\n",
" strict=True,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 51,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(array([-0.98912135, -0.36778665, 1.28792526, 0.19397442, 0.9202309 ]),\n",
" array([-0.98912135, -0.36778665, 1.28792526, 0.19397442, 0.9202309 ]))"
]
},
"execution_count": 51,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"f = pytensor.function([], draws)\n",
"f(), f()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The outer function still needs to be told the final update rule"
]
},
{
"cell_type": "code",
"execution_count": 52,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(array([-0.98912135, -0.36778665, 1.28792526, 0.19397442, 0.9202309 ]),\n",
" array([ 0.57710379, -0.63646365, 0.54195222, -0.31659545, -0.32238912]))"
]
},
"execution_count": 52,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"f = pytensor.function([], draws, updates=updates)\n",
"f(), f()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"As expected, Scan only looks at default updates\n",
"for shared variables created inside the user provided function."
]
},
{
"cell_type": "code",
"execution_count": 53,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(array([-0.98912135, -0.98912135, -0.98912135, -0.98912135, -0.98912135]),\n",
" array([-0.36778665, -0.36778665, -0.36778665, -0.36778665, -0.36778665]))"
]
},
"execution_count": 53,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"rng = pytensor.shared(np.random.default_rng(123), name=\"rng\")\n",
"next_rng, x = pt.random.normal(rng=rng).owner.outputs\n",
"rng.default_update = next_rng\n",
" \n",
"def random_step(rng, x): \n",
" return x\n",
"\n",
"draws, updates = pytensor.scan(\n",
" fn=random_step,\n",
" outputs_info=[None],\n",
" non_sequences=[rng, x],\n",
" n_steps=5,\n",
" strict=True,\n",
")\n",
"\n",
"f = pytensor.function([], draws)\n",
"f(), f()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "gE1cLlZc07pH"
},
"source": [
"#### Limitations\n",
"RNGs in Scan are only supported via shared variables in non-sequences at the moment"
]
},
{
"cell_type": "code",
"execution_count": 54,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "p_utP4uQ2dAS",
"outputId": "2c7087fd-2a39-4d2f-c8c5-29903ae7a236"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Tensor type field must be a TensorType; found <class 'pytensor.tensor.random.type.RandomGeneratorType'>.\n"
]
}
],
"source": [
"rng = pt.random.type.RandomGeneratorType()(\"rng\")\n",
"\n",
"def random_step(rng):\n",
" next_rng, x = pt.random.normal(rng=rng).owner.outputs\n",
" return next_rng, x\n",
"\n",
"try:\n",
" (next_rngs, draws), updates = pytensor.scan(\n",
" fn=random_step,\n",
" outputs_info=[rng, None],\n",
" n_steps=5,\n",
" strict=True\n",
" )\n",
"except TypeError as err:\n",
" print(err)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In the future, TensorTypes may be allowed as explicit recurring states,\n",
"rendering the use of updates optional or unnecessary"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZNBb1g6XeP0-"
},
"source": [
"### OpFromGraph"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In contrast to Scan, non-shared RNG variables can be used directly in OpFromGraph"
]
},
{
"cell_type": "code",
"execution_count": 55,
"metadata": {},
"outputs": [],
"source": [
"from pytensor.compile.builders import OpFromGraph\n",
"\n",
"rng = pt.random.type.RandomGeneratorType()(\"rng\")\n",
"\n",
"def lognormal(rng):\n",
" next_rng, x = pt.random.normal(rng=rng).owner.outputs\n",
" return [next_rng, pt.exp(x)]\n",
"\n",
"lognormal_ofg = OpFromGraph([rng], lognormal(rng))"
]
},
{
"cell_type": "code",
"execution_count": 56,
"metadata": {},
"outputs": [],
"source": [
"rng_x = pytensor.shared(np.random.default_rng(1), name=\"rng_x\")\n",
"rng_y = pytensor.shared(np.random.default_rng(2), name=\"rng_y\")\n",
"\n",
"next_rng_x, x = lognormal_ofg(rng_x)\n",
"next_rng_y, y = lognormal_ofg(rng_y) \n",
"\n",
"f = pytensor.function([], [x, y], updates={rng_x: next_rng_x, rng_y: next_rng_y})"
]
},
{
"cell_type": "code",
"execution_count": 57,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"([array(1.41281503), array(1.20810544)],\n",
" [array(2.27417681), array(0.59288879)],\n",
" [array(1.39157622), array(0.66162024)])"
]
},
"execution_count": 57,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"f(), f(), f()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Also in contrast to Scan, \n",
"there is no special treatment of updates for shared variables used in the inner graphs of OpFromGraph.\n",
"\n",
"Any \"updates\" must be modeled as explicit outputs \n",
"and used in the outer graph directly \n",
"as in the example above.\n",
"\n",
"This is arguably more clean."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Other backends (and their limitations)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Numba and RandomState"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The legacy NumPy RandomState can also be used with random variables,\n",
"and are actually the only supported RNG variables for the NUMBA backend at the moment."
]
},
{
"cell_type": "code",
"execution_count": 58,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Numba does not support NumPy `Generator`s\n"
]
}
],
"source": [
"rng = pytensor.shared(np.random.default_rng(123), name=\"randomstate_rng\")\n",
"x = pt.random.normal(rng=rng)\n",
"try:\n",
" numba_fn = pytensor.function([], x, mode=\"NUMBA\")\n",
"except TypeError as exc:\n",
" print(exc)"
]
},
{
"cell_type": "code",
"execution_count": 59,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"normal_rv{0, (0, 0), floatX, False}.1 [id A] <TensorType(float64, ())> 0\n",
" |randomstate_rng [id B] <RandomStateType>\n",
" |TensorConstant{[]} [id C] <TensorType(int64, (0,))>\n",
" |TensorConstant{11} [id D] <TensorType(int64, ())>\n",
" |TensorConstant{0.0} [id E] <TensorType(float32, ())>\n",
" |TensorConstant{1.0} [id F] <TensorType(float32, ())>\n"
]
},
{
"data": {
"text/plain": [
"<ipykernel.iostream.OutStream at 0x7f6c308c3eb0>"
]
},
"execution_count": 59,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"rng = pytensor.shared(np.random.RandomState(123), name=\"randomstate_rng\")\n",
"x = pt.random.normal(rng=rng)\n",
"numba_fn = pytensor.function([], x, mode=\"NUMBA\")\n",
"pytensor.dprint(numba_fn, print_type=True)"
]
},
{
"cell_type": "code",
"execution_count": 60,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"-1.0856306033005612 0.9973454465835858\n"
]
}
],
"source": [
"print(numba_fn(), numba_fn())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The shared RNG variables are not actually used in the compiled Numba function,\n",
"other than for their initial value. \n",
"As such they cannot be reseeded after creation!"
]
},
{
"cell_type": "code",
"execution_count": 61,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.28297849805199204 -1.506294713918092\n"
]
}
],
"source": [
"rng.set_value(np.random.RandomState(123))\n",
"print(numba_fn(), numba_fn())"
]
},
{
"cell_type": "code",
"execution_count": 62,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"-0.5786002519685364 1.651436537097151\n"
]
}
],
"source": [
"rng.set_value(np.random.RandomState(123))\n",
"print(numba_fn(), numba_fn())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Actually updates rules are not respected at all, \n",
"as can be seen by the function above giving different outputs, \n",
"even though updates were never specified."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Support of Numpy Random Generators by Numba may fix these limitations in the future.\n",
"This may allow us to stop supporting the legacy RandomStates."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### JAX"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"JAX uses a different type of PRNG than those of Numpy. \n",
"This means that the standard shared RNGs \n",
"cannot be used directly in graphs transpiled to JAX.\n",
"\n",
"Instead a copy of the Shared RNG variable is made, \n",
"and its bit generator state is given a `jax_state` entry\n",
"that is actually used by the JAX random variables.\n",
"\n",
"In general, update rules are still respected, \n",
"but they won't be used on the original shared variable,\n",
"only the copied one actually used in the transpiled function"
]
},
{
"cell_type": "code",
"execution_count": 63,
"metadata": {},
"outputs": [],
"source": [
"import jax"
]
},
{
"cell_type": "code",
"execution_count": 64,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/evelin/Documents/Ricardo/Projects/pytensor/pytensor/link/jax/linker.py:28: UserWarning: The RandomType SharedVariables [rng] will not be used in the compiled JAX graph. Instead a copy will be used.\n",
" warnings.warn(\n",
"No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n"
]
}
],
"source": [
"rng = pytensor.shared(np.random.default_rng(123), name=\"rng\")\n",
"next_rng, x = pt.random.uniform(rng=rng).owner.outputs\n",
"jax_fn = pytensor.function([], [x], updates={rng: next_rng}, mode=\"JAX\")"
]
},
{
"cell_type": "code",
"execution_count": 65,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[Array(0.07577298, dtype=float64)] [Array(0.09217023, dtype=float64)]\n"
]
}
],
"source": [
"print(jax_fn(), jax_fn())"
]
},
{
"cell_type": "code",
"execution_count": 66,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[Array(0.13929162, dtype=float64)] [Array(0.45162648, dtype=float64)]\n"
]
}
],
"source": [
"# No effect on the jax evaluation\n",
"rng.set_value(np.random.default_rng(123))\n",
"print(jax_fn(), jax_fn())"
]
},
{
"cell_type": "code",
"execution_count": 67,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'bit_generator': Array(1, dtype=int64, weak_type=True),\n",
" 'has_uint32': Array(0, dtype=int64, weak_type=True),\n",
" 'jax_state': Array([2647707238, 2709433097], dtype=uint32),\n",
" 'state': {'inc': Array(-9061352147377205305, dtype=int64),\n",
" 'state': Array(-6044258077699604239, dtype=int64)},\n",
" 'uinteger': Array(0, dtype=int64, weak_type=True)}"
]
},
"execution_count": 67,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"[jax_rng] = jax_fn.input_storage[0].storage\n",
"jax_rng"
]
},
{
"cell_type": "code",
"execution_count": 68,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[Array(0.57655083, dtype=float64)] [Array(0.50347362, dtype=float64)]\n"
]
}
],
"source": [
"[jax_rng] = jax_fn.input_storage[0].storage\n",
"jax_rng[\"jax_state\"] = jax.random.PRNGKey(0)\n",
"print(jax_fn(), jax_fn())"
]
},
{
"cell_type": "code",
"execution_count": 69,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[Array(0.57655083, dtype=float64)] [Array(0.50347362, dtype=float64)]\n"
]
}
],
"source": [
"[jax_rng] = jax_fn.input_storage[0].storage\n",
"jax_rng[\"jax_state\"] = jax.random.PRNGKey(0)\n",
"print(jax_fn(), jax_fn())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"PyTensor could provide shared JAX-like RNGs and allow RandomVariables to accept them,\n",
"but that would break the spirit of one graph -> multiple backends.\n",
"\n",
"Alternatively, PyTensor could try to use a more general type for RNGs that\n",
"can be used across different backends,\n",
"either directly or after some conversion operation \n",
"(if such operations can be implemented in the different backends)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"colab": {
"authorship_tag": "ABX9TyORRxkCASQeq139SWNaBZxI",
"include_colab_link": true,
"provenance": []
},
"hide_input": false,
"kernelspec": {
"display_name": "pytensor",
"language": "python",
"name": "pytensor"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.10"
},
"toc": {
"base_numbering": 1,
"nav_menu": {},
"number_sections": true,
"sideBar": true,
"skip_h1_title": false,
"title_cell": "Table of Contents",
"title_sidebar": "Contents",
"toc_cell": false,
"toc_position": {},
"toc_section_display": true,
"toc_window_display": true
},
"varInspector": {
"cols": {
"lenName": 16,
"lenType": 16,
"lenVar": 40
},
"kernels_config": {
"python": {
"delete_cmd_postfix": "",
"delete_cmd_prefix": "del ",
"library": "var_list.py",
"varRefreshCmd": "print(var_dic_list())"
},
"r": {
"delete_cmd_postfix": ") ",
"delete_cmd_prefix": "rm(",
"library": "var_list.r",
"varRefreshCmd": "cat(var_dic_list()) "
}
},
"types_to_exclude": [
"module",
"function",
"builtin_function_or_method",
"instance",
"_Feature"
],
"window_display": false
}
},
"nbformat": 4,
"nbformat_minor": 1
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment