Skip to content

Instantly share code, notes, and snippets.

@maitchison
Created November 22, 2022 18:54
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save maitchison/2f8f294c93eca3dceb2956de8cbfd5be to your computer and use it in GitHub Desktop.
Save maitchison/2f8f294c93eca3dceb2956de8cbfd5be to your computer and use it in GitHub Desktop.
import matplotlib.pyplot as plt
import numpy as np
gamma = 0.9
def true_return(rewards, terminals):
returns = []
acc = 0
for r, term in zip(rewards[::-1], terminals[::-1]):
acc = acc * gamma * (1-term) + r
returns.append(acc)
return returns[::-1]
def forward_return_v1(rewards, terminals):
returns = []
acc = 0
for r, term in zip(rewards, terminals):
acc = acc * gamma * (1-term) + r
returns.append(acc)
return returns
def forward_return_v2(rewards, terminals):
returns = []
acc = 0
for r, term in zip(rewards, terminals):
acc = acc * gamma + r
returns.append(acc)
acc *= (1-term)
return returns
rew = [0,0,1,0,2,0]
term = [0, 0, 1, 0, 0, 0]
print("Returns on simple example.")
print("True: ", true_return(rew,term))
print("Before:", forward_return_v1(rew,term))
print("After: ", forward_return_v2(rew,term))
xs = []
ys = []
zs = []
for _ in range(100):
gamma = 0.99
rew = np.random.randint(1,5, size=1000)
term = np.random.rand(1000) < 0.01
rew *= term
a = np.asarray(true_return(rew,term))
b = np.asarray(forward_return_v1(rew,term))
c = np.asarray(forward_return_v2(rew,term))
xs.append(a.var())
ys.append(b.var())
zs.append(c.var())
plt.scatter(xs, ys, marker='x', label='zero_before_append')
plt.scatter(xs, zs, marker='o', label='zero_after_append')
plt.plot(range(3), range(3), label='true return', color='black', ls='--')
plt.legend()
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment