Skip to content

Instantly share code, notes, and snippets.

@pierrelux
Last active August 7, 2019 17:30
Show Gist options
  • Save pierrelux/d7ee58fdcb5dd75f736ba8e2450385bb to your computer and use it in GitHub Desktop.
Save pierrelux/d7ee58fdcb5dd75f736ba8e2450385bb to your computer and use it in GitHub Desktop.
def dadashi_fig2d():
""" Figure 2 d) of
''The Value Function Polytope in Reinforcement Learning''
by Dadashi et al. (2019) https://arxiv.org/abs/1901.11524
Returns:
tuple (P, R, gamma) where the first element is a tensor of shape
(A x S x S), the second element 'R' has shape (S x A) and the
last element is the scalar (float) discount factor.
"""
P = np.array([[[0.7, 0.3], [0.2, 0.8]],
[[0.99, 0.01], [0.99, 0.01]]])
R = np.array(([[-0.45, -0.1],
[0.5, 0.5]]))
return P, R, 0.9
def mdp_to_dot(P, R, discount):
del discount
graph = gv.Digraph(
body=['d2tdocpreamble = "\\usetikzlibrary{automata}"'],
node_attr={'style': 'state'},
edge_attr={'lblstyle': 'auto'}) # , 'topath': 'bend left'})
graph.graph_attr['rankdir'] = 'LR'
for a in range(P.shape[0]):
for i in range(P.shape[1]):
for j in range(P.shape[2]):
if P[a, i, j] > 1e-5:
graph.edge(str(i), str(j), label=f"({a}, {R[i,a]:.3f}, {P[a,i,j]:.3f})")
return graph.source
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment