Skip to content

Instantly share code, notes, and snippets.

Avatar

Nikola Živković NMZivkovic

View GitHub Profile
View training.py
agent.train_step_counter.assign(0)
avg_return = get_average_return(evaluation_env, agent.policy, EVAL_EPISODES)
returns = [avg_return]
for _ in range(NUMBER_ITERATION):
for _ in range(COLLECTION_STEPS):
experience_replay.timestamp_data(train_env, agent.collect_policy)
View ExperienceReply.py
class ExperienceReply(object):
def __init__(self, agent, enviroment):
self._replay_buffer = TFUniformReplayBuffer(
data_spec=agent.collect_data_spec,
batch_size=enviroment.batch_size,
max_length=50000)
self._random_policy = RandomTFPolicy(train_env.time_step_spec(),
enviroment.action_spec())
View average_return.py
def get_average_return(environment, policy, episodes=10):
total_return = 0.0
for _ in range(episodes):
time_step = environment.reset()
episode_return = 0.0
while not time_step.is_last():
action_step = policy.action(time_step)
View DQNAgent.py
counter = tf.Variable(0)
agent = DqnAgent(
train_env.time_step_spec(),
train_env.action_spec(),
q_network = q_network,
optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=1e-3),
td_errors_loss_fn = common.element_wise_squared_loss,
train_step_counter = counter)
View qnetwork.py
hidden_layers = (100,)
q_network = QNetwork(
train_env.observation_spec(),
train_env.action_spec(),
fc_layer_params=hidden_layers)
View qnet_constructor.py
class QNetwork(network.Network):
"""Feed Forward network."""
def __init__(self,
input_tensor_spec,
action_spec,
preprocessing_layers=None,
preprocessing_combiner=None,
conv_layer_params=None,
fc_layer_params=(75, 40),
View enviroments.py
train_env = suite_gym.load('CartPole-v0')
evaluation_env = suite_gym.load('CartPole-v0')
print('Observation Spec:')
print(train_env.time_step_spec().observation)
print('Reward Spec:')
print(train_env.time_step_spec().reward)
print('Action Spec:')
View imports_globals.py
import base64
import imageio
import matplotlib
import matplotlib.pyplot as plt
import tensorflow as tf
from tf_agents.agents.dqn.dqn_agent import DqnAgent
from tf_agents.networks.q_network import QNetwork
View usage.py
dataProcessor = DataProcessor(32, 300, 500, list_dataset)
dataProcessor.load_process()
image_batch, label_batch = dataProcessor.get_batch()
View load_process.py
def load_process(self, shuffle_size = 1000):
self.loaded_dataset = self.dataset.map(self._load_labeled_data, num_parallel_calls=tf.data.experimental.AUTOTUNE)
self.loaded_dataset = self.loaded_dataset.cache()
# Shuffle data and create batches
self.loaded_dataset = self.loaded_dataset.shuffle(buffer_size=shuffle_size)
self.loaded_dataset = self.loaded_dataset.repeat()
self.loaded_dataset = self.loaded_dataset.batch(self.batch_size)
You can’t perform that action at this time.