Skip to content

Instantly share code, notes, and snippets.

@aabbas90
Last active June 30, 2023 11:40
Show Gist options
  • Save aabbas90/9e29b464e20971258ab7f424904193d9 to your computer and use it in GitHub Desktop.
Save aabbas90/9e29b464e20971258ab7f424904193d9 to your computer and use it in GitHub Desktop.
import numpy as np
class ReplayBuffer:
def __init__(self, a = 1.0, b = 0.7):
self.keys = []
self.values = []
self.a = a
# b < 1.0 allows to sample near the max key more frequently.
# This allows to sample more trajectories where optimization is difficult, 1.0 will do uniform distribution.
self.b = b
self.min_key = float("Inf")
self.max_key = -float("Inf")
def add(self, key, value):
# key should be the objective value (float) of the current iterate.
# value is any arbitrary data needing to be stored for each iterate.
self.keys.append(key)
self.values.append(value)
self.min_key = min(self.min_key, key)
self.max_key = max(self.max_key, key)
def sample(self):
assert len(self.keys) > 0
# Sample from beta distribution to allow non-uniform sampling.
sampled_key = np.random.beta(self.a, self.b) * (self.max_key - self.min_key) + self.min_key
best_index = np.argmin(np.abs(np.array(self.keys) - sampled_key))
return self.values[best_index]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment