Skip to content

Instantly share code, notes, and snippets.

@denisyarats
Created January 17, 2017 07:16
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save denisyarats/4981579b42c8a08e49206347f7d28c6c to your computer and use it in GitHub Desktop.
Save denisyarats/4981579b42c8a08e49206347f7d28c6c to your computer and use it in GitHub Desktop.
./sarsa.py --max_episodes 10000 --alpha 0.3 --gamma 0.9 --eps 0.2 --eps_schedule 200 --goal 25 --env copy --upload
#!/usr/local/bin/python
"""
SARSA - on policy TD(0) learning.
Q(S, A) <- Q(S, A) + alpha * ((R + gamma * Q(S', A')) - Q(S, A))
A, A' ~ e-greedy from pi(A|S)
"""
import argparse
import numpy as np
from collections import defaultdict
import gym
from gym import wrappers
import pdb
EXP_NAME_PREFIX = 'exp/sarsa'
API_KEY = '???'
ENVS = {
'copy': 'Copy-v0',
}
def decode(a, dims):
if len(dims) == 1:
return a
res = []
for d in reversed(dims):
res.append(a % d)
a /= d
res.reverse()
return res
def sarsa(env, max_episodes, alpha, gamma, eps, eps_schedule, goal):
if hasattr(env.action_space, 'spaces'):
dims = [d.n for d in env.action_space.spaces]
else:
dims = [env.action_space.n]
nA = np.prod(dims)
nS = env.observation_space.n
Q = np.zeros((nS, nA), np.float32)
P = np.zeros(nA, np.float32)
def exec_policy(s):
P.fill(eps / nA)
P[np.argmax(Q[s])] += 1 - eps
return np.random.choice(xrange(nA), p=P)
tR = np.zeros(100, np.float32)
for e in xrange(max_episodes):
if e % 50 == 0 and e > 0:
print 'episode %d, average reward: %.3f' % (e, np.mean(tR))
if np.mean(tR) > goal:
return e
if e % eps_schedule == 0 and e > 0:
eps /= 2
s = env.reset()
a = exec_policy(s)
done = False
tR[e % tR.size] = 0.
while not done:
ns, r, done, _ = env.step(decode(a, dims))
na = exec_policy(ns)
Q[s][a] += alpha * ((r + gamma * Q[ns][na]) - Q[s][a])
s, a = ns, na
tR[e % tR.size] += r
return max_episodes
def main():
parser = argparse.ArgumentParser(description='SARSA')
parser.add_argument('--env', choices=ENVS.keys())
parser.add_argument('--max_episodes', type=int, default=10000)
parser.add_argument('--alpha', type=float, default=1.0)
parser.add_argument('--gamma', type=float, default=1.0)
parser.add_argument('--eps', type=float, default=0.1)
parser.add_argument('--eps_schedule', type=int, default=10000)
parser.add_argument('--goal', type=float, default=1.0)
parser.add_argument('--upload', action='store_true', default=False)
args = parser.parse_args()
exp_name = '%s_%s' % (EXP_NAME_PREFIX, args.env)
env = gym.make(ENVS[args.env])
env.seed(0)
np.random.seed(0)
if args.upload:
env = wrappers.Monitor(env, exp_name, force=True)
res = sarsa(env, args.max_episodes, args.alpha,
args.gamma, args.eps, args.eps_schedule, args.goal)
print 'result -> %d' % res
env.close()
if args.upload:
gym.upload(exp_name, api_key=API_KEY)
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment