Skip to content

Instantly share code, notes, and snippets.

@sisp
Last active December 26, 2015 12:49
Show Gist options
  • Save sisp/7154041 to your computer and use it in GitHub Desktop.
Save sisp/7154041 to your computer and use it in GitHub Desktop.
import numpy as np
import theano
import theano.tensor as T
floatX = theano.config.floatX
def scan1(a, y, x=None):
"""
works
"""
def step(a_, y_):
return 0.9 * a_, theano.scan_module.until(y_ <= 0)
return theano.scan(step, outputs_info=[a], non_sequences=[y], n_steps=10)
def scan11(a, y, x=None):
"""
does not work
This case is somewhat close to what I am doing in line search.
"""
def step(a_, y_):
y_ = theano.clone(y_, replace={x: 2*x})
return 0.9 * a_, theano.scan_module.until(y_ <= 0)
return theano.scan(step, outputs_info=[a], non_sequences=[y], n_steps=10)
def scan2(a, y, x=None):
"""
does not work
"""
def step(a_):
return 0.9 * a_, theano.scan_module.until(y <= 0)
return theano.scan(step, outputs_info=[a], n_steps=10)
def run(a, y, givens, scan_fn, x=None):
rval, updates = scan_fn(a, y, x)
f = theano.function([], rval, givens=givens, updates=updates)
print f()
if __name__ == '__main__':
a = T.constant(1.0, name='a')
x = T.vector('x')
y = x.sum()
X = theano.shared(np.random.uniform(size=10).astype(floatX), borrow=True)
givens = {x: X}
run(a, y, givens, scan1) # works
run(a, y, givens, scan11, x) # does not work
run(a, y, givens, scan2) # does not work
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment