Skip to content

Instantly share code, notes, and snippets.

@paulgribble
Last active November 7, 2023 14:18
Show Gist options
  • Save paulgribble/0822a8acc7bda9dbc4f7c2d5453d275c to your computer and use it in GitHub Desktop.
Save paulgribble/0822a8acc7bda9dbc4f7c2d5453d275c to your computer and use it in GitHub Desktop.
motornet tests
import os, time, sys, json
import numpy as np
import torch as th
import matplotlib.pyplot as plt
import motornet as mn
def main_fcn(USE_DUMMY_FUN):
print(f"\nUSE_DUMMY_FUN = {USE_DUMMY_FUN}\n")
print('All packages imported.')
print('pytorch version: ' + th.__version__)
print('numpy version: ' + np.__version__)
print('motornet version: ' + mn.__version__)
effector = mn.effector.RigidTendonArm26(muscle=mn.muscle.RigidTendonHillMuscle())
env = mn.environment.RandomTargetReach(effector=effector, max_ep_duration=1.)
class Policy(th.nn.Module):
def __init__(self, input_dim: int, hidden_dim: int, output_dim: int, device):
super().__init__()
self.device = device
self.hidden_dim = hidden_dim
self.n_layers = 1
self.gru = th.nn.GRU(input_dim, hidden_dim, 1, batch_first=True)
self.fc = th.nn.Linear(hidden_dim, output_dim)
self.sigmoid = th.nn.Sigmoid()
# the default initialization in torch isn't ideal
for name, param in self.named_parameters():
if name == "gru.weight_ih_l0":
th.nn.init.xavier_uniform_(param)
elif name == "gru.weight_hh_l0":
th.nn.init.orthogonal_(param)
elif name == "gru.bias_ih_l0":
th.nn.init.zeros_(param)
elif name == "gru.bias_hh_l0":
th.nn.init.zeros_(param)
elif name == "fc.weight":
th.nn.init.xavier_uniform_(param)
elif name == "fc.bias":
th.nn.init.constant_(param, -5.)
else:
raise ValueError
self.to(device)
def forward(self, x, h0):
y, h = self.gru(x[:, None, :], h0)
u = self.sigmoid(self.fc(y)).squeeze(dim=1)
return u, h
def init_hidden(self, batch_size):
weight = next(self.parameters()).data
hidden = weight.new(self.n_layers, batch_size, self.hidden_dim).zero_().to(self.device)
return hidden
device = th.device("cpu")
policy = Policy(env.observation_space.shape[0], 128, env.n_muscles, device=device)
optimizer = th.optim.Adam(policy.parameters(), lr=10**-3)
def l1(x, y):
"""L1 loss"""
return th.mean(th.sum(th.abs(x - y), dim=-1))
loss_fcn = l1
batch_size = 512
n_batch = 50
interval = 10
t0 = time.time()
losses = []
# imagine the plant was a simpl linear layer
dummy_env = th.nn.Linear(6,2)
for batch in range(n_batch):
# initialize batch
h = policy.init_hidden(batch_size=batch_size)
# initial environment
if USE_DUMMY_ENV:
obs = th.rand(batch_size,16, dtype=th.float32)
xy = [th.rand(batch_size,1,2)]
tg = [th.rand(batch_size,1,2)]
else:
obs, info = env.reset(options={"batch_size": batch_size})
xy = [info["states"]["fingertip"][:, None, :]]
tg = [info["goal"][:, None, :]]
# simulate whole episode
for i in range(100): # will run until `max_ep_duration` is reached
# run the model to get commands to environment
action, h = policy(obs, h)
# apply those commands to environment
if USE_DUMMY_ENV:
env_out = dummy_env(action)
xy.append(env_out[:, None, :]) # trajectories
tg.append(th.rand(batch_size,1,2)) # targets
else:
obs, reward, terminated, truncated, info = env.step(action=action)
xy.append(info["states"]["fingertip"][:, None, :]) # trajectories
tg.append(info["goal"][:, None, :]) # targets
# concatenate into a (batch_size, n_timesteps, xy) tensor
xy = th.cat(xy, axis=1)
tg = th.cat(tg, axis=1)
#loss = l1(xy, tg) # L1 loss on position
loss = loss_fcn(xy, tg)
# backward pass & update weights
optimizer.zero_grad()
loss.backward()
th.nn.utils.clip_grad_norm_(policy.parameters(), max_norm=1.) # important!
optimizer.step()
losses.append(loss.item())
if (batch % interval == 0) and (batch != 0):
print("elapsed time: {} s; Batch {}/{} Done, mean policy loss: {}".format(round(time.time()-t0, 1), batch, n_batch, round(sum(losses[-interval:])/interval,8)))
print(f"elapsed time: {time.time()-t0:.3f} sec")
if __name__ == "__main__":
if len(sys.argv) < 2:
USE_DUMMY_ENV = False
else:
USE_DUMMY_ENV = sys.argv[1]
main_fcn(USE_DUMMY_ENV)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment