Skip to content

Instantly share code, notes, and snippets.

@pierrelux
Last active April 24, 2021 06:59
Show Gist options
  • Save pierrelux/5edf0fcb845e7d8213888a925c0d58e7 to your computer and use it in GitHub Desktop.
Save pierrelux/5edf0fcb845e7d8213888a925c0d58e7 to your computer and use it in GitHub Desktop.
Four ways to compute discounted returns
import numpy as onp
from scipy.signal import lfilter
import jax
import jax.numpy as jnp
def discount_lfilter(rewards, discount):
return lfilter(b=[1], a=[1, -discount], x=rewards[::-1])[::-1]
def discount_correlate(rewards, discount):
nsamples = rewards.shape[0]
return onp.correlate(rewards, onp.power(discount, onp.arange(nsamples)),'full')[-nsamples:]
def discount_convolve(rewards, discount):
nsamples = rewards.shape[0]
return onp.convolve(rewards[::-1], onp.power(discount, onp.arange(nsamples)),'full')[:nsamples][::-1]
def convolve1D(x, y):
# Based on https://github.com/google/jax/issues/1561
x_jax = jnp.reshape(x,(1,1,len(x)))
y_jax = jnp.flip(jnp.reshape(y,(1,1,len(y))),2)
return jnp.ravel(jax.lax.conv_general_dilated(x_jax,y_jax,[1],[(len(x)-1,len(x)-1)]))
def discount_convolve_jax(rewards, discount):
nsamples = rewards.shape[0]
discount_sequence = jnp.power(discount, jnp.arange(nsamples))
filtered_rewards = convolve1D(rewards[::-1], discount_sequence)
return filtered_rewards[:nsamples][::-1]
if __name__ == "__main__":
discount = 0.9
rewards = onp.array([1,2,3,4], dtype=float)
print(discount_lfilter(rewards, discount))
print(discount_correlate(rewards, discount))
print(discount_convolve(rewards, discount))
print(discount_convolve_jax(rewards, discount))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment