Skip to content

Instantly share code, notes, and snippets.

@FLamparski
Last active July 20, 2018 16:48
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save FLamparski/b3540528828aa4570c4c90fbb30e6490 to your computer and use it in GitHub Desktop.
Save FLamparski/b3540528828aa4570c4c90fbb30e6490 to your computer and use it in GitHub Desktop.
Further adventures in reinforcement learning
Display the source blob
Display the rendered blob
Raw
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
import random
import numpy as np
from collections import deque
class QAgent:
def __init__(self, model, γ=0.98, batch_size=64):
self.model = model
self.memory = deque(maxlen=10000)
self.γ = γ
self.batch_size = batch_size
self.metrics_log = []
def get_Q(self, state):
xs = state.reshape(1, state.shape[0])
ys = self.model.predict(xs)
return ys.reshape(ys.shape[1])
def remember(self, state, action, reward, next_state, done):
self.memory.append((state, action, reward, next_state, done))
def learn(self, rounds=1):
sample_size = min(len(self.memory), self.batch_size)
for _ in range(rounds):
xs, ys = self.experiences(sample_size)
self.model.train_on_batch(np.array(xs), np.array(ys))
def test(self):
sample_size = min(len(self.memory), self.batch_size)
xs, ys = self.experiences(sample_size)
metrics = self.model.evaluate(np.array(xs), np.array(ys), verbose=0)
self.metrics_log.append(metrics)
return metrics
def experiences(self, sample_size):
sample = random.sample(self.memory, sample_size)
xs, ys = [], []
for state, action, reward, next_state, done in sample:
Qrow = self.get_Q(state)
Qnext = self.get_Q(next_state)
Qrow[action] = reward if done else reward + self.γ * np.max(Qnext)
xs.append(state)
ys.append(Qrow)
return xs, ys
@FLamparski
Copy link
Author

See also: log_progress

Problem: how can I avoid the model accidentally converging on just picking one action? How can I make it actually learn the objective?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment