Skip to content

Instantly share code, notes, and snippets.

@sgillen
Created September 16, 2021 02:37
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save sgillen/2d85c8deb9966d52ff0dfe037f9ab583 to your computer and use it in GitHub Desktop.
Save sgillen/2d85c8deb9966d52ff0dfe037f9ab583 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"from datetime import datetime\n",
"import functools\n",
"import os\n",
"\n",
"from IPython.display import HTML, clear_output\n",
"import jax\n",
"import jax.numpy as jnp\n",
"import matplotlib.pyplot as plt\n",
"\n",
"import brax\n",
"\n",
"from brax import envs\n",
"from brax.training import ppo, sac\n",
"from brax.io import html\n",
"\n",
"import gym\n",
"import numpy as np\n",
"\n",
"bx_env_name = \"ant\"\n",
"mj_env_name = \"Ant-v2\""
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"running build_ext\n"
]
}
],
"source": [
"# Define the brax env the rng it needs\n",
"rng = jax.random.PRNGKey(seed=0)\n",
"bx_env_fn = envs.create_fn(env_name=bx_env_name)\n",
"bx_env = bx_env_fn() \n",
"\n",
"# Make the standard mj environemnt\n",
"mj_env = gym.make(\"Ant-v2\")\n",
"\n",
"act_size = bx_env.action_size"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"bx_state = bx_env.reset(rng)\n",
"bx_obs = bx_state.obs.to_py()\n",
"mj_obs = mj_env.reset()"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(87,)\n",
"(111,)\n"
]
}
],
"source": [
"print(bx_obs.shape)\n",
"print(mj_obs.shape)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0 floor\n",
"1 torso_geom\n",
"2 aux_1_geom\n",
"3 left_leg_geom\n",
"4 left_ankle_geom\n",
"5 aux_2_geom\n",
"6 right_leg_geom\n",
"7 right_ankle_geom\n",
"8 aux_3_geom\n",
"9 back_leg_geom\n",
"10 third_ankle_geom\n",
"11 aux_4_geom\n",
"12 rightback_leg_geom\n",
"13 fourth_ankle_geom\n"
]
}
],
"source": [
"# ^^ What's the difference? Mujoco includes 4 extra bodies/geoms each with a 6 vector containing contact force and moment on a the bodies COM. \n",
"# Note that \"bodies\" and \"geoms\" are distinct in Mujoco, basically bodies carry inertial information, geoms contact information\n",
"# See: http://mujoco.org/book/index.html#BodyGeomSite\n",
"# Now, let's print all the geom names out\n",
"\n",
"for i in range(mj_env.sim.model.ngeom):\n",
" print(i, mj_env.sim.model.geom_id2name(i))"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0 $ Torso\n",
"1 Aux 1\n",
"2 $ Body 4\n",
"3 Aux 2\n",
"4 $ Body 7\n",
"5 Aux 3\n",
"6 $ Body 10\n",
"7 Aux 4\n",
"8 $ Body 13\n",
"9 Ground\n"
]
}
],
"source": [
"# Great, now let's see what brax is working with \n",
"for i, body_name in enumerate(bx_env.env.env.sys.body_idx.keys()):\n",
" print(i, body_name)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"# Ok, with this in mind let's see where the sims match up\n",
"\n",
"posvel_end = 27\n",
"# Compare the position and velocity\n",
"def print_ant_diff_state(mj_obs, bx_obs):\n",
" mj_posvel = mj_obs[:posvel_end]\n",
" bx_posvel = bx_obs[:posvel_end]\n",
" \n",
" print(\"Position and veclocity ======================================\")\n",
" for i in range(posvel_end):\n",
" print(f\"{i:2} mj = {mj_obs[i]:7.4f} bx = {bx_obs[i]:7.4f} diff = {mj_obs[i] - bx_obs[i]:7.4f}\")\n",
" \n",
"# Compares contacts\n",
"def print_ant_diff_contact(mj_obs, bx_obs):\n",
" \n",
" mj_idx = posvel_end\n",
" print(\"Mujoco Contact Forces =======================================\")\n",
" for i in range(mj_env.sim.model.ngeom):\n",
" geom_name = mj_env.sim.model.geom_id2name(i)\n",
" mj_contact = mj_obs[mj_idx:mj_idx + 6]\n",
" print(f\"geom {geom_name}: {mj_contact}\")\n",
" mj_idx += 6 \n",
" \n",
"\n",
" # In hindsight these appear to not match up to the order of contact forces in the brax observations\n",
" print(\"Brax Contact Forces =========================================\")\n",
" bx_idx = posvel_end\n",
" for body_name in bx_env.env.env.sys.body_idx.keys():\n",
" bx_contact = bx_obs[bx_idx:bx_idx + 6]\n",
" print(f\"geom id {body_name} : {bx_contact}\")\n",
" bx_idx += 6 "
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Position and veclocity ======================================\n",
" 0 mj = 0.7500 bx = 0.5133 diff = 0.2367\n",
" 1 mj = 1.0000 bx = 1.0000 diff = 0.0000\n",
" 2 mj = 0.0000 bx = 0.0000 diff = 0.0000\n",
" 3 mj = 0.0000 bx = 0.0000 diff = 0.0000\n",
" 4 mj = 0.0000 bx = 0.0000 diff = 0.0000\n",
" 5 mj = 0.0000 bx = 0.0000 diff = 0.0000\n",
" 6 mj = 0.0000 bx = 0.8727 diff = -0.8727\n",
" 7 mj = 0.0000 bx = 0.0000 diff = 0.0000\n",
" 8 mj = 0.0000 bx = -0.8727 diff = 0.8727\n",
" 9 mj = 0.0000 bx = 0.0000 diff = 0.0000\n",
"10 mj = 0.0000 bx = -0.8727 diff = 0.8727\n",
"11 mj = 0.0000 bx = 0.0000 diff = 0.0000\n",
"12 mj = 0.0000 bx = 0.8727 diff = -0.8727\n",
"13 mj = 0.0000 bx = 0.0000 diff = 0.0000\n",
"14 mj = 0.0000 bx = 0.0000 diff = 0.0000\n",
"15 mj = 0.0000 bx = 0.0000 diff = 0.0000\n",
"16 mj = 0.0000 bx = 0.0000 diff = 0.0000\n",
"17 mj = 0.0000 bx = 0.0000 diff = 0.0000\n",
"18 mj = 0.0000 bx = 0.0000 diff = 0.0000\n",
"19 mj = 0.0000 bx = 0.0000 diff = 0.0000\n",
"20 mj = 0.0000 bx = 0.0000 diff = 0.0000\n",
"21 mj = 0.0000 bx = 0.0000 diff = 0.0000\n",
"22 mj = 0.0000 bx = 0.0000 diff = 0.0000\n",
"23 mj = 0.0000 bx = 0.0000 diff = 0.0000\n",
"24 mj = 0.0000 bx = 0.0000 diff = 0.0000\n",
"25 mj = 0.0000 bx = 0.0000 diff = 0.0000\n",
"26 mj = 0.0000 bx = 0.0000 diff = 0.0000\n",
"Mujoco Contact Forces =======================================\n",
"geom floor: [0. 0. 0. 0. 0. 0.]\n",
"geom torso_geom: [0. 0. 0. 0. 0. 0.]\n",
"geom aux_1_geom: [0. 0. 0. 0. 0. 0.]\n",
"geom left_leg_geom: [0. 0. 0. 0. 0. 0.]\n",
"geom left_ankle_geom: [0. 0. 0. 0. 0. 0.]\n",
"geom aux_2_geom: [0. 0. 0. 0. 0. 0.]\n",
"geom right_leg_geom: [0. 0. 0. 0. 0. 0.]\n",
"geom right_ankle_geom: [0. 0. 0. 0. 0. 0.]\n",
"geom aux_3_geom: [0. 0. 0. 0. 0. 0.]\n",
"geom back_leg_geom: [0. 0. 0. 0. 0. 0.]\n",
"geom third_ankle_geom: [0. 0. 0. 0. 0. 0.]\n",
"geom aux_4_geom: [0. 0. 0. 0. 0. 0.]\n",
"geom rightback_leg_geom: [0. 0. 0. 0. 0. 0.]\n",
"geom fourth_ankle_geom: [0. 0. 0. 0. 0. 0.]\n",
"Brax Contact Forces =========================================\n",
"geom id $ Torso : [0. 0. 0. 0. 0. 0.]\n",
"geom id Aux 1 : [0. 0. 0. 0. 0. 0.]\n",
"geom id $ Body 4 : [0. 0. 0. 0. 0. 0.]\n",
"geom id Aux 2 : [0. 0. 0. 0. 0. 0.]\n",
"geom id $ Body 7 : [0. 0. 0. 0. 0. 0.]\n",
"geom id Aux 3 : [0. 0. 0. 0. 0. 0.]\n",
"geom id $ Body 10 : [0. 0. 0. 0. 0. 0.]\n",
"geom id Aux 4 : [0. 0. 0. 0. 0. 0.]\n",
"geom id $ Body 13 : [0. 0. 0. 0. 0. 0.]\n",
"geom id Ground : [0. 0. 0. 0. 0. 0.]\n"
]
}
],
"source": [
"# Comparing right after a reset, looks like the mj ant starts suspended .25m in the air, and some small differences among the other states. \n",
"\n",
"# Let's ingore the initial noise for now \n",
"def mj_reset_no_noise(env):\n",
" env.reset() \n",
" env.set_state(mj_env.init_qpos, mj_env.init_qvel)\n",
" return env._get_obs() \n",
"\n",
"mj_obs = mj_reset_no_noise(mj_env.unwrapped)\n",
"\n",
"bx_state = bx_env.reset(rng) \n",
"print_ant_diff_state(mj_obs, bx_obs)\n",
"print_ant_diff_contact(mj_obs, bx_obs)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"# Let the sims run for a bit\n",
"for i in range(100):\n",
" \n",
" act_np = np.zeros(act_size) # Choose zeros or random actions\n",
" #act_np = np.random.random(act_size)\n",
" act_jp = jnp.array(act_np)\n",
" \n",
" mj_obs,_,_,_ = mj_env.step(act_np)\n",
" bx_state = bx_env.step(bx_state, act_jp)\n",
" bx_obs = bx_state.obs.to_py()"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Position and veclocity ======================================\n",
" 0 mj = 0.5237 bx = 0.4848 diff = 0.0390\n",
" 1 mj = 1.0000 bx = 1.0000 diff = 0.0000\n",
" 2 mj = 0.0000 bx = 0.0000 diff = -0.0000\n",
" 3 mj = 0.0000 bx = 0.0000 diff = -0.0000\n",
" 4 mj = 0.0000 bx = 0.0000 diff = -0.0000\n",
" 5 mj = -0.0000 bx = -0.0000 diff = 0.0000\n",
" 6 mj = 0.8746 bx = 0.8305 diff = 0.0441\n",
" 7 mj = 0.0000 bx = 0.0000 diff = -0.0000\n",
" 8 mj = -0.8746 bx = -0.8305 diff = -0.0441\n",
" 9 mj = 0.0000 bx = -0.0000 diff = 0.0000\n",
"10 mj = -0.8746 bx = -0.8305 diff = -0.0441\n",
"11 mj = -0.0000 bx = 0.0000 diff = -0.0000\n",
"12 mj = 0.8746 bx = 0.8305 diff = 0.0441\n",
"13 mj = 0.0000 bx = -0.0000 diff = 0.0000\n",
"14 mj = 0.0000 bx = -0.0000 diff = 0.0000\n",
"15 mj = -0.0085 bx = -0.0005 diff = -0.0080\n",
"16 mj = -0.0000 bx = -0.0000 diff = 0.0000\n",
"17 mj = 0.0000 bx = -0.0000 diff = 0.0000\n",
"18 mj = -0.0000 bx = 0.0000 diff = -0.0000\n",
"19 mj = -0.0000 bx = 0.0000 diff = -0.0000\n",
"20 mj = -0.0235 bx = 0.0010 diff = -0.0244\n",
"21 mj = 0.0000 bx = 0.0000 diff = -0.0000\n",
"22 mj = 0.0235 bx = -0.0010 diff = 0.0244\n",
"23 mj = 0.0000 bx = -0.0000 diff = 0.0000\n",
"24 mj = 0.0235 bx = -0.0010 diff = 0.0244\n",
"25 mj = -0.0000 bx = -0.0000 diff = 0.0000\n",
"26 mj = -0.0235 bx = 0.0010 diff = -0.0244\n"
]
}
],
"source": [
"print_ant_diff_state(mj_obs, bx_obs) # After settling for a bit... not *that* bad for a lot of states. Given the difference in mass / inertia, contact, friction, initial position, etc."
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Mujoco Contact Forces =======================================\n",
"geom floor: [0. 0. 0. 0. 0. 0.]\n",
"geom torso_geom: [0. 0. 0. 0. 0. 0.]\n",
"geom aux_1_geom: [0. 0. 0. 0. 0. 0.]\n",
"geom left_leg_geom: [0. 0. 0. 0. 0. 0.]\n",
"geom left_ankle_geom: [ 1.00000000e+00 -1.00000000e+00 -6.55031585e-15 -8.71563128e-01\n",
" -8.71563128e-01 1.00000000e+00]\n",
"geom aux_2_geom: [0. 0. 0. 0. 0. 0.]\n",
"geom right_leg_geom: [0. 0. 0. 0. 0. 0.]\n",
"geom right_ankle_geom: [ 1.00000000e+00 1.00000000e+00 -1.32116540e-14 8.71563128e-01\n",
" -8.71563128e-01 1.00000000e+00]\n",
"geom aux_3_geom: [0. 0. 0. 0. 0. 0.]\n",
"geom back_leg_geom: [0. 0. 0. 0. 0. 0.]\n",
"geom third_ankle_geom: [-1.00000000e+00 1.00000000e+00 6.77236045e-15 8.71563128e-01\n",
" 8.71563128e-01 1.00000000e+00]\n",
"geom aux_4_geom: [0. 0. 0. 0. 0. 0.]\n",
"geom rightback_leg_geom: [0. 0. 0. 0. 0. 0.]\n",
"geom fourth_ankle_geom: [-1.00000000e+00 -1.00000000e+00 1.32116540e-14 -8.71563128e-01\n",
" 8.71563128e-01 1.00000000e+00]\n",
"Brax Contact Forces =========================================\n",
"geom id $ Torso : [0. 0. 0. 0. 0. 0.]\n",
"geom id Aux 1 : [-0.21253121 -0.21253204 0.44075885 0. 0. 0. ]\n",
"geom id $ Body 4 : [ 0.21252969 -0.21253003 0.4407489 0. 0. 0. ]\n",
"geom id Aux 2 : [0.2125305 0.21253125 0.44075665 0. 0. 0. ]\n",
"geom id $ Body 7 : [-0.21253087 0.21253133 0.44075722 0. 0. 0. ]\n",
"geom id Aux 3 : [0. 0. 0. 0. 0. 0.]\n",
"geom id $ Body 10 : [-1.7521014e-03 1.7518904e-03 -9.8720193e-08 0.0000000e+00\n",
" 0.0000000e+00 0.0000000e+00]\n",
"geom id Aux 4 : [-1.7528944e-03 -1.7527896e-03 4.8894435e-08 0.0000000e+00\n",
" 0.0000000e+00 0.0000000e+00]\n",
"geom id $ Body 13 : [ 1.7522224e-03 -1.7520105e-03 -9.8953024e-08 0.0000000e+00\n",
" 0.0000000e+00 0.0000000e+00]\n",
"geom id Ground : [1.7521055e-03 1.7520073e-03 4.5634806e-08 0.0000000e+00 0.0000000e+00\n",
" 0.0000000e+00]\n"
]
}
],
"source": [
"# But contacts are not close. Looks like the ordering of contacts in observations \n",
"# don't match what is in bx_env.env.env.sys.body_idx, and the magnitudes and moments compared to mujoco are not very close. \n",
"\n",
"print_ant_diff_contact(mj_obs, bx_obs) "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Brax",
"language": "python",
"name": "brax"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.6"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment