Skip to content

Instantly share code, notes, and snippets.

@jskDr
Last active November 23, 2019 09:32
Show Gist options
  • Save jskDr/c1edf3d7475ec7bda61fc8ffabaf5064 to your computer and use it in GitHub Desktop.
Save jskDr/c1edf3d7475ec7bda61fc8ffabaf5064 to your computer and use it in GitHub Desktop.
High-level implementation of ActorCritic with minimal typing (General implementation)
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"# Based on AC code from https://github.com/kimmyungsup/Reinforcement-Learning-with-Tensorflow-2.0/blob/master/ActorCritic_tf20/a2c_tf20.py\n",
"import gym\n",
"from rl.ac_tf2 import ActorModel, CriticModel, ActorCriticTrain, ReplayBuff\n",
"from rl.ac_tf2 import ac_step, ac_train, ac_report"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"env = gym.make('CartPole-v0')\n",
"num_action = env.action_space.n"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"e : 0 reward : 19.0 step : 20\n",
"e : 100 reward : 27.0 step : 28\n",
"e : 200 reward : 100.0 step : 101\n",
"e : 300 reward : 48.0 step : 49\n",
"e : 400 reward : 138.0 step : 139\n"
]
}
],
"source": [
"actor_critic = ActorCriticTrain(num_action)\n",
"\n",
"t_end = 500\n",
"epi = 500\n",
"train_size = 20\n",
" \n",
"buff = ReplayBuff() \n",
"\n",
"state = env.reset()\n",
"for e in range(epi):\n",
" total_reward = 0\n",
" for t in range(t_end): \n",
" state, total_reward, done = ac_step(env, actor_critic, buff, state, total_reward, t_end, t)\n",
" \n",
" ac_train(actor_critic, buff, train_size, done)\n",
"\n",
" if done:\n",
" env.reset()\n",
" ac_report(actor_critic, total_reward, e, t)\n",
" break"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "tf2",
"language": "python",
"name": "tf2"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.4"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment