Skip to content

Instantly share code, notes, and snippets.

@JannesKlaas
Last active August 6, 2019 21:43
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save JannesKlaas/6c0ebb9c99a53b4c92e30548d3674886 to your computer and use it in GitHub Desktop.
Save JannesKlaas/6c0ebb9c99a53b4c92e30548d3674886 to your computer and use it in GitHub Desktop.
def train(model,epochs):
# Train
#Reseting the win counter
win_cnt = 0
# We want to keep track of the progress of the AI over time, so we save its win count history
win_hist = []
#Epochs is the number of games we play
for e in range(epochs):
loss = 0.
#Resetting the game
env.reset()
game_over = False
# get initial input
input_t = env.observe()
while not game_over:
#The learner is acting on the last observed game screen
#input_t is a vector containing representing the game screen
input_tm1 = input_t
#Take a random action with probability epsilon
if np.random.rand() <= epsilon:
#Eat something random from the menu
action = np.random.randint(0, num_actions, size=1)
else:
#Choose yourself
#q contains the expected rewards for the actions
q = model.predict(input_tm1)
#We pick the action with the highest expected reward
action = np.argmax(q[0])
# apply action, get rewards and new state
input_t, reward, game_over = env.act(action)
#If we managed to catch the fruit we add 1 to our win counter
if reward == 1:
win_cnt += 1
#Uncomment this to render the game here
#display_screen(action,3000,inputs[0])
"""
The experiences < s, a, r, s’ > we make during gameplay are our training data.
Here we first save the last experience, and then load a batch of experiences to train our model
"""
# store experience
exp_replay.remember([input_tm1, action, reward, input_t], game_over)
# Load batch of experiences
inputs, targets = exp_replay.get_batch(model, batch_size=batch_size)
# train model on experiences
batch_loss = model.train_on_batch(inputs, targets)
#sum up loss over all batches in an epoch
loss += batch_loss
win_hist.append(win_cnt)
return win_hist
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment