Skip to content

Instantly share code, notes, and snippets.

@NMZivkovic
Created December 23, 2019 11:46
Show Gist options
  • Save NMZivkovic/8d527876a6c826a017a3e325a60e3a1e to your computer and use it in GitHub Desktop.
Save NMZivkovic/8d527876a6c826a017a3e325a60e3a1e to your computer and use it in GitHub Desktop.
class ExperienceReply(object):
def __init__(self, agent, enviroment):
self._replay_buffer = TFUniformReplayBuffer(
data_spec=agent.collect_data_spec,
batch_size=enviroment.batch_size,
max_length=50000)
self._random_policy = RandomTFPolicy(train_env.time_step_spec(),
enviroment.action_spec())
self._fill_buffer(train_env, self._random_policy, steps=100)
self.dataset = self._replay_buffer.as_dataset(
num_parallel_calls=3,
sample_batch_size=BATCH_SIZE,
num_steps=2).prefetch(3)
self.iterator = iter(self.dataset)
def _fill_buffer(self, enviroment, policy, steps):
for _ in range(steps):
self.timestamp_data(enviroment, policy)
def timestamp_data(self, environment, policy):
time_step = environment.current_time_step()
action_step = policy.action(time_step)
next_time_step = environment.step(action_step.action)
timestamp_trajectory = trajectory.from_transition(time_step, action_step, next_time_step)
self._replay_buffer.add_batch(timestamp_trajectory)
experience_replay = ExpirienceReply(agent, train_env)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment