Skip to content

Instantly share code, notes, and snippets.

@pierrelux
pierrelux / discount.py
Last active April 24, 2021 06:59
Four ways to compute discounted returns
import numpy as onp
from scipy.signal import lfilter
import jax
import jax.numpy as jnp
def discount_lfilter(rewards, discount):
return lfilter(b=[1], a=[1, -discount], x=rewards[::-1])[::-1]
def discount_correlate(rewards, discount):
@pierrelux
pierrelux / forest.py
Created October 10, 2019 19:53
Forest management example
import numpy as onp
def forest_management(forest_stages=3, r1=4, r2=2, p=0.1):
"""Forest management example from the MDPToolbox package.
Chadès, I., Chapron, G., Cros, M.‐J., Garcia, F. and Sabbadin, R. 2014.
MDPtoolbox: a multi‐platform toolbox to solve stochastic dynamic programming problems.
Ecography 37: 916–920 (ver. 0).
from jax import jvp, grad
def f(x,y):
return x + y**2
def freeze(f, argnum, val):
def _f(arg):
args = [val, arg] if argnum == 0 else [arg, val]
return f(*args)
return _f
def dadashi_fig2d():
""" Figure 2 d) of
''The Value Function Polytope in Reinforcement Learning''
by Dadashi et al. (2019) https://arxiv.org/abs/1901.11524
Returns:
tuple (P, R, gamma) where the first element is a tensor of shape
(A x S x S), the second element 'R' has shape (S x A) and the
last element is the scalar (float) discount factor.
"""
P = np.array([[[0.7, 0.3], [0.2, 0.8]],
@pierrelux
pierrelux / accumulate_discounted.py
Created July 19, 2019 17:59
VJP through a specific lfilter performing discounting on an array of scalar elements (rewards).
import autograd.numpy as np
from scipy.signal import lfilter
from autograd.extend import primitive, defvjp
@primitive
def accumulate_discounted(rewards, discount=1.):
"""Behaves like `accumulate` but where each array element gets discounted.
Args:
rewards (np.ndarray): 1D array of rewards
discount (float): Scalar discount factor
@pierrelux
pierrelux / exact_pg.py
Created July 17, 2019 21:28
Exact Policy Gradient in jax, demonstrated in figure 2d of Dadashi et al. (2019)
import jax
import jax.numpy as np
from jax import grad, jit
from jax.scipy.special import logsumexp
def dadashi_fig2d():
""" Figure 2 d) of
''The Value Function Polytope in Reinforcement Learning''
by Dadashi et al. (2019) https://arxiv.org/abs/1901.11524
def induced_chain(transition, policy):
"""Marginalize the choice of actions under the given policy
Args:
transition (numpy.ndarray): Transition kernel as a (A x S x S) tensor
policy (numpy.ndarray): Policy as a (S x A) matrix
Returns:
numpy.ndarray: Marginalized transition matrix as a (S x S) matrix,
where the first dimension denote "source" states and the second is for
"destination" states. From i to j.
"""
@pierrelux
pierrelux / texrepeat.sh
Last active July 19, 2019 18:43
texrepeat.sh
texexpand thesis.tex | grep -E '(\b.+) \1\b'
@pierrelux
pierrelux / sshtunnel.txt
Last active June 15, 2018 15:04
Chrome SSH Tunneling
ssh -ND 1080 user@hostname
google-chrome --proxy-server="socks5://localhost:1080"
@pierrelux
pierrelux / live.sh
Created February 14, 2017 22:01
Stream to Youtube Live using ffmpeg and save a local copy
ffmpeg -f video4linux2 -thread_queue_size 512 -input_format mjpeg -video_size hd720 -i /dev/video0 \
-f alsa -thread_queue_size 512 -i hw:1,0,0 \
-acodec mp3 -async 1 -vcodec libx264 -preset veryfast output.mp4 \
-acodec mp3 -ar 44100 -vcodec libx264 -preset ultrafast -maxrate 1984k -bufsize 3968k -g 60 -f flv rtmp://x.rtmp.youtube.com/live2/yourid