Created
November 4, 2016 09:49
-
-
Save denis-bz/4ade13b5c96812a7b61b1a2822410445 to your computer and use it in GitHub Desktop.
theano-example-how-to-monitor-gradients.py 2016-11-04 Nov
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# a tiny example of how to monitor gradients in theano | |
# from http://www.marekrei.com/blog/theano-tutorial/ 9. Minimal Training Example | |
# denis-bz 2016-11-04 nov | |
import theano | |
import theano.tensor as TT | |
import numpy as np | |
floatx = theano.config.floatX | |
np.set_printoptions( threshold=20, edgeitems=10, linewidth=100, | |
formatter = dict( float = lambda x: "%.2g" % x )) # print arrays %.2g | |
#............................................................................... | |
# theano dataflow graphs: inputs x target, state W -> y, cost, gradients | |
x = TT.fvector('x') | |
target = TT.fscalar('target') | |
W = theano.shared(np.asarray([0.2, 0.7]), 'W') # state | |
y = (x * W).sum() | |
cost = TT.sqr(target - y) | |
gradients = TT.grad(cost, [W]) # dcost/dW | |
W_updated = W - (0.1 * gradients[0]) | |
updates = [(W, W_updated)] | |
updater = theano.function([x, target], y, updates=updates) | |
# an unusual "function" -- | |
# each call updater() takes a gradient step, updates state W | |
costf = theano.function( [x, target], cost, name='costf' ) | |
gradf = theano.function( [x, target], gradients, name='gradf' ) | |
# not [x, W, target] -- function args may not be shared vars | |
# flow: x, target, W too -> cost -> gradients | |
x0 = np.array( [1.0, 1.0] ).astype(floatx) | |
target0 = 20.0 | |
#............................................................................... | |
print "# cost, y = x . W, state W, 0.1 gradient dcost/dW" | |
for i in xrange(10): | |
Wold = W.get_value() | |
gradinc = 0.1 * gradf( x0, target0 )[0] | |
yout = updater( x0, target0 ) # updates W | |
Wnew = W.get_value() | |
print "cost %-6.3g y %-6.3g W %-12s 0.1 grad %s" % ( | |
costf( x0, target0 ), yout, Wold, gradinc ) | |
# cost 131 y 0.9 W [0.2 0.7] 0.1 grad [-3.8 -3.8] | |
# cost 47.3 y 8.54 W [4 4.5] 0.1 grad [-2.3 -2.3] | |
# cost 17 y 13.1 W [6.3 6.8] 0.1 grad [-1.4 -1.4 | |
# ... |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment