Skip to content

Instantly share code, notes, and snippets.

@araffin
Last active February 7, 2025 21:16
Show Gist options
  • Select an option

  • Save araffin/a7a576ec1453e74d9bb93120918ef7e7 to your computer and use it in GitHub Desktop.

Select an option

Save araffin/a7a576ec1453e74d9bb93120918ef7e7 to your computer and use it in GitHub Desktop.
from typing import Optional
import jax
import numpy as np
from brax import envs
from brax.envs.base import PipelineEnv
from brax.envs.wrappers import training
from brax.io import image as brax_image
from gymnasium import spaces
# from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import VecEnv, VecMonitor, VecNormalize
from stable_baselines3.common.evaluation import evaluate_policy
from sbx import PPO, TQC, SAC
import optax
class BraxSB3Wrapper(VecEnv):
"""A wrapper that converts batched Brax Env to one that follows SB3 VecEnv API."""
def __init__(
self,
env: PipelineEnv,
seed: int = 0,
backend: Optional[str] = None,
keep_infos: bool = True,
) -> None:
self._env = env
self.metadata = {
"render_modes": ["human", "rgb_array"],
"video.frames_per_second": 1 / self._env.dt,
}
self.render_mode = "rgb_array"
if not hasattr(self._env, "batch_size"):
raise ValueError("underlying env must be batched")
self.num_envs = self._env.batch_size
self.seed(seed)
self.backend = backend
self._state = None
self.keep_infos = keep_infos
self.default_infos = [{} for _ in range(self.num_envs)]
obs = np.inf * np.ones(self._env.observation_size, dtype=np.float32)
self.observation_space = spaces.Box(-obs, obs, dtype=np.float32)
action = jax.tree.map(np.array, self._env.sys.actuator.ctrl_range)
self.action_space = spaces.Box(action[:, 0], action[:, 1], dtype=np.float32)
def reset(key):
key1, key2 = jax.random.split(key)
state = self._env.reset(key2)
return state, state.obs, key1
self._reset = jax.jit(reset, backend=self.backend)
def step(state, action):
state = self._env.step(state, action)
# Note: they don't seem to handle truncation properly
info = {**state.metrics, **state.info}
return state, state.obs, state.reward, state.done, info
self._step = jax.jit(step, backend=self.backend)
def reset(self) -> np.ndarray:
self._state, obs, self._key = self._reset(self._key)
return np.array(obs)
def step_async(self, actions: np.ndarray) -> None:
self.actions = actions
def step_wait(self) -> tuple[np.ndarray, np.ndarray, np.ndarray, list]:
# TODO: add last observation too?
self._state, obs, rewards, dones, info = self._step(self._state, self.actions)
# Convert from dict of list to list of dicts
if self.keep_infos:
# May be slow with many envs
infos = self.to_list(info)
else:
infos = self.default_infos
return np.array(obs), np.array(rewards), np.array(dones).astype(bool), infos
def seed(self, seed: int = 0) -> None:
self._key = jax.random.PRNGKey(seed)
def render(self, mode: str = "human") -> None:
if mode == "rgb_array":
if self._state is None:
raise RuntimeError("Must call reset or step before rendering")
return brax_image.render_array(
self._env.sys, self._state.pipeline_state, 256, 256
)
else:
# Use opencv to render
return super().render(mode="human")
def get_images(self):
state_list = [self._state.take(i).pipeline_state for i in range(self.num_envs)]
return brax_image.render_array(self._env.sys, state_list, width=256, height=256)
def env_is_wrapped(self, wrapper_class, indices=None):
# For compatibility with eval and monitor helpers
return [False]
def set_attr(self, attr_name, value, indices=None):
raise NotImplementedError("Setting attributes is not supported.")
def get_attr(self, attr_name, indices=None):
# resolve indices
if indices is None:
indices = slice(None)
num_indices = self.num_envs
else:
num_indices = len(indices)
attr_val = getattr(self, attr_name)
return [attr_val] * num_indices
def env_method(self, method_name: str, *method_args, indices=None, **method_kwargs):
raise NotImplementedError()
def close(self) -> None:
pass
def to_list(self, info_dict: dict):
infos = [dict.fromkeys(info_dict.keys()) for _ in range(self.num_envs)]
# From https://github.com/isaac-sim/IsaacLab
# fill-in information for each sub-environment
# note: This loop becomes slow when number of environments is large.
for idx in range(self.num_envs):
# fill-in bootstrap information
# TODO: use "truncation" key
# infos[idx]["TimeLimit.truncated"] = truncated[idx] and not terminated[idx]
# TODO: use first-obs?
# infos[idx]["terminal_observation"] = None
# fill-in information from extras
for key, value in info_dict.items():
try:
infos[idx][key] = value[idx]
except TypeError:
# Note: doesn't work for State object
pass
# return list of dictionaries
return infos
# ['ant', 'halfcheetah', 'hopper', 'humanoid', 'humanoidstandup',
# 'inverted_pendulum', 'inverted_double_pendulum', 'pusher', 'reacher', 'walker2d']
env_name = "hopper"
backend = "generalized" # ['mjx', 'generalized', 'positional', 'spring']
n_envs = 1024
# Create a vectorized environment
vmap_env = envs.create(env_name, backend=backend, batch_size=n_envs)
# vmap_env = training.VmapWrapper(base_env, batch_size=n_envs)
vec_env = BraxSB3Wrapper(vmap_env, keep_infos=False)
vec_env = VecMonitor(vec_env)
vec_env = VecNormalize(vec_env, norm_reward=False)
simba_hyperparams = dict(
batch_size=256,
# buffer_size=100_000,
learning_rate=3e-4,
policy_kwargs={
"optimizer_class": optax.adamw,
"net_arch": {"pi": [128], "qf": [256, 256]},
"n_critics": 2,
},
learning_starts=10_000,
# normalize={"norm_obs": True, "norm_reward": False},
# resets=[50000, 75000],
)
# model = PPO(
# "MlpPolicy",
# vec_env,
# n_steps=64,
# batch_size=1024,
# n_epochs=4,
# verbose=1,
# device="cpu",
# )
# model = TQC(
# "SimbaPolicy",
# vec_env,
# train_freq=5,
# gradient_steps=min(n_envs, 256),
# policy_delay=10,
# verbose=1,
# **simba_hyperparams,
# )
model = SAC(
"MlpPolicy",
vec_env,
train_freq=5,
gradient_steps=min(n_envs, 256),
policy_delay=10,
verbose=1,
)
# Training
try:
model.learn(total_timesteps=int(3e7), progress_bar=True, log_interval=10)
except KeyboardInterrupt:
pass
# # Evaluate the model
# print(evaluate_policy(model, vec_env, n_eval_episodes=10, render=True))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment