Skip to content

Instantly share code, notes, and snippets.

@willwhitney
Created May 27, 2019 23:22
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save willwhitney/c79ee2368210fa360316defe46fec090 to your computer and use it in GitHub Desktop.
Save willwhitney/c79ee2368210fa360316defe46fec090 to your computer and use it in GitHub Desktop.
class DmMujocoModel(nn.Module):
def __init__(self, embed_dim, env_name, traj_len, qpos_only=False, qpos_qvel=False):
super().__init__()
self.embed_dim = embed_dim
self.dataset = DmData(env_name, traj_len, qpos_only, qpos_qvel)
self.dataset.make_env()
self.env = self.dataset.env
self.dummy_parameter = nn.Parameter(torch.zeros(1))
def forward(self, s, a):
self.env.reset()
s, a = s.cpu(), a.cpu()
prediction = torch.zeros_like(s)
for i, (state, actions) in enumerate(zip(s, a)):
with self.env.physics.reset_context():
if self.dataset.qpos_only:
self.env.physics.data.qpos[:] = state.numpy()
else:
qpos_size = state.size(0)//2
self.env.physics.data.qpos[:] = state[:qpos_size].numpy()
self.env.physics.data.qvel[:] = state[qpos_size:].numpy()
for action in actions:
self.env.step(action)
prediction[i] = torch.from_numpy(self.dataset.get_obs())
mu = torch.zeros(s.size(1), self.embed_dim).cuda()
log_var = torch.zeros(s.size(1), self.embed_dim).cuda()
return prediction.float().cuda() + self.dummy_parameter, mu, log_var
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment