Skip to content

Instantly share code, notes, and snippets.

@kc611
Last active March 2, 2022 09:39
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save kc611/821b41413d0f6cc0ad34d842a25aa7ea to your computer and use it in GitHub Desktop.
Save kc611/821b41413d0f6cc0ad34d842a25aa7ea to your computer and use it in GitHub Desktop.
New implementation of `RandomState`'s lowering with proper state management
from textwrap import dedent, indent
from typing import Any, Callable, Dict, Optional
import numba.np.unsafe.ndarray as numba_ndarray
import numpy as np
from numba import _helperlib, types
from numba.core import cgutils
from numba.extending import NativeValue, box, models, register_model, typeof_impl, unbox
from numba.extending import (
NativeValue,
box,
make_attribute_wrapper,
models,
register_model,
typeof_impl,
overload_method,
unbox,
)
import aesara.tensor.random.basic as aer
from aesara.graph.basic import Apply
from aesara.graph.op import Op
from aesara.link.numba.dispatch import basic as numba_basic
from aesara.link.numba.dispatch.basic import numba_funcify, numba_typify
from aesara.link.utils import (
compile_function_src,
get_name_for_object,
unique_name_generator,
)
from aesara.tensor.basic import get_vector_length
from aesara.tensor.random.type import RandomStateType
from aesara.tensor.random.var import RandomStateSharedVariable
from numba import _helperlib
from numpy.random import RandomState
import numba
from numba.cpython.randomimpl import get_state_ptr
from numba.core.imputils import (Registry, impl_ret_untracked,
impl_ret_new_ref)
from numba.core.extending import intrinsic, overload_method
from numba.core import types, utils, cgutils
from llvmlite import ir
from numba.np.arrayobj import make_array
class RandomStateNumbaType(types.Type):
def __init__(self):
super(RandomStateNumbaType, self).__init__(name="RandomStateNumba")
random_state_numba_type = RandomStateNumbaType()
@typeof_impl.register(RandomState)
def typeof_index(val, c):
return random_state_numba_type
@register_model(RandomStateNumbaType)
class RandomStateNumbaModel(models.StructModel):
def __init__(self, dmm, fe_type):
members = [
("state_key", types.Array(types.int32, 1, 'C')),
("pos", types.int32),
("has_gauss", types.int32),
("gauss", types.float64),
]
models.StructModel.__init__(self, dmm, fe_type, members)
@unbox(RandomStateNumbaType)
def unbox_random_state(typ, obj, c):
"""Convert a `RandomState` object to a native `RandomStateNumbaModel` structure.
"""
bitgen = c.pyapi.call_method(obj, "__getstate__")
state = c.pyapi.dict_getitem_string(bitgen, "state")
state_pos = c.pyapi.dict_getitem_string(state, "pos")
state_key = c.pyapi.dict_getitem_string(state, "key")
has_gauss = c.pyapi.dict_getitem_string(bitgen, "has_gauss")
gauss = c.pyapi.dict_getitem_string(bitgen, "gauss")
interval = cgutils.create_struct_proxy(typ)(c.context, c.builder)
interval.state_key = numba.core.boxing.unbox_array(types.Array(types.int32, 1, 'C'),
state_key, c).value
interval.pos = numba.core.boxing.unbox_integer(types.int32, state_pos, c).value
interval.has_gauss = numba.core.boxing.unbox_integer(types.int32, has_gauss, c).value
interval.gauss = numba.core.boxing.unbox_float(types.float64, gauss, c).value
c.pyapi.decref(state_key)
c.pyapi.decref(state_pos)
c.pyapi.decref(gauss)
c.pyapi.decref(has_gauss)
is_error = cgutils.is_not_null(c.builder, c.pyapi.err_occurred())
return NativeValue(interval._getvalue(), is_error=is_error)
make_attribute_wrapper(RandomStateNumbaType, "state_key", "state_key")
make_attribute_wrapper(RandomStateNumbaType, "has_gauss", "has_gauss")
make_attribute_wrapper(RandomStateNumbaType, "gauss", "gauss")
make_attribute_wrapper(RandomStateNumbaType, "pos", "pos")
int32_t = ir.IntType(32)
def const_int(x):
return ir.Constant(int32_t, x)
double = ir.DoubleType()
N = 624
N_const = ir.Constant(int32_t, N)
# Accessors
def get_index_ptr(builder, rng_object):
return cgutils.gep_inbounds(builder, rng_object, 0, 1)
def get_array_ptr(builder, rng_object):
return cgutils.gep_inbounds(builder, rng_object, 0, 0)
def get_next_int32(context, builder, signature, args):
"""
Get the next int32 generated by the PRNG at *rng_object*.
"""
### Problem Here ###
# This needs to be pointer to the value directly, rather than
# allocating a new value and giving back the pointer to the
# newly allocated value. The main aim is to update the object
# in-place so the next time this method is called, it'll process
# t6he logic according to the new updated values
rng_object = args[0]
rawptr = cgutils.alloca_once_value(builder, value=rng_object)
idxptr = get_index_ptr(builder, rawptr)
idx = builder.load(idxptr)
need_reshuffle = builder.icmp_unsigned('>=', idx, N_const)
# TODO: Implement a c-call for this
with cgutils.if_unlikely(builder, need_reshuffle):
# fn = get_rnd_shuffle(builder)
# builder.call(fn, (rng_object,))
builder.store(const_int(0), idxptr)
idx = builder.load(idxptr)
array_ptr = get_array_ptr(builder, rawptr)
arr_data_ptr_ptr = cgutils.gep_inbounds(builder, array_ptr, 0, 4)
arr_data_ptr_void = builder.load(arr_data_ptr_ptr)
arr_data_ptr = builder.bitcast(arr_data_ptr_void, ir.ArrayType(ir.IntType(32), 624).as_pointer())
arr_val_ptr = cgutils.gep_inbounds(builder, arr_data_ptr, 0, idx)
y = builder.load(arr_val_ptr)
idx = builder.add(idx, const_int(1))
builder.store(idx, idxptr)
# Tempering
y = builder.xor(y, builder.lshr(y, const_int(11)))
y = builder.xor(y, builder.and_(builder.shl(y, const_int(7)),
const_int(0x9d2c5680)))
y = builder.xor(y, builder.and_(builder.shl(y, const_int(15)),
const_int(0xefc60000)))
y = builder.xor(y, builder.lshr(y, const_int(18)))
return y
def get_next_double(context, builder, signature, args):
"""
Get the next double generated by the PRNG at *rng_object*.
"""
# a = rk_random(state) >> 5, b = rk_random(state) >> 6;
a = builder.lshr(get_next_int32(context, builder, signature, args), const_int(5))
b = builder.lshr(get_next_int32(context, builder, signature, args), const_int(6))
# return (a * 67108864.0 + b) / 9007199254740992.0;
a = builder.uitofp(a, double)
b = builder.uitofp(b, double)
return builder.fdiv(
builder.fadd(b, builder.fmul(a, ir.Constant(double, 67108864.0))),
ir.Constant(double, 9007199254740992.0))
@intrinsic
def intr_random(typcontext, rng_type):
rett = types.float64
sig = rett(rng_type)
def codegen(context, builder, signature, args):
res = get_next_double(context, builder, signature, args)
return impl_ret_untracked(context, builder, signature.return_type, res)
return sig, codegen
@overload_method(RandomStateNumbaType, "random")
def numba_random_dist(rng):
def _random_impl(rng):
return intr_random(rng)
return _random_impl
@numba.njit
def return_val(x):
# These print proper output so unboxing works
# print(x.state_key)
# print(x.gauss)
# print(x.has_gauss)
# print(x.pos)
res = x.random()
return res
a = RandomState(3)
res = return_val(a)
print(res)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment