Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
An efficient way of generating a large number of trajectories for a QuadrotorEnvironment
import time
import numpy as np
from quadrotor_environment.quadrotor_model import SysState
from quadrotor_environment.simulation_parameter_randomization import SPRWrappedQuadrotorEnvironment
from stable_baselines import PPO2
state_dtype = np.dtype([("position", np.float, 3),
("velocity", np.float, 3),
("rotation", np.quaternion),
("angular_velocity", np.float, 3),
("propeller_speed", np.float, 4)])
trajectory_dtype = np.dtype([("time", np.float), ("action", np.float, 4), ("state", state_dtype)])
trajectory_list_dtype = np.dtype([("trajectory", trajectory_dtype)])
def record_episodes(model, env_params, initial_states, num_steps, spr_seeds=None, batch_size=20):
"""
Records a number of episodes and yields them one by one in a structured numpy array.
The episodes are recorded in parallel in batches, to make use of the improved speed of parallel model predictions.
No more than `batch_size` episodes are kept in memory at any time. Thus this function is suitable for generating and
processing a large number of episodes efficiently, which means that we take advantage of parallel model predictions
while not occupying too much memory.
Note that the last batch might be smaller than batch_size.
:param model: The model to control the quadrotor. We use it to make deterministic predictions.
:param env_params: The parameters of the environment to record the episodes in. Refer to the documentation of
quadrotor_environment.simulation_parameter_randomization.SPRWrappedQuadrotorEnvironment for allowed values.
:param initial_states: A list containing an initial state for each episode to generate. If the list contains
something else than initial states, then a random state is generated using the specified environment.
:param num_steps: The number of steps to simulate forward for each episode.
:param spr_seeds: A list of seeds to pass to the SPRWrappedQuadrotorEnvironment
:param batch_size: The number of trajectories to generate in parallel.
:return: yields structured numpy arrays containing a trajectory each. Refer to
trajectory_generation.trajectory_dtype for the actual structure.
"""
for batch_start in range(0, len(initial_states), batch_size):
batch = initial_states[batch_start:min(batch_start+batch_size, len(initial_states))]
seed_batch = spr_seeds[batch_start:min(batch_start+batch_size, len(initial_states))] if spr_seeds else None
for traj in _record_episodes(model, env_params, batch, num_steps, spr_seeds=seed_batch):
yield traj
def _record_episodes(model: PPO2, env_params: dict, initial_states: list, num_steps: int, spr_seeds:list = None):
"""
Records a number of episodes and returns them in a structured numpy array.
The episodes are all recorded in parallel, each in its own copy of the environment to make use of the improved
speed of parallel model predictions.
:param model: The model to control the quadrotor. We use it to make deterministic predictions.
:param env_params: The parameters of the environment to record the episodes in. Refer to the documentation of
quadrotor_environment.simulation_parameter_randomization.SPRWrappedQuadrotorEnvironment for allowed values.
:param initial_states: A list containing an initial state for each episode to generate. If the list contains
something else than initial states, then a random state is generated using the specified environment.
:param num_steps: The number of steps to simulate forward for each episode.
:param spr_seeds: A list of seeds to pass to the SPRWrappedQuadrotorEnvironment
:return: A list of structured numpy array containing a trajectory. Refer to trajectory_generation.trajectory_dtype
for the actual structure.
"""
# Use 0 seed if spr_seeds is None
if spr_seeds is None:
spr_seeds = [0 for _ in initial_states]
assert(len(spr_seeds) == len(initial_states))
# Generate an environment for each trajectory to generate and reset it 'all_env_observations' is a list containing
# the list of the current observation for each environment
envs = [SPRWrappedQuadrotorEnvironment(**env_params) for _ in initial_states]
if isinstance(initial_states[0], SysState):
all_env_observations = [env.reset(initial_state, spr_seed=seed) for
env, initial_state, seed in zip(envs, initial_states, spr_seeds)]
else:
all_env_observations = [env.reset(spr_seed=seed) for env, seed in zip(envs, spr_seeds)]
# Set up list structured numpy arrays to contain the trajectories and set the initial state
numpy_trajectories = [np.empty(num_steps +1,dtype=trajectory_dtype) for _ in initial_states]
for env, numpy_trajectory in zip(envs, numpy_trajectories):
state = env.get_current_state()
numpy_trajectory[0]['time'] = env.time
numpy_trajectory[0]['action'] = np.zeros(4)
numpy_trajectory[0]['state']['position'] = state.position
numpy_trajectory[0]['state']['velocity'] = state.velocity
numpy_trajectory[0]['state']['rotation'] = state.rotation
numpy_trajectory[0]['state']['angular_velocity'] = state.angular_velocity
numpy_trajectory[0]['state']['propeller_speed'] = state.propeller_speed
# Simulate all trajectories in parallel
for step in range(num_steps):
# Do one step in each environment
all_env_actions = model.predict(all_env_observations, deterministic=True)[0]
all_env_observations = []
for env, action, numpy_trajectory in zip(envs, all_env_actions, numpy_trajectories):
observation, reward, done, _ = env.step(action)
all_env_observations.append(observation)
env.unwrapped.observation_history.prune_history() # HACK: this prevents our memory from exploding
state = env.get_current_state()
numpy_trajectory[step+1]['time'] = env.time
numpy_trajectory[step+1]['action'] = action
numpy_trajectory[step+1]['state']['position'] = state.position
numpy_trajectory[step+1]['state']['velocity'] = state.velocity
numpy_trajectory[step+1]['state']['rotation'] = state.rotation
numpy_trajectory[step+1]['state']['angular_velocity'] = state.angular_velocity
numpy_trajectory[step+1]['state']['propeller_speed'] = state.propeller_speed
return numpy_trajectories
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.