Skip to content

Instantly share code, notes, and snippets.

@vwxyzjn
Last active March 19, 2019 05:00
Show Gist options
  • Save vwxyzjn/bb2075e55171106e4a5691f35f25d504 to your computer and use it in GitHub Desktop.
Save vwxyzjn/bb2075e55171106e4a5691f35f25d504 to your computer and use it in GitHub Desktop.
How Openai's baselines handles different types of observation spaces and action spaces
# https://github.com/hill-a/stable-baselines/blob/06f5843a3254ab7c2f6c927792e00365a778009e/stable_baselines/common/input.py#L6
def observation_input(ob_space, batch_size=None, name='Ob', scale=False):
"""
Build observation input with encoding depending on the observation space type
When using Box ob_space, the input will be normalized between [1, 0] on the bounds ob_space.low and ob_space.high.
:param ob_space: (Gym Space) The observation space
:param batch_size: (int) batch size for input
(default is None, so that resulting input placeholder can take tensors with any batch size)
:param name: (str) tensorflow variable name for input placeholder
:param scale: (bool) whether or not to scale the input
:return: (TensorFlow Tensor, TensorFlow Tensor) input_placeholder, processed_input_tensor
"""
if isinstance(ob_space, Discrete):
observation_ph = tf.placeholder(shape=(batch_size,), dtype=tf.int32, name=name)
processed_observations = tf.to_float(tf.one_hot(observation_ph, ob_space.n))
return observation_ph, processed_observations
elif isinstance(ob_space, Box):
observation_ph = tf.placeholder(shape=(batch_size,) + ob_space.shape, dtype=ob_space.dtype, name=name)
processed_observations = tf.to_float(observation_ph)
# rescale to [1, 0] if the bounds are defined
if (scale and
not np.any(np.isinf(ob_space.low)) and not np.any(np.isinf(ob_space.high)) and
np.any((ob_space.high - ob_space.low) != 0)):
# equivalent to processed_observations / 255.0 when bounds are set to [255, 0]
processed_observations = ((processed_observations - ob_space.low) / (ob_space.high - ob_space.low))
return observation_ph, processed_observations
elif isinstance(ob_space, MultiBinary):
observation_ph = tf.placeholder(shape=(batch_size, ob_space.n), dtype=tf.int32, name=name)
processed_observations = tf.to_float(observation_ph)
return observation_ph, processed_observations
elif isinstance(ob_space, MultiDiscrete):
observation_ph = tf.placeholder(shape=(batch_size, len(ob_space.nvec)), dtype=tf.int32, name=name)
processed_observations = tf.concat([
tf.to_float(tf.one_hot(input_split, ob_space.nvec[i])) for i, input_split
in enumerate(tf.split(observation_ph, len(ob_space.nvec), axis=-1))
], axis=-1)
return observation_ph, processed_observations
else:
raise NotImplementedError("Error: the model does not support input space of type {}".format(
type(ob_space).__name__))
# https://github.com/hill-a/stable-baselines/blob/06f5843a3254ab7c2f6c927792e00365a778009e/stable_baselines/common/distributions.py#L470
def make_proba_dist_type(ac_space):
"""
return an instance of ProbabilityDistributionType for the correct type of action space
:param ac_space: (Gym Space) the input action space
:return: (ProbabilityDistributionType) the approriate instance of a ProbabilityDistributionType
"""
if isinstance(ac_space, spaces.Box):
assert len(ac_space.shape) == 1, "Error: the action space must be a vector"
return DiagGaussianProbabilityDistributionType(ac_space.shape[0])
elif isinstance(ac_space, spaces.Discrete):
return CategoricalProbabilityDistributionType(ac_space.n)
elif isinstance(ac_space, spaces.MultiDiscrete):
return MultiCategoricalProbabilityDistributionType(ac_space.nvec)
elif isinstance(ac_space, spaces.MultiBinary):
return BernoulliProbabilityDistributionType(ac_space.n)
else:
raise NotImplementedError("Error: probability distribution, not implemented for action space of type {}."
.format(type(ac_space)) +
" Must be of type Gym Spaces: Box, Discrete, MultiDiscrete or MultiBinary.")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment