Last active
January 10, 2023 16:08
-
-
Save bbbales2/be4e43dbd82161750757105be527358e to your computer and use it in GitHub Desktop.
Timeseries model
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from functools import reduce | |
from sortedcontainers import SortedSet, SortedDict | |
from typing import List, Callable, TypeVar, Any, Generic, Optional, Tuple, Dict | |
from collections import namedtuple, defaultdict | |
import haiku | |
import jax | |
import jax.scipy | |
import numpy.random | |
X = TypeVar("X") | |
Y = TypeVar("Y") | |
G = TypeVar("G") | |
class Partitioner(Generic[X, G]): | |
base_partitioned_shape: Tuple[int, int] | |
_partitions: jax.numpy.ndarray | |
_partition_indices: jax.numpy.ndarray | |
_partition_sizes: Dict[int, int] | |
def __init__( | |
self, | |
xs: List[X], | |
partition_by: Callable[[X], G], | |
num_partitions: int = 1 | |
): | |
def hash_function(group): | |
return hash(group) % num_partitions | |
self._partition_sizes = defaultdict(lambda: 0) | |
partitions_builder = [] | |
partition_indices_builder = [] | |
for x in xs: | |
partition = hash_function(partition_by(x)) | |
partitions_builder.append(partition) | |
partition_indices_builder.append(self._partition_sizes[partition]) | |
self._partition_sizes[partition] += 1 | |
self._partitions = jax.numpy.array(partitions_builder) | |
self._partition_indices = jax.numpy.array(partition_indices_builder) | |
max_partition_size = max(self._partition_sizes.values()) | |
self.base_partitioned_shape = (num_partitions, max_partition_size) | |
def partition(self, x: jax.numpy.array): | |
assert x.shape[0] == len(self._partitions) | |
gathered_shape = self.base_partitioned_shape + x.shape[1:] | |
gathered = jax.numpy.zeros(gathered_shape, dtype = x.dtype) | |
return gathered.at[self._partitions, self._partition_indices].set(x) | |
def unpartition(self, x: jax.numpy.array): | |
assert x.shape[:2] == self.base_partitioned_shape | |
return x[self._partitions, self._partition_indices] | |
class Mapper(Generic[X, Y]): | |
subscript_domains: Any | |
_constants: Any | |
_subscript_indices: Any | |
def __init__( | |
self, | |
xs: List[X], | |
prepare_constants: Optional[Any] = None, | |
prepare_subscripts: Optional[Any] = None | |
): | |
# Compute constants and save as pytree of ndarray | |
self._constants = jax.tree_map(lambda f: jax.numpy.array([f(x) for x in xs]), prepare_constants) | |
# Save subscript domains (pytree of SortedSets) and subscript indices (pytree of ndarray[int]) | |
class MyList(list): | |
pass | |
subscript_values = jax.tree_map(lambda f: MyList(f(x) for x in xs), prepare_subscripts) | |
subscript_domains = jax.tree_map(lambda x: SortedSet(x), subscript_values) | |
flat_subscript_domains, subscript_domain_treedef = jax.tree_util.tree_flatten(subscript_domains) | |
flat_subscript_values, subscript_value_treedef = jax.tree_util.tree_flatten(subscript_values) | |
flat_subscript_indices = [ | |
jax.numpy.array([domain.index(value) for value in values]) | |
for domain, values in zip(flat_subscript_domains, flat_subscript_values) | |
] | |
assert subscript_domain_treedef == subscript_value_treedef | |
subscript_indices = jax.tree_util.tree_unflatten(subscript_domain_treedef, flat_subscript_indices) | |
self.subscript_domains = subscript_domains | |
self._subscript_indices = subscript_indices | |
def __call__( | |
self, | |
map_function: Callable[[Any, Any], Y], | |
parameters: Optional[Any] = None | |
): | |
def mapper(args): | |
constants, subscript_indices = args | |
flat_parameters, parameters_treedef = jax.tree_util.tree_flatten(parameters) | |
flat_subscript_indices, subscript_indices_treedef = jax.tree_util.tree_flatten(subscript_indices) | |
assert parameters_treedef == subscript_indices_treedef | |
local_parameters = jax.tree_util.tree_unflatten( | |
parameters_treedef, | |
(parameter[index] for parameter, index in zip(flat_parameters, flat_subscript_indices)) | |
) | |
return map_function(constants, local_parameters) | |
return jax.vmap(mapper)((self._constants, self._subscript_indices)) | |
class Scanner(Generic[X, Y, G]): | |
groups: SortedSet[G] | |
_mapper: Mapper | |
_initial_value_indices: jax.numpy.ndarray | |
_original_indices: jax.numpy.ndarray | |
def __init__( | |
self, | |
xs: List[X], | |
group_by: Callable[[X], G], | |
prepare_constants: Optional[Any] = None, | |
prepare_subscripts: Optional[Any] = None, | |
num_partitions: int = 1 | |
): | |
#self._constants = self._grouper.partition(self._mapper._constants) | |
#self._subscript_indices = self._grouper.partition(self._mapper._subscript_indices) | |
# Separate the inputs out into groups | |
grouped_xs = SortedDict() | |
original_indices = SortedDict() | |
for i, x in enumerate(xs): | |
group = group_by(x) | |
if group not in grouped_xs: | |
grouped_xs[group] = list() | |
original_indices[group] = list() | |
grouped_xs[group].append(x) | |
original_indices[group].append(i) | |
self.groups = SortedSet(grouped_xs.keys()) | |
self._original_indices = jax.numpy.concatenate([jax.numpy.array(indices) for indices in original_indices.values()]) | |
ordered_xs = reduce(lambda x, y: x + y, grouped_xs.values()) | |
self._mapper = Mapper(xs=ordered_xs, prepare_constants=prepare_constants, prepare_subscripts=prepare_subscripts) | |
self._partitioner = Partitioner(xs=ordered_xs, partition_by=group_by, num_partitions=num_partitions) | |
# Build an array to figure out if we need to reset our calculation | |
# The value will be positive when this is true and be an index | |
# into the initial values array | |
# The indices are stored as if arrays were indexed from 1 because there's | |
# no difference in a negative zero and a positive zero | |
initial_value_index_builder = [] | |
for group, group_xs in grouped_xs.items(): | |
group_size = len(group_xs) | |
initial_value_index = grouped_xs.index(group) + 1 | |
initial_value_index_builder.append(initial_value_index) | |
initial_value_index_builder.extend((group_size - 1) * [-1 * initial_value_index]) | |
self._initial_value_indices = jax.numpy.array(initial_value_index_builder) | |
@property | |
def subscript_domains(self): | |
return self._mapper.subscript_domains | |
def __call__( | |
self, | |
update: Callable[[Y, Any, Any], Y], | |
initial_values: jax.numpy.ndarray, | |
parameters: Optional[Any] = None | |
): | |
def scanner(previous, args): | |
initial_value_index, constants, subscript_indices = args | |
flat_parameters, parameters_treedef = jax.tree_util.tree_flatten(parameters) | |
flat_subscript_indices, subscript_indices_treedef = jax.tree_util.tree_flatten(subscript_indices) | |
assert parameters_treedef == subscript_indices_treedef | |
local_parameters = jax.tree_util.tree_unflatten( | |
parameters_treedef, | |
(parameter[index] for parameter, index in zip(flat_parameters, flat_subscript_indices)) | |
) | |
current = jax.numpy.where( | |
initial_value_index > 0, | |
initial_values[initial_value_index - 1], | |
update(previous, constants, local_parameters) | |
) | |
return current, current | |
partitioned_initial_value_indices = jax.tree_util.tree_map(self._partitioner.partition, self._initial_value_indices) | |
partitioned_constants = jax.tree_util.tree_map(self._partitioner.partition, self._mapper._constants) | |
partitioned_subscript_indices = jax.tree_util.tree_map(self._partitioner.partition, self._mapper._subscript_indices) | |
def mapper(initial_value_indices, constants, subscript_indices): | |
return jax.lax.scan( | |
f=scanner, | |
init=0.0, | |
xs=(initial_value_indices, constants, subscript_indices) | |
) | |
mapped_scan = jax.vmap(mapper, in_axes=(0, 0, 0)) | |
_, partitioned_ys = mapped_scan( | |
partitioned_initial_value_indices, | |
partitioned_constants, | |
partitioned_subscript_indices | |
) | |
return self._partitioner.unpartition(partitioned_ys)[self._original_indices] | |
Position = namedtuple("Position", ["letter", "t", "x"]) | |
positions: List[Position] = [] | |
rng = numpy.random.default_rng(seed=0) | |
for letter in ["A", "B", "C"]: | |
length = rng.integers(1, 10) | |
x = rng.normal(0.0, 10.0) # Initial position | |
for t in range(length): | |
x = x + rng.normal(0.0, 0.1) # Walk a little | |
x_meas = x + rng.normal(0.0, 0.2) # Add some measurement noise | |
positions.append( | |
Position(letter, t, x_meas) | |
) | |
def update(previous, _, parameters): | |
innovation, scale_walk = parameters | |
return previous + scale_walk * innovation | |
def group_by(position: Position): | |
return position.letter | |
def prepare_innovation_subscripts(position: Position): | |
return (position.letter, position.t) | |
rw = Scanner( | |
positions, | |
group_by=group_by, | |
prepare_subscripts=(prepare_innovation_subscripts, lambda x: x.letter), | |
num_partitions=2 | |
) | |
def random_walk_likelihood(): | |
innovation_domain, scale_walk_domain = rw.subscript_domains | |
innovations = haiku.get_parameter("innovations", shape=[len(innovation_domain)], init=jax.numpy.zeros) | |
log_scale_walk = haiku.get_parameter("log_scale_walk", shape=[len(scale_walk_domain)], init=jax.numpy.zeros) | |
scale_walk = jax.numpy.exp(log_scale_walk) | |
log_scale_measure = haiku.get_parameter("log_scale_measure", shape=(), init=jax.numpy.zeros) | |
scale_measure = jax.numpy.exp(log_scale_measure) | |
log_initial_value_scale = haiku.get_parameter("initial_value_scale", shape=[], init=jax.numpy.zeros) | |
initial_value_scale = jax.numpy.exp(log_initial_value_scale) | |
initial_values_z = haiku.get_parameter("initial_values_z", shape=[len(rw.groups)], init=jax.numpy.zeros) | |
initial_values = initial_values_z * initial_value_scale | |
mu = rw( | |
update=update, | |
initial_values=initial_values, | |
parameters=(innovations, scale_walk) | |
) | |
x_meas = jax.numpy.array([position.x for position in positions]) | |
return jax.numpy.sum(jax.scipy.stats.norm.logpdf(x_meas, mu, scale_measure)) | |
transformed = haiku.without_apply_rng(haiku.transform(random_walk_likelihood)) | |
rng_key = jax.random.PRNGKey(42) | |
params = transformed.init(rng=rng_key) | |
mu_out = transformed.apply(params) | |
print(mu_out) | |
initial_values = jax.numpy.array([1, 2, 3]) | |
innovations = jax.numpy.ones(len(rw.subscript_domains[0])) | |
scale_walk = jax.numpy.array([0.1, 0.2, 0.3]) | |
mu2 = rw(update, initial_values, (innovations, scale_walk)) | |
print(mu2) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment