Skip to content

Instantly share code, notes, and snippets.

@oscarknagg
Created May 15, 2019 16:23
Show Gist options
  • Save oscarknagg/627dbcfe020cc63dd47d57e1cf6b076c to your computer and use it in GitHub Desktop.
Save oscarknagg/627dbcfe020cc63dd47d57e1cf6b076c to your computer and use it in GitHub Desktop.
import torch
bodies = torch.zeros((2, 1, 7, 7))
heads = torch.zeros((2, 1, 7, 7))
num_envs = bodies.size(0)
# Initialise body as shown in diagram
bodies[:, :, 3, 2] = 1
bodies[:, :, 3, 3] = 2
bodies[:, :, 2, 3] = 3
bodies[:, :, 2, 4] = 4
bodies[:, :, 2, 5] = 5
bodies[:, :, 3, 5] = 6
bodies[:, :, 4, 5] = 7
bodies[:, :, 5, 5] = 8
print(bodies[0, 0], '\n')
# Move tail
bodies.sub_(1).relu_()
print(bodies[0, 0], '\n')
# Move head manually
# See https://gist.github.com/oscarknagg/863602483afc83d698ce399a67eb21d4)
# for the head movement procedure
heads[:, :, 5, 4] = 1
# Create new front position
bodies.add_(
heads*(bodies+1).view(num_envs, -1).max(dim=1, keepdim=True)[0]
.unsqueeze(-1) # Add singleton dimensions so broadcasting works in add_
.unsqueeze(-1)
)
print(bodies[0, 0], '\n')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment