Skip to content

Instantly share code, notes, and snippets.

@cloneofsimo
Created August 1, 2023 07:56
Show Gist options
  • Save cloneofsimo/46cee71557b33f0a5b3c97cd99f73245 to your computer and use it in GitHub Desktop.
Save cloneofsimo/46cee71557b33f0a5b3c97cd99f73245 to your computer and use it in GitHub Desktop.
from rg2.gym import Rg2UEnv, WalkerEnvConfig
from gym.wrappers import TimeLimit
from stable_baselines3 import PPO, SAC
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import VecNormalize, DummyVecEnv, SubprocVecEnv
from stable_baselines3.common.callbacks import CheckpointCallback
if __name__ == "__main__":
NOM_POS = [0, 0, 0.01, 1.0, 0.0, 0.0, 0.0]
def make_env(rank):
def __init():
env = Rg2UEnv(
WalkerEnvConfig(
resource_dir="/home/user/Downloads/wcollision/urdf/robot.urdf",
gc_init=NOM_POS + [0.0] * 15,
gv_init=[0.0] * (15 + 6),
action_mean=[0.0] * 15,
action_std=[0.3] * 15,
p_gain=50,
d_gain=0.2,
env_params=[-1, -1, -1, -1, -1, -1],
).get_cpp_object(),
seed=rank,
visualizable=(rank == 0),
)
env = TimeLimit(env, 400)
env = Monitor(env)
return env
return __init
envs = SubprocVecEnv([make_env(i) for i in range(8)])
envs = VecNormalize(envs, norm_obs=True, norm_reward=True, clip_obs=10.0)
# envs.env_method("turn_on_visualization", indices=0)
checkpoint_callback = CheckpointCallback(
save_freq=10_000,
save_path="./logs/",
name_prefix="rl_model",
save_replay_buffer=True,
save_vecnormalize=True,
)
model = PPO("MlpPolicy", envs, verbose=1, batch_size=64, learning_rate=2e-4)
model.learn(total_timesteps=100_000_000, callback=checkpoint_callback)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment