Skip to content

Instantly share code, notes, and snippets.

@michaelosthege
Created March 4, 2019 16:47
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save michaelosthege/6953a2af7417da6ebdd41771a9e7e7a8 to your computer and use it in GitHub Desktop.
Save michaelosthege/6953a2af7417da6ebdd41771a9e7e7a8 to your computer and use it in GitHub Desktop.
Custom Theano Op for wrapping around an ODE-integrator.
import base64
import hashlib
import theano
import theano.tensor as tt
def make_hashable(obj):
"""Makes tuples, lists, dicts, sets and frozensets hashable."""
if isinstance(obj, (tuple, list)):
return tuple((make_hashable(e) for e in obj))
if isinstance(obj, dict):
return tuple(sorted((k, make_hashable(v)) for k,v in obj.items()))
if isinstance(obj, (set, frozenset)):
return tuple(sorted(make_hashable(e) for e in obj))
return obj
def make_hash_sha256(obj):
"""Computes a sha256 hash for the object."""
hasher = hashlib.sha256()
hasher.update(repr(make_hashable(obj)).encode())
return base64.b64encode(hasher.digest()).decode()
class IntegrationOp(theano.Op):
"""This is a theano Op that becomes a node in the computation graph.
It is not differentiable, because it uses a 'solver' function that is provided by the user.
"""
__props__ = ("solver",)
def __init__(self, solver):
self.solver = solver
return super().__init__()
def __hash__(self):
subhashes = (
hash(type(self)),
make_hash_sha256(self.solver)
)
return hash(subhashes)
def make_node(self, y0:list, x, theta:list):
# NOTE: theano does not allow a list of tensors to be one of the inputs
# that's why they have to be tt.stack()ed which also merges them into one dtype!
# TODO: check dtypes and raise warnings
y0 = tt.stack([tt.as_tensor_variable(y) for y in y0])
theta = tt.stack([tt.as_tensor_variable(t) for t in theta])
x = tt.as_tensor_variable(x)
apply_node = theano.Apply(self,
[y0, x, theta], # symbolic inputs: y0 and theta
[tt.dmatrix()]) # symbolic outputs: Y_hat
# NOTE: to support multiple different dtypes as transient variables, the
# output type would have to be a list of dvector/svectors.
return apply_node
def perform(self, node, inputs, output_storage):
# this performs the actual simulation using the provided solver
# which takes actual y0/x/theta values and returns a matrix
y0, x, theta = inputs
Y_hat = self.solver(y0, x, theta) # solve for all x
output_storage[0][0] = Y_hat
return
def grad(self, inputs, outputs):
return [theano.gradient.grad_undefined(self, k, inp,
'No gradient defined through Python-wrapping IntegrationOp.')
for k, inp in enumerate(inputs)]
def infer_shape(self, node, input_shapes):
s_y0, s_x, s_theta = input_shapes
output_shapes = [(s_y0[0],s_x[0])]
return output_shapes
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment