Skip to content

Instantly share code, notes, and snippets.

Last active April 24, 2017 02:22
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save unixpickle/fc9229b28d79cd1a8e64ad29c98cd50f to your computer and use it in GitHub Desktop.
Save unixpickle/fc9229b28d79cd1a8e64ad29c98cd50f to your computer and use it in GitHub Desktop.
CartPole Decision Tree
import gym
def main():
env = gym.make('CartPole-v1')
obs = env.reset()
reward = 0
while True:
obs, rew, done, _ = env.step(policy(obs))
reward += rew
if done:
print('Total reward of %f' % reward)
def policy(obs):
Deterministic policy that wins CartPole almost
every time.
if obs[2] < -0.012419 or obs[3] < -0.091290:
return 0
return 1
// Slow and crappy hill-climbing algorithm for finding
// decision tree policy for CartPole.
package main
import (
const (
Host = ""
Population = 20
PolicyDepth = 2
MutateProb = 0.05
func main() {
policy := NewPolicy(PolicyDepth)
envs := make([]gym.Env, Population)
log.Printf("Creating %d environments...", Population)
for i := range envs {
var err error
envs[i], err = gym.Make(Host, "CartPole-v0")
r := rip.NewRIP()
var batch int
for !r.Done() {
policies := make([]PolicyNode, Population)
policies[0] = policy
for i := 1; i < Population; i++ {
policies[i] = policy.Copy()
rewards := Rollouts(policies, envs)
var maxReward float64
policy, maxReward = BestPolicy(policies, rewards)
log.Printf("batch=%d max_reward=%f", batch, maxReward)
func Rollouts(policies []PolicyNode, envs []gym.Env) []float64 {
res := make([]float64, len(envs))
var wg sync.WaitGroup
for i, p := range policies {
go func(p PolicyNode, e gym.Env, reward *float64) {
defer wg.Done()
obs, err := e.Reset()
var done bool
for !done {
var obsVec []float64
var rew float64
obs, rew, done, _, err = e.Step(p.Decide(obsVec))
*reward += rew
}(p, envs[i], &res[i])
return res
func BestPolicy(policies []PolicyNode, rewards []float64) (policy PolicyNode, maxReward float64) {
maxReward = math.Inf(-1)
for i, r := range rewards {
if r > maxReward {
maxReward = r
policy = policies[i]
func NewPolicy(depth int) PolicyNode {
if depth == 0 {
return &LeafNode{Decision: rand.Intn(2)}
res := &BranchNode{
Left: NewPolicy(depth - 1),
Right: NewPolicy(depth - 1),
return res
type PolicyNode interface {
Mutate(prob float64)
Decide(in []float64) int
Copy() PolicyNode
type BranchNode struct {
Param int
Thresh float64
Left PolicyNode
Right PolicyNode
func (b *BranchNode) String() string {
return fmt.Sprintf("if obs[%d] < %f {\n%s\n} else {\n%s\n}",
b.Param, b.Thresh, b.Left.String(), b.Right.String())
func (b *BranchNode) Mutate(prob float64) {
if rand.Float64() < prob {
b.Param = rand.Intn(4)
if rand.Float64() < prob {
b.Thresh = randomThreshold(b.Param)
func (b *BranchNode) Decide(in []float64) int {
if in[b.Param] > b.Thresh {
return b.Right.Decide(in)
} else {
return b.Left.Decide(in)
func (b *BranchNode) Copy() PolicyNode {
return &BranchNode{
Param: b.Param,
Thresh: b.Thresh,
Left: b.Left.Copy(),
Right: b.Right.Copy(),
type LeafNode struct {
Decision int
func (l *LeafNode) String() string {
return fmt.Sprintf("return %d", l.Decision)
func (l *LeafNode) Mutate(prob float64) {
if rand.Float64() < prob {
l.Decision = rand.Intn(2)
func (l *LeafNode) Decide(in []float64) int {
return l.Decision
func (l *LeafNode) Copy() PolicyNode {
return &LeafNode{Decision: l.Decision}
func randomThreshold(param int) float64 {
return rand.NormFloat64() * 10
func must(err error) {
if err != nil {
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment