This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
'''Trains a multi-output deep NN on the MNIST dataset using crossentropy and | |
policy gradients (REINFORCE). | |
The goal of this example is twofold: | |
* Show how to use policy graidents for training | |
* Show how to use generators with multioutput models | |
# Policy graidients | |
This is a Reinforcement Learning technique [1] that trains the model | |
following the gradient of the logarithm of action taken scaled by the advantage | |
(reward - baseline) of that action. | |
# Generators |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from keras.optimizers import Adam | |
from keras import backend as K | |
from keras.datasets import mnist | |
from keras.utils.np_utils import to_categorical | |
from keras.metrics import categorical_accuracy | |
from keras.initializations import glorot_uniform, zero | |
import numpy as np | |
# inputs and targets are placeholders | |
input_dim = 28*28 |