Skip to content

Instantly share code, notes, and snippets.

@kadirpekel
Last active January 9, 2023 12:13
Show Gist options
  • Save kadirpekel/9d06d9292b3bb2efb5733ecd7d26466d to your computer and use it in GitHub Desktop.
Save kadirpekel/9d06d9292b3bb2efb5733ecd7d26466d to your computer and use it in GitHub Desktop.
CartPole RL Study
import time
import tensorflow as tf
import matplotlib.pyplot as plt
from tf_agents.environments import suite_gym
from tf_agents.networks.q_network import QNetwork
from tf_agents.agents.dqn import dqn_agent
from tf_agents.environments import TFPyEnvironment
from tf_agents.replay_buffers import TFUniformReplayBuffer
from tf_agents.policies.policy_saver import PolicySaver
from tf_agents.policies import random_tf_policy
from tf_agents.utils import common
from tf_agents.metrics import tf_metrics
from tf_agents.drivers.dynamic_step_driver import DynamicStepDriver
from keras.optimizers import Adam
num_iterations = 10000
save_dir = 'saved_states'
log_interval = 200
train_env = TFPyEnvironment(suite_gym.load('CartPole-v1'))
q_net = QNetwork(
train_env.observation_spec(),
train_env.action_spec(),
fc_layer_params=(100,)
)
train_step_counter = tf.Variable(0)
agent = dqn_agent.DqnAgent(
train_env.time_step_spec(),
train_env.action_spec(),
q_network=q_net,
optimizer=Adam(learning_rate=1e-3),
td_errors_loss_fn=common.element_wise_squared_loss,
train_step_counter=train_step_counter
)
agent.initialize()
replay_buffer = TFUniformReplayBuffer(data_spec=agent.collect_data_spec,
batch_size=train_env.batch_size,
max_length=10000)
random_policy = random_tf_policy.RandomTFPolicy(train_env.time_step_spec(),
train_env.action_spec())
collect_driver = DynamicStepDriver(
train_env,
# agent.collect_policy,
random_policy, # <---- This makes the model work as expected
observers=[replay_buffer.add_batch],
num_steps=1000)
collect_driver.run()
dataset = replay_buffer.as_dataset(
num_steps=agent._n_step_update + 1,
sample_batch_size=train_env.batch_size
).prefetch(3)
avg_return_metric = tf_metrics.AverageReturnMetric()
observers = [avg_return_metric, replay_buffer.add_batch]
collect_step_driver = DynamicStepDriver(
train_env,
agent.collect_policy, # <---- This remains as the agent's default
observers=observers,
num_steps=1
)
iterator = iter(dataset)
agent.train = common.function(agent.train)
agent.train_step_counter.assign(0)
policy = None
try:
policy = tf.saved_model.load(save_dir)
except:
print('No saved state found, training...')
episodes = []
steps = []
for _ in range(num_iterations):
collect_step_driver.run()
experience, unused_info = next(iterator)
train_loss = agent.train(experience)
step = agent.train_step_counter.numpy()
if step % log_interval == 0:
print('step = {0}: loss = {1}'.format(step, train_loss.loss))
episodes.append(train_loss.loss)
steps.append(step)
print('Average return: {}'.format(avg_return_metric.result().numpy()))
policy = agent.policy
policy_saver = PolicySaver(agent.policy)
policy_saver.save(save_dir)
plt.plot(steps, episodes)
plt.xlabel('Steps')
plt.ylabel('Average Return')
plt.show()
test_env = TFPyEnvironment(suite_gym.load('CartPole-v1'))
num_episodes = 20
for _ in range(num_episodes):
time_step = test_env.reset()
while not time_step.is_last():
action_step = policy.action(time_step)
time_step = test_env.step(action_step)
test_env.render(mode='human')
time.sleep(0.01)
train_env.close()
test_env.close()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment