Created
November 24, 2024 17:58
-
-
Save araffin/534f7f1506364eb824c2f4d6a2dd81d1 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import sbx | |
import shimmy | |
import stable_baselines3 as sb3 | |
from dm_control import suite | |
from gymnasium.wrappers import FlattenObservation | |
from stable_baselines3.common.env_checker import check_env | |
# Available envs: | |
# suite._DOMAINS | |
# suite.dog.SUITE | |
env = suite.load(domain_name="dog", task_name="run") | |
gym_env = FlattenObservation(shimmy.DmControlCompatibilityV0(env)) | |
check_env(gym_env) | |
# model = sbx.PPO("MlpPolicy", gym_env, verbose=1).learn(10_000, progress_bar=True) | |
model = sb3.PPO("MlpPolicy", gym_env, verbose=1).learn(10_000, progress_bar=True) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment