-
-
Save kc611/821b41413d0f6cc0ad34d842a25aa7ea to your computer and use it in GitHub Desktop.
New implementation of `RandomState`'s lowering with proper state management
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from textwrap import dedent, indent | |
from typing import Any, Callable, Dict, Optional | |
import numba.np.unsafe.ndarray as numba_ndarray | |
import numpy as np | |
from numba import _helperlib, types | |
from numba.core import cgutils | |
from numba.extending import NativeValue, box, models, register_model, typeof_impl, unbox | |
from numba.extending import ( | |
NativeValue, | |
box, | |
make_attribute_wrapper, | |
models, | |
register_model, | |
typeof_impl, | |
overload_method, | |
unbox, | |
) | |
import aesara.tensor.random.basic as aer | |
from aesara.graph.basic import Apply | |
from aesara.graph.op import Op | |
from aesara.link.numba.dispatch import basic as numba_basic | |
from aesara.link.numba.dispatch.basic import numba_funcify, numba_typify | |
from aesara.link.utils import ( | |
compile_function_src, | |
get_name_for_object, | |
unique_name_generator, | |
) | |
from aesara.tensor.basic import get_vector_length | |
from aesara.tensor.random.type import RandomStateType | |
from aesara.tensor.random.var import RandomStateSharedVariable | |
from numba import _helperlib | |
from numpy.random import RandomState | |
import numba | |
from numba.cpython.randomimpl import get_state_ptr | |
from numba.core.imputils import (Registry, impl_ret_untracked, | |
impl_ret_new_ref) | |
from numba.core.extending import intrinsic, overload_method | |
from numba.core import types, utils, cgutils | |
from llvmlite import ir | |
from numba.np.arrayobj import make_array | |
class RandomStateNumbaType(types.Type): | |
def __init__(self): | |
super(RandomStateNumbaType, self).__init__(name="RandomStateNumba") | |
random_state_numba_type = RandomStateNumbaType() | |
@typeof_impl.register(RandomState) | |
def typeof_index(val, c): | |
return random_state_numba_type | |
@register_model(RandomStateNumbaType) | |
class RandomStateNumbaModel(models.StructModel): | |
def __init__(self, dmm, fe_type): | |
members = [ | |
("state_key", types.Array(types.int32, 1, 'C')), | |
("pos", types.int32), | |
("has_gauss", types.int32), | |
("gauss", types.float64), | |
] | |
models.StructModel.__init__(self, dmm, fe_type, members) | |
@unbox(RandomStateNumbaType) | |
def unbox_random_state(typ, obj, c): | |
"""Convert a `RandomState` object to a native `RandomStateNumbaModel` structure. | |
""" | |
bitgen = c.pyapi.call_method(obj, "__getstate__") | |
state = c.pyapi.dict_getitem_string(bitgen, "state") | |
state_pos = c.pyapi.dict_getitem_string(state, "pos") | |
state_key = c.pyapi.dict_getitem_string(state, "key") | |
has_gauss = c.pyapi.dict_getitem_string(bitgen, "has_gauss") | |
gauss = c.pyapi.dict_getitem_string(bitgen, "gauss") | |
interval = cgutils.create_struct_proxy(typ)(c.context, c.builder) | |
interval.state_key = numba.core.boxing.unbox_array(types.Array(types.int32, 1, 'C'), | |
state_key, c).value | |
interval.pos = numba.core.boxing.unbox_integer(types.int32, state_pos, c).value | |
interval.has_gauss = numba.core.boxing.unbox_integer(types.int32, has_gauss, c).value | |
interval.gauss = numba.core.boxing.unbox_float(types.float64, gauss, c).value | |
c.pyapi.decref(state_key) | |
c.pyapi.decref(state_pos) | |
c.pyapi.decref(gauss) | |
c.pyapi.decref(has_gauss) | |
is_error = cgutils.is_not_null(c.builder, c.pyapi.err_occurred()) | |
return NativeValue(interval._getvalue(), is_error=is_error) | |
make_attribute_wrapper(RandomStateNumbaType, "state_key", "state_key") | |
make_attribute_wrapper(RandomStateNumbaType, "has_gauss", "has_gauss") | |
make_attribute_wrapper(RandomStateNumbaType, "gauss", "gauss") | |
make_attribute_wrapper(RandomStateNumbaType, "pos", "pos") | |
int32_t = ir.IntType(32) | |
def const_int(x): | |
return ir.Constant(int32_t, x) | |
double = ir.DoubleType() | |
N = 624 | |
N_const = ir.Constant(int32_t, N) | |
# Accessors | |
def get_index_ptr(builder, rng_object): | |
return cgutils.gep_inbounds(builder, rng_object, 0, 1) | |
def get_array_ptr(builder, rng_object): | |
return cgutils.gep_inbounds(builder, rng_object, 0, 0) | |
def get_next_int32(context, builder, signature, args): | |
""" | |
Get the next int32 generated by the PRNG at *rng_object*. | |
""" | |
### Problem Here ### | |
# This needs to be pointer to the value directly, rather than | |
# allocating a new value and giving back the pointer to the | |
# newly allocated value. The main aim is to update the object | |
# in-place so the next time this method is called, it'll process | |
# t6he logic according to the new updated values | |
rng_object = args[0] | |
rawptr = cgutils.alloca_once_value(builder, value=rng_object) | |
idxptr = get_index_ptr(builder, rawptr) | |
idx = builder.load(idxptr) | |
need_reshuffle = builder.icmp_unsigned('>=', idx, N_const) | |
# TODO: Implement a c-call for this | |
with cgutils.if_unlikely(builder, need_reshuffle): | |
# fn = get_rnd_shuffle(builder) | |
# builder.call(fn, (rng_object,)) | |
builder.store(const_int(0), idxptr) | |
idx = builder.load(idxptr) | |
array_ptr = get_array_ptr(builder, rawptr) | |
arr_data_ptr_ptr = cgutils.gep_inbounds(builder, array_ptr, 0, 4) | |
arr_data_ptr_void = builder.load(arr_data_ptr_ptr) | |
arr_data_ptr = builder.bitcast(arr_data_ptr_void, ir.ArrayType(ir.IntType(32), 624).as_pointer()) | |
arr_val_ptr = cgutils.gep_inbounds(builder, arr_data_ptr, 0, idx) | |
y = builder.load(arr_val_ptr) | |
idx = builder.add(idx, const_int(1)) | |
builder.store(idx, idxptr) | |
# Tempering | |
y = builder.xor(y, builder.lshr(y, const_int(11))) | |
y = builder.xor(y, builder.and_(builder.shl(y, const_int(7)), | |
const_int(0x9d2c5680))) | |
y = builder.xor(y, builder.and_(builder.shl(y, const_int(15)), | |
const_int(0xefc60000))) | |
y = builder.xor(y, builder.lshr(y, const_int(18))) | |
return y | |
def get_next_double(context, builder, signature, args): | |
""" | |
Get the next double generated by the PRNG at *rng_object*. | |
""" | |
# a = rk_random(state) >> 5, b = rk_random(state) >> 6; | |
a = builder.lshr(get_next_int32(context, builder, signature, args), const_int(5)) | |
b = builder.lshr(get_next_int32(context, builder, signature, args), const_int(6)) | |
# return (a * 67108864.0 + b) / 9007199254740992.0; | |
a = builder.uitofp(a, double) | |
b = builder.uitofp(b, double) | |
return builder.fdiv( | |
builder.fadd(b, builder.fmul(a, ir.Constant(double, 67108864.0))), | |
ir.Constant(double, 9007199254740992.0)) | |
@intrinsic | |
def intr_random(typcontext, rng_type): | |
rett = types.float64 | |
sig = rett(rng_type) | |
def codegen(context, builder, signature, args): | |
res = get_next_double(context, builder, signature, args) | |
return impl_ret_untracked(context, builder, signature.return_type, res) | |
return sig, codegen | |
@overload_method(RandomStateNumbaType, "random") | |
def numba_random_dist(rng): | |
def _random_impl(rng): | |
return intr_random(rng) | |
return _random_impl | |
@numba.njit | |
def return_val(x): | |
# These print proper output so unboxing works | |
# print(x.state_key) | |
# print(x.gauss) | |
# print(x.has_gauss) | |
# print(x.pos) | |
res = x.random() | |
return res | |
a = RandomState(3) | |
res = return_val(a) | |
print(res) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment