Skip to content

Instantly share code, notes, and snippets.

@araffin
Last active September 10, 2024 12:48
Show Gist options
  • Save araffin/1fb77a8f290ac248b2e76e01164f21e0 to your computer and use it in GitHub Desktop.
Save araffin/1fb77a8f290ac248b2e76e01164f21e0 to your computer and use it in GitHub Desktop.
Minimal implementation to solve the HalfCheetah env using open-loop oscillators
import gymnasium as gym
import numpy as np
from gymnasium.envs.mujoco.mujoco_env import MujocoEnv
# Env initialization
env = gym.make("HalfCheetah-v4", render_mode="human")
# Wrap to have reward statistics
env = gym.wrappers.RecordEpisodeStatistics(env)
mujoco_env = env.unwrapped
n_joints = 6
assert isinstance(mujoco_env, MujocoEnv)
# PD Controller gains
kp, kd = 1.0, 0.05
# Reset the environment
t, _ = 0.0, env.reset(seed=0)
# Oscillators parameters
omega_stance = 2 * np.pi * 4.622 * np.ones(n_joints)
omega_swing = 2 * np.pi * 3.865 * np.ones(n_joints)
phase_shifts = 2 * np.pi * np.array([0.00, 0.789, 0.316, 0.294, 0.629, 0.921])
amplitudes = np.array([1.123, -1.91, -1.204, 1.173, 1.196, -0.085])
offsets = np.array([-0.114, 0.075, 0.002, -0.493, -0.501, -0.227])
oscillator_dt = 0.001 # 1kHz, integration step
# Initial joint positions
theta = phase_shifts.copy()
while True:
env.render()
# Integrate oscillators equations
for _ in range(int(mujoco_env.dt / oscillator_dt)):
in_swing_phase = np.sin(theta) > 0
theta_dot = in_swing_phase * omega_swing + (1 - in_swing_phase) * omega_stance
# Integrate and keep theta in [0, 2 * pi]
theta = (theta + oscillator_dt * theta_dot) % (2 * np.pi)
# Open-Loop Control using oscillators
desired_qpos = amplitudes * np.sin(theta) + offsets
# PD Control: desired qvel is zero
desired_torques = (
kp * (desired_qpos - mujoco_env.data.qpos[-n_joints:])
- kd * mujoco_env.data.qvel[-n_joints:]
)
desired_torques = np.clip(desired_torques, -1.0, 1.0) # clip to action bounds
_, reward, terminated, truncated, info = env.step(desired_torques)
t += mujoco_env.dt
if terminated or truncated:
print(f"Episode return: {float(info['episode']['r'].item()):.2f}")
t, _ = 0.0, env.reset()
# Reinitialize
theta = phase_shifts.copy()
@araffin
Copy link
Author

araffin commented Mar 27, 2024

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment