Skip to content

Instantly share code, notes, and snippets.

@d0znpp
Created December 12, 2017 00:47
Show Gist options
  • Save d0znpp/527c65a8b0d432643e476f7e955ec711 to your computer and use it in GitHub Desktop.
Save d0znpp/527c65a8b0d432643e476f7e955ec711 to your computer and use it in GitHub Desktop.
def train(mnist, max_layers):
sess = tf.Session()
global_step = tf.Variable(0, trainable=False)
starter_learning_rate = 0.1
learning_rate = tf.train.exponential_decay(0.99, global_step,
500, 0.96, staircase=True)
optimizer = tf.train.RMSPropOptimizer(learning_rate=learning_rate)
reinforce = Reinforce(sess, optimizer, policy_network, args.max_layers, global_step)
net_manager = NetManager(num_input=784,
num_classes=10,
learning_rate=0.001,
mnist=mnist)
MAX_EPISODES = 250
step = 0
state = np.array( [[10.0, 128.0, 1.0, 1.0]*max_layers], dtype=np.float32)
pre_acc = 0.0
for i_episode in range(MAX_EPISODES):
action = reinforce.get_action(state)
print("current action:", action)
if all(ai > 0 for ai in action[0][0]):
reward, pre_acc = net_manager.get_reward(action, step, pre_acc)
else:
reward = -1.0
# In our sample action is equal state
state = action[0]
reinforce.store_rollout(state, reward)
step += 1
ls = reinforce.train_step(MAX_STEPS)
log_str = "current time: "+str(datetime.datetime.now().time())+" episode: "+str(i_episode)+" loss: "+str(ls)+" last_state: "+str(state)+" last_reward: "+str(reward)
print(log_str)
def main():
max_layers = 3
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
train(mnist, max_layers)
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment