Skip to content

Instantly share code, notes, and snippets.

@denisyarats
Created January 15, 2017 05:32
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save denisyarats/f80ac232d2dadc6c04abd6dd64389bfe to your computer and use it in GitHub Desktop.
Save denisyarats/f80ac232d2dadc6c04abd6dd64389bfe to your computer and use it in GitHub Desktop.
on policy mc
#!/usr/local/bin/python
import argparse
import numpy as np
from collections import defaultdict
import gym
from gym import wrappers
import pdb
EXP_NAME_PREFIX = 'exp/on_policy_mc'
API_KEY = 'sk_ARsYZ2eRsGoeANVhUgrQ'
ENVS = {
'copy': 'Copy-v0',
'repeatcopy': 'RepeatCopy-v0',
'duplicatedinput': 'DuplicatedInput-v0',
'reversedaddition': 'ReversedAddition-v0',
'reversedaddition3': 'ReversedAddition3-v0',
'reverse': 'Reverse-v0',
}
def decode(a, dims):
res = []
for d in reversed(dims):
res.append(a % d)
a /= d
res.reverse()
return res
def main():
parser = argparse.ArgumentParser(description='on policy mc')
parser.add_argument('--env', choices=ENVS.keys())
parser.add_argument('--max_episodes', type=int, default=10000)
parser.add_argument('--gamma', type=float, default=1.0)
parser.add_argument('--eps', type=float, default=0.1)
parser.add_argument('--goal', type=float, default=1.0)
parser.add_argument('--upload', action='store_true', default=False)
args = parser.parse_args()
exp_name = '%s_%s' % (EXP_NAME_PREFIX, args.env)
env = gym.make(ENVS[args.env])
if args.upload:
env = wrappers.Monitor(env, exp_name, force=True)
nS = env.observation_space.n
dims = [d.n for d in env.action_space.spaces]
nA = np.prod(dims)
Q = np.zeros((nS, nA), np.float32)
N = np.zeros((nS, nA), np.int32)
P = np.zeros((nS, nA), np.float32)
tR = np.zeros(100, np.float32)
for e in xrange(args.max_episodes):
if e % 1000 == 0 and e > 0:
print 'episode %d, average reward: %.3f' % (e, np.mean(tR))
if np.mean(tR) > args.goal:
break
if e % 10000 == 0 and e > 0:
args.eps /= 2
s = env.reset()
S, R, A = [], [], []
done = False
tR[e % tR.size] = 0.
while not done:
S.append(s)
P[s].fill(args.eps / nA)
P[s][np.argmax(Q[s])] += 1 - args.eps
a = np.random.choice(xrange(nA), p=P[s])
A.append(a)
s, r, done, _ = env.step(decode(a, dims))
R.append(r)
tR[e % tR.size] += r
G = 0.
for s, r, a in reversed(zip(S, R, A)):
G = args.gamma * G + r
N[s][a] += 1
Q[s][a] += (G - Q[s][a]) / N[s][a]
env.close()
if args.upload:
gym.upload(exp_name, api_key=API_KEY)
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment