Last active
November 7, 2023 14:18
-
-
Save paulgribble/0822a8acc7bda9dbc4f7c2d5453d275c to your computer and use it in GitHub Desktop.
motornet tests
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os, time, sys, json | |
import numpy as np | |
import torch as th | |
import matplotlib.pyplot as plt | |
import motornet as mn | |
def main_fcn(USE_DUMMY_FUN): | |
print(f"\nUSE_DUMMY_FUN = {USE_DUMMY_FUN}\n") | |
print('All packages imported.') | |
print('pytorch version: ' + th.__version__) | |
print('numpy version: ' + np.__version__) | |
print('motornet version: ' + mn.__version__) | |
effector = mn.effector.RigidTendonArm26(muscle=mn.muscle.RigidTendonHillMuscle()) | |
env = mn.environment.RandomTargetReach(effector=effector, max_ep_duration=1.) | |
class Policy(th.nn.Module): | |
def __init__(self, input_dim: int, hidden_dim: int, output_dim: int, device): | |
super().__init__() | |
self.device = device | |
self.hidden_dim = hidden_dim | |
self.n_layers = 1 | |
self.gru = th.nn.GRU(input_dim, hidden_dim, 1, batch_first=True) | |
self.fc = th.nn.Linear(hidden_dim, output_dim) | |
self.sigmoid = th.nn.Sigmoid() | |
# the default initialization in torch isn't ideal | |
for name, param in self.named_parameters(): | |
if name == "gru.weight_ih_l0": | |
th.nn.init.xavier_uniform_(param) | |
elif name == "gru.weight_hh_l0": | |
th.nn.init.orthogonal_(param) | |
elif name == "gru.bias_ih_l0": | |
th.nn.init.zeros_(param) | |
elif name == "gru.bias_hh_l0": | |
th.nn.init.zeros_(param) | |
elif name == "fc.weight": | |
th.nn.init.xavier_uniform_(param) | |
elif name == "fc.bias": | |
th.nn.init.constant_(param, -5.) | |
else: | |
raise ValueError | |
self.to(device) | |
def forward(self, x, h0): | |
y, h = self.gru(x[:, None, :], h0) | |
u = self.sigmoid(self.fc(y)).squeeze(dim=1) | |
return u, h | |
def init_hidden(self, batch_size): | |
weight = next(self.parameters()).data | |
hidden = weight.new(self.n_layers, batch_size, self.hidden_dim).zero_().to(self.device) | |
return hidden | |
device = th.device("cpu") | |
policy = Policy(env.observation_space.shape[0], 128, env.n_muscles, device=device) | |
optimizer = th.optim.Adam(policy.parameters(), lr=10**-3) | |
def l1(x, y): | |
"""L1 loss""" | |
return th.mean(th.sum(th.abs(x - y), dim=-1)) | |
loss_fcn = l1 | |
batch_size = 512 | |
n_batch = 50 | |
interval = 10 | |
t0 = time.time() | |
losses = [] | |
# imagine the plant was a simpl linear layer | |
dummy_env = th.nn.Linear(6,2) | |
for batch in range(n_batch): | |
# initialize batch | |
h = policy.init_hidden(batch_size=batch_size) | |
# initial environment | |
if USE_DUMMY_ENV: | |
obs = th.rand(batch_size,16, dtype=th.float32) | |
xy = [th.rand(batch_size,1,2)] | |
tg = [th.rand(batch_size,1,2)] | |
else: | |
obs, info = env.reset(options={"batch_size": batch_size}) | |
xy = [info["states"]["fingertip"][:, None, :]] | |
tg = [info["goal"][:, None, :]] | |
# simulate whole episode | |
for i in range(100): # will run until `max_ep_duration` is reached | |
# run the model to get commands to environment | |
action, h = policy(obs, h) | |
# apply those commands to environment | |
if USE_DUMMY_ENV: | |
env_out = dummy_env(action) | |
xy.append(env_out[:, None, :]) # trajectories | |
tg.append(th.rand(batch_size,1,2)) # targets | |
else: | |
obs, reward, terminated, truncated, info = env.step(action=action) | |
xy.append(info["states"]["fingertip"][:, None, :]) # trajectories | |
tg.append(info["goal"][:, None, :]) # targets | |
# concatenate into a (batch_size, n_timesteps, xy) tensor | |
xy = th.cat(xy, axis=1) | |
tg = th.cat(tg, axis=1) | |
#loss = l1(xy, tg) # L1 loss on position | |
loss = loss_fcn(xy, tg) | |
# backward pass & update weights | |
optimizer.zero_grad() | |
loss.backward() | |
th.nn.utils.clip_grad_norm_(policy.parameters(), max_norm=1.) # important! | |
optimizer.step() | |
losses.append(loss.item()) | |
if (batch % interval == 0) and (batch != 0): | |
print("elapsed time: {} s; Batch {}/{} Done, mean policy loss: {}".format(round(time.time()-t0, 1), batch, n_batch, round(sum(losses[-interval:])/interval,8))) | |
print(f"elapsed time: {time.time()-t0:.3f} sec") | |
if __name__ == "__main__": | |
if len(sys.argv) < 2: | |
USE_DUMMY_ENV = False | |
else: | |
USE_DUMMY_ENV = sys.argv[1] | |
main_fcn(USE_DUMMY_ENV) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment