Skip to content

Instantly share code, notes, and snippets.

@bbbales2
Last active January 10, 2023 16:08
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save bbbales2/be4e43dbd82161750757105be527358e to your computer and use it in GitHub Desktop.
Save bbbales2/be4e43dbd82161750757105be527358e to your computer and use it in GitHub Desktop.
Timeseries model
from functools import reduce
from sortedcontainers import SortedSet, SortedDict
from typing import List, Callable, TypeVar, Any, Generic, Optional, Tuple, Dict
from collections import namedtuple, defaultdict
import haiku
import jax
import jax.scipy
import numpy.random
X = TypeVar("X")
Y = TypeVar("Y")
G = TypeVar("G")
class Partitioner(Generic[X, G]):
base_partitioned_shape: Tuple[int, int]
_partitions: jax.numpy.ndarray
_partition_indices: jax.numpy.ndarray
_partition_sizes: Dict[int, int]
def __init__(
self,
xs: List[X],
partition_by: Callable[[X], G],
num_partitions: int = 1
):
def hash_function(group):
return hash(group) % num_partitions
self._partition_sizes = defaultdict(lambda: 0)
partitions_builder = []
partition_indices_builder = []
for x in xs:
partition = hash_function(partition_by(x))
partitions_builder.append(partition)
partition_indices_builder.append(self._partition_sizes[partition])
self._partition_sizes[partition] += 1
self._partitions = jax.numpy.array(partitions_builder)
self._partition_indices = jax.numpy.array(partition_indices_builder)
max_partition_size = max(self._partition_sizes.values())
self.base_partitioned_shape = (num_partitions, max_partition_size)
def partition(self, x: jax.numpy.array):
assert x.shape[0] == len(self._partitions)
gathered_shape = self.base_partitioned_shape + x.shape[1:]
gathered = jax.numpy.zeros(gathered_shape, dtype = x.dtype)
return gathered.at[self._partitions, self._partition_indices].set(x)
def unpartition(self, x: jax.numpy.array):
assert x.shape[:2] == self.base_partitioned_shape
return x[self._partitions, self._partition_indices]
class Mapper(Generic[X, Y]):
subscript_domains: Any
_constants: Any
_subscript_indices: Any
def __init__(
self,
xs: List[X],
prepare_constants: Optional[Any] = None,
prepare_subscripts: Optional[Any] = None
):
# Compute constants and save as pytree of ndarray
self._constants = jax.tree_map(lambda f: jax.numpy.array([f(x) for x in xs]), prepare_constants)
# Save subscript domains (pytree of SortedSets) and subscript indices (pytree of ndarray[int])
class MyList(list):
pass
subscript_values = jax.tree_map(lambda f: MyList(f(x) for x in xs), prepare_subscripts)
subscript_domains = jax.tree_map(lambda x: SortedSet(x), subscript_values)
flat_subscript_domains, subscript_domain_treedef = jax.tree_util.tree_flatten(subscript_domains)
flat_subscript_values, subscript_value_treedef = jax.tree_util.tree_flatten(subscript_values)
flat_subscript_indices = [
jax.numpy.array([domain.index(value) for value in values])
for domain, values in zip(flat_subscript_domains, flat_subscript_values)
]
assert subscript_domain_treedef == subscript_value_treedef
subscript_indices = jax.tree_util.tree_unflatten(subscript_domain_treedef, flat_subscript_indices)
self.subscript_domains = subscript_domains
self._subscript_indices = subscript_indices
def __call__(
self,
map_function: Callable[[Any, Any], Y],
parameters: Optional[Any] = None
):
def mapper(args):
constants, subscript_indices = args
flat_parameters, parameters_treedef = jax.tree_util.tree_flatten(parameters)
flat_subscript_indices, subscript_indices_treedef = jax.tree_util.tree_flatten(subscript_indices)
assert parameters_treedef == subscript_indices_treedef
local_parameters = jax.tree_util.tree_unflatten(
parameters_treedef,
(parameter[index] for parameter, index in zip(flat_parameters, flat_subscript_indices))
)
return map_function(constants, local_parameters)
return jax.vmap(mapper)((self._constants, self._subscript_indices))
class Scanner(Generic[X, Y, G]):
groups: SortedSet[G]
_mapper: Mapper
_initial_value_indices: jax.numpy.ndarray
_original_indices: jax.numpy.ndarray
def __init__(
self,
xs: List[X],
group_by: Callable[[X], G],
prepare_constants: Optional[Any] = None,
prepare_subscripts: Optional[Any] = None,
num_partitions: int = 1
):
#self._constants = self._grouper.partition(self._mapper._constants)
#self._subscript_indices = self._grouper.partition(self._mapper._subscript_indices)
# Separate the inputs out into groups
grouped_xs = SortedDict()
original_indices = SortedDict()
for i, x in enumerate(xs):
group = group_by(x)
if group not in grouped_xs:
grouped_xs[group] = list()
original_indices[group] = list()
grouped_xs[group].append(x)
original_indices[group].append(i)
self.groups = SortedSet(grouped_xs.keys())
self._original_indices = jax.numpy.concatenate([jax.numpy.array(indices) for indices in original_indices.values()])
ordered_xs = reduce(lambda x, y: x + y, grouped_xs.values())
self._mapper = Mapper(xs=ordered_xs, prepare_constants=prepare_constants, prepare_subscripts=prepare_subscripts)
self._partitioner = Partitioner(xs=ordered_xs, partition_by=group_by, num_partitions=num_partitions)
# Build an array to figure out if we need to reset our calculation
# The value will be positive when this is true and be an index
# into the initial values array
# The indices are stored as if arrays were indexed from 1 because there's
# no difference in a negative zero and a positive zero
initial_value_index_builder = []
for group, group_xs in grouped_xs.items():
group_size = len(group_xs)
initial_value_index = grouped_xs.index(group) + 1
initial_value_index_builder.append(initial_value_index)
initial_value_index_builder.extend((group_size - 1) * [-1 * initial_value_index])
self._initial_value_indices = jax.numpy.array(initial_value_index_builder)
@property
def subscript_domains(self):
return self._mapper.subscript_domains
def __call__(
self,
update: Callable[[Y, Any, Any], Y],
initial_values: jax.numpy.ndarray,
parameters: Optional[Any] = None
):
def scanner(previous, args):
initial_value_index, constants, subscript_indices = args
flat_parameters, parameters_treedef = jax.tree_util.tree_flatten(parameters)
flat_subscript_indices, subscript_indices_treedef = jax.tree_util.tree_flatten(subscript_indices)
assert parameters_treedef == subscript_indices_treedef
local_parameters = jax.tree_util.tree_unflatten(
parameters_treedef,
(parameter[index] for parameter, index in zip(flat_parameters, flat_subscript_indices))
)
current = jax.numpy.where(
initial_value_index > 0,
initial_values[initial_value_index - 1],
update(previous, constants, local_parameters)
)
return current, current
partitioned_initial_value_indices = jax.tree_util.tree_map(self._partitioner.partition, self._initial_value_indices)
partitioned_constants = jax.tree_util.tree_map(self._partitioner.partition, self._mapper._constants)
partitioned_subscript_indices = jax.tree_util.tree_map(self._partitioner.partition, self._mapper._subscript_indices)
def mapper(initial_value_indices, constants, subscript_indices):
return jax.lax.scan(
f=scanner,
init=0.0,
xs=(initial_value_indices, constants, subscript_indices)
)
mapped_scan = jax.vmap(mapper, in_axes=(0, 0, 0))
_, partitioned_ys = mapped_scan(
partitioned_initial_value_indices,
partitioned_constants,
partitioned_subscript_indices
)
return self._partitioner.unpartition(partitioned_ys)[self._original_indices]
Position = namedtuple("Position", ["letter", "t", "x"])
positions: List[Position] = []
rng = numpy.random.default_rng(seed=0)
for letter in ["A", "B", "C"]:
length = rng.integers(1, 10)
x = rng.normal(0.0, 10.0) # Initial position
for t in range(length):
x = x + rng.normal(0.0, 0.1) # Walk a little
x_meas = x + rng.normal(0.0, 0.2) # Add some measurement noise
positions.append(
Position(letter, t, x_meas)
)
def update(previous, _, parameters):
innovation, scale_walk = parameters
return previous + scale_walk * innovation
def group_by(position: Position):
return position.letter
def prepare_innovation_subscripts(position: Position):
return (position.letter, position.t)
rw = Scanner(
positions,
group_by=group_by,
prepare_subscripts=(prepare_innovation_subscripts, lambda x: x.letter),
num_partitions=2
)
def random_walk_likelihood():
innovation_domain, scale_walk_domain = rw.subscript_domains
innovations = haiku.get_parameter("innovations", shape=[len(innovation_domain)], init=jax.numpy.zeros)
log_scale_walk = haiku.get_parameter("log_scale_walk", shape=[len(scale_walk_domain)], init=jax.numpy.zeros)
scale_walk = jax.numpy.exp(log_scale_walk)
log_scale_measure = haiku.get_parameter("log_scale_measure", shape=(), init=jax.numpy.zeros)
scale_measure = jax.numpy.exp(log_scale_measure)
log_initial_value_scale = haiku.get_parameter("initial_value_scale", shape=[], init=jax.numpy.zeros)
initial_value_scale = jax.numpy.exp(log_initial_value_scale)
initial_values_z = haiku.get_parameter("initial_values_z", shape=[len(rw.groups)], init=jax.numpy.zeros)
initial_values = initial_values_z * initial_value_scale
mu = rw(
update=update,
initial_values=initial_values,
parameters=(innovations, scale_walk)
)
x_meas = jax.numpy.array([position.x for position in positions])
return jax.numpy.sum(jax.scipy.stats.norm.logpdf(x_meas, mu, scale_measure))
transformed = haiku.without_apply_rng(haiku.transform(random_walk_likelihood))
rng_key = jax.random.PRNGKey(42)
params = transformed.init(rng=rng_key)
mu_out = transformed.apply(params)
print(mu_out)
initial_values = jax.numpy.array([1, 2, 3])
innovations = jax.numpy.ones(len(rw.subscript_domains[0]))
scale_walk = jax.numpy.array([0.1, 0.2, 0.3])
mu2 = rw(update, initial_values, (innovations, scale_walk))
print(mu2)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment