Skip to content

Instantly share code, notes, and snippets.

@denisyarats
Created January 19, 2017 07:55
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save denisyarats/6c47eac17825abdccd5c22894da23072 to your computer and use it in GitHub Desktop.
Save denisyarats/6c47eac17825abdccd5c22894da23072 to your computer and use it in GitHub Desktop.
#!/usr/local/bin/python
"""
Q-learning - off policy TD(0) learning.
Q(S, A) <- Q(S, A) + alpha * ((R + gamma * max(Q(S', A'))) - Q(S, A))
A ~ e-greedy from pi(A|S)
"""
import argparse
import numpy as np
import matplotlib
from matplotlib import pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from collections import defaultdict
import gym
from gym import wrappers
import pdb
EXP_NAME_PREFIX = 'exp/q_learning'
API_KEY = '???'
ENVS = {
'copy': 'Copy-v0', # --env copy --alpha 0.4 --gamma 0.69 --eps 0.3 --eps_schedule 100 --goal 25
'frozenlake': 'FrozenLake-v0', # --env frozenlake --alpha 0.3 --gamma 0.95 --eps 0.4 --eps_schedule 100 --max_episodes 5000 --goal 0.78 --correction -0.0001
}
def decode(a, dims):
if len(dims) == 1:
return a
res = []
for d in reversed(dims):
res.append(a % d)
a /= d
res.reverse()
return res
def plot_surface(z, title):
x_range = np.arange(0, z.shape[0])
y_range = np.arange(0, z.shape[1])
x, y = np.meshgrid(x_range, y_range)
z = np.apply_along_axis(lambda t: z[t[0]][t[1]], 2, np.dstack([x, y]))
def plot(x, y, z, title):
fig = plt.figure(figsize=(20, 10))
ax = fig.add_subplot(111, projection='3d')
surf = ax.plot_surface(x, y, z, rstride=1, cstride=1,
cmap=matplotlib.cm.coolwarm, vmin=0.0, vmax=1.0)
ax.set_xlabel('episode')
ax.set_ylabel('action')
ax.set_zlabel('prob')
ax.set_title(title)
ax.view_init(ax.elev, -120)
fig.colorbar(surf)
plt.show()
plot(x, y, z, title)
def q_learning(env, max_episodes, alpha, gamma, eps,
eps_schedule, goal, correction):
if hasattr(env.action_space, 'spaces'):
dims = [d.n for d in env.action_space.spaces]
else:
dims = [env.action_space.n]
nA = np.prod(dims)
nS = env.observation_space.n
Q = np.zeros((nS, nA), np.float32)
P = np.zeros(nA, np.float32)
def policy(s):
P.fill(eps / nA)
P[np.argmax(Q[s])] += 1 - eps
return P
tR = np.zeros(100, np.float32)
for e in xrange(max_episodes):
if e % 50 == 0 and e > 0:
print 'episode %d, average reward: %.3f' % (e, np.mean(tR))
if np.mean(tR) > goal:
return e
if e % eps_schedule == 0 and e > 0:
eps /= 2
s = env.reset()
done = False
tR[e % tR.size] = 0.
while not done:
P = policy(s)
a = np.random.choice(xrange(nA), p=P)
ns, r, done, _ = env.step(decode(a, dims))
Q[s][a] += alpha * ((r - correction + gamma * np.max(Q[ns])) - Q[s][a])
s = ns
tR[e % tR.size] += r
return max_episodes
def main():
parser = argparse.ArgumentParser(description='Q-learning')
parser.add_argument('--env', choices=ENVS.keys())
parser.add_argument('--max_episodes', type=int, default=10000)
parser.add_argument('--alpha', type=float, default=1.0)
parser.add_argument('--gamma', type=float, default=1.0)
parser.add_argument('--eps', type=float, default=0.1)
parser.add_argument('--eps_schedule', type=int, default=10000)
parser.add_argument('--goal', type=float, default=1.0)
parser.add_argument('--correction', type=float, default=0.0)
parser.add_argument('--upload', action='store_true', default=False)
args = parser.parse_args()
exp_name = '%s_%s' % (EXP_NAME_PREFIX, args.env)
env = gym.make(ENVS[args.env])
env.seed(0)
np.random.seed(0)
if args.upload:
env = wrappers.Monitor(env, exp_name, force=True)
res = q_learning(env, args.max_episodes, args.alpha,
args.gamma, args.eps, args.eps_schedule, args.goal,
args.correction)
print 'result -> %d' % res
env.close()
if args.upload:
gym.upload(exp_name, api_key=API_KEY)
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment