Skip to content

Instantly share code, notes, and snippets.

Created April 18, 2017 15:31
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save unixpickle/55647f3f987fc172b2ec51a52cc56763 to your computer and use it in GitHub Desktop.
Save unixpickle/55647f3f987fc172b2ec51a52cc56763 to your computer and use it in GitHub Desktop.
CartPole TRPO
package main
import (
gym ""
const (
BaseURL = "http://localhost:5000"
RolloutsPerBatch = 30
NumBatches = 50
RenderEnv = false
func main() {
// Connect to gym server.
client, err := gym.NewClient(BaseURL)
// Create environment instance.
id, err := client.Create("CartPole-v0")
defer client.Close(id)
// Start monitoring to "./cartpole-monitor".
workingDir, err := os.Getwd()
monitorFile := filepath.Join(workingDir, "gym-monitor")
must(client.StartMonitor(id, monitorFile, false, false, false))
defer client.CloseMonitor(id)
// Create a neural network policy.
creator := anyvec32.CurrentCreator()
policy := &anyrnn.LayerBlock{
Layer: anynet.Net{
anynet.NewFC(creator, 4, 32),
anynet.NewFC(creator, 32, 16),
anynet.NewFC(creator, 16, 2),
actionSampler := anyrl.Softmax{}
// Create an anyrl.Env from our gym environment.
env, err := anyrl.GymEnv(creator, client, id, RenderEnv)
// Setup Trust Region Policy Optimization for training.
trpo := &anyrl.TRPO{
NaturalPG: anyrl.NaturalPG{
Policy: policy,
Params: policy.Parameters(),
ActionSpace: actionSampler,
// This is akin to the learning rate.
TargetKL: 0.005,
for batchIdx := 0; batchIdx < NumBatches; batchIdx++ {
// Gather episode rollouts.
var rollouts []*anyrl.RolloutSet
for i := 0; i < RolloutsPerBatch; i++ {
rollout, err := anyrl.RolloutRNN(creator, policy, actionSampler, env)
rollouts = append(rollouts, rollout)
// Join the rollouts into one set.
r := anyrl.PackRolloutSets(rollouts)
// Print the rewards.
log.Printf("batch %d: mean_reward=%v", batchIdx, r.MeanReward(creator))
// Train on the rollouts.
grad := trpo.Run(r)
// Uncomment the code below to upload to the Gym website.
// Note: you must set the OPENAI_GYM_API_KEY environment
// variable or set the second argument of Upload() to a
// non-empty string.
// must(client.Upload(monitorFile, "", ""))
func must(err error) {
if err != nil {
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment