Skip to content

Instantly share code, notes, and snippets.

@paulgribble
Created November 5, 2023 15:01
Show Gist options
  • Save paulgribble/2ada10b34b3c224e32c9dc3dda48cded to your computer and use it in GitHub Desktop.
Save paulgribble/2ada10b34b3c224e32c9dc3dda48cded to your computer and use it in GitHub Desktop.
Effect of batch size on simulation time per movement during MotorNet training
# %%
import os
import time
import sys
import json
import numpy as np
import torch as th
import matplotlib.pyplot as plt
import motornet as mn
print('All packages imported.')
print('pytorch version: ' + th.__version__)
print('numpy version: ' + np.__version__)
print('motornet version: ' + mn.__version__)
# %%
effector = mn.effector.RigidTendonArm26(muscle=mn.muscle.RigidTendonHillMuscle())
env = mn.environment.RandomTargetReach(effector=effector, max_ep_duration=1.)
# %%
class Policy(th.nn.Module):
def __init__(self, input_dim: int, hidden_dim: int, output_dim: int, device):
super().__init__()
self.device = device
self.hidden_dim = hidden_dim
self.n_layers = 1
self.gru = th.nn.GRU(input_dim, hidden_dim, 1, batch_first=True)
self.fc = th.nn.Linear(hidden_dim, output_dim)
self.sigmoid = th.nn.Sigmoid()
# the default initialization in torch isn't ideal
for name, param in self.named_parameters():
if name == "gru.weight_ih_l0":
th.nn.init.xavier_uniform_(param)
elif name == "gru.weight_hh_l0":
th.nn.init.orthogonal_(param)
elif name == "gru.bias_ih_l0":
th.nn.init.zeros_(param)
elif name == "gru.bias_hh_l0":
th.nn.init.zeros_(param)
elif name == "fc.weight":
th.nn.init.xavier_uniform_(param)
elif name == "fc.bias":
th.nn.init.constant_(param, -5.)
else:
raise ValueError
self.to(device)
def forward(self, x, h0):
y, h = self.gru(x[:, None, :], h0)
u = self.sigmoid(self.fc(y)).squeeze(dim=1)
return u, h
def init_hidden(self, batch_size):
weight = next(self.parameters()).data
hidden = weight.new(self.n_layers, batch_size, self.hidden_dim).zero_().to(self.device)
return hidden
device = th.device("cpu")
policy = Policy(env.observation_space.shape[0], 128, env.n_muscles, device=device)
optimizer = th.optim.Adam(policy.parameters(), lr=10**-3)
# %%
def l1(x, y):
"""L1 loss"""
return th.mean(th.sum(th.abs(x - y), dim=-1))
# %%
BS = np.array([8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384])
BSt = np.zeros(np.shape(BS))
n_batch = 10
for i, batch_size in enumerate(BS):
bt0 = time.time()
losses = []
for batch in range(n_batch):
# initialize batch
h = policy.init_hidden(batch_size=batch_size)
obs, info = env.reset(options={"batch_size": batch_size})
terminated = False
# initial positions and targets
xy = [info["states"]["fingertip"][:, None, :]]
tg = [info["goal"][:, None, :]]
# simulate whole episode
while not terminated: # will run until `max_ep_duration` is reached
action, h = policy(obs, h)
obs, reward, terminated, truncated, info = env.step(action=action)
xy.append(info["states"]["fingertip"][:, None, :]) # trajectories
tg.append(info["goal"][:, None, :]) # targets
# concatenate into a (batch_size, n_timesteps, xy) tensor
xy = th.cat(xy, axis=1)
tg = th.cat(tg, axis=1)
loss = l1(xy, tg) # L1 loss on position
# backward pass & update weights
optimizer.zero_grad()
loss.backward()
th.nn.utils.clip_grad_norm_(policy.parameters(), max_norm=1.) # important!
optimizer.step()
losses.append(loss.item())
BSt[i] = (time.time()-bt0) * 1000
print(f"{n_batch} x batch_size of {batch_size}: {BSt[i]:.0f} ms total, {BSt[i]/batch_size/n_batch:.1f} ms per movement")
# %%
plt.semilogy(BS, BSt/BS/n_batch, 'o-')
plt.xlabel('Batch Size')
plt.ylabel('Time per movement (ms)')
plt.title('Effect of Batch Size on Simulation Time')
plt.show()
# %%
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment