Skip to content

Instantly share code, notes, and snippets.

@redwrasse
Last active May 20, 2021 23:59
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save redwrasse/e698fd8c69d08a830c622b29364ef04f to your computer and use it in GitHub Desktop.
Save redwrasse/e698fd8c69d08a830c622b29364ef04f to your computer and use it in GitHub Desktop.
toy implementation of 'space-time as a contrastive random walk'
# pseudocode impl
# Algorithm 1 Pseudocode in a PyTorch-like style.
# for x in loader: # x: batch with B sequences
# # Split image into patches
# # B x C x T x H x W -> B x C x T x N x h x w
# x = unfold(x, (patch_size, patch_size))
# x = spatial_jitter(x)
# # Embed patches (B x C x T x N)
# v = l2_norm(resnet(x))
# # Transitions from t to t+1 (B x T-1 x N x N)
# A = einsum("bcti,bctj->btij",
# v[:,:,:-1], v[:,:,1:]) / temperature
# # Transition energies for palindrome graph
# AA = cat((A, A[:,::-1].transpose(-1,-2), 1)
# AA[rand(AA) < drop_rate] = -1e10 # Edge dropout
# At = eye(P) # Init. position
# # Compute walks
# for t in range(2*T-2):
# At = bmm(softmax(AA[:,t]), dim=-1), At)
# # Target is the original node
# loss = At[[range(P)]*B]].log()
import torch
import torch.nn.functional as F
import random
import math
# made up dimensions
NUM_BATCHES = 100
B = 4 # num. sequences in batch
C = 1 # original channel dim.
T = 50 # num. timesteps
H = 6 # frame height
W = 6 # frame width
h, w = 3, 3 # patch size
PATCH_STEP = 1
D = 2 # mock resnet embedding dim
TAU = 1.0 # temperature
mock_loader = [torch.randn(B, C, T, H, W) for _
in range(NUM_BATCHES)]
def circularly_polarized_loader():
""" 'circularly polarized' sequence of frames """
loader = []
for _ in range(NUM_BATCHES):
x = torch.clamp(torch.zeros(B, C, T, H, W),
min=0.01,
max=1.)
for b in range(B):
offset = 2 * 3.14 * random.randint(0, T - 1) / T
for t in range(T):
theta = 2 * 3.14 * t / T + offset
snv = math.sin(theta)
cosv = math.cos(theta)
h_ix = min(int((snv + 1.) * H / 2), H-1)
w_ix = min(int((cosv + 1.) * W / 2), W-1)
x[b, :, t, h_ix, w_ix] = 1.
loader.append(x)
return loader
def pseudocode_impl(loader):
phi = MockResnetEmbedding()
optimizer = torch.optim.SGD(
phi.parameters(),
lr=0.1
)
max_iters = 10000
for iter in range(max_iters):
random.shuffle(loader)
cross_ent_loss = 0.
for x in loader:
optimizer.zero_grad()
# x a batch with B sequences
# shape (B, C, T, H, W)
# B num. batches, C num. channels,
# T num. timesteps, H frame height, W frame width
# Split image into patches
# B x C x T x H x W -> B x C x T x N x h x w
xp = frame_to_patches(x,
dim_h=3,
dim_w=4,
patch_height=h,
patch_width=w,
patch_step=PATCH_STEP)
xp = spatial_jitter(xp)
v = phi(xp)
# Transitions (all inner products) from t to t+1 (B x T-1 x N x N)
E = torch.einsum("bcti,bctj->btij",
v[:,:,:-1], v[:,:,1:])
EP = palindromed_transitions(E)
AA = transition_probs(EP, out_dim=-1)
MP = matrix_product(AA)
H = cross_entropy_loss(MP)
cross_ent_loss += H
params = list(phi.parameters())
# backpropagate
H.backward()
optimizer.step()
cross_ent_loss /= NUM_BATCHES
print(f'(i={iter}) cross entropy loss: {H}')
def cross_entropy_loss(MP):
H = torch.mean(torch.einsum('bii', -torch.log(MP)),
dim=0)
return H
def matrix_product(AA):
""" returns matrix product of individual transition probs
Given AA of shape B x T x N x N
computes product of all matrices along T axis,
returns entity of shape B X N X N
"""
T = AA.shape[1]
MP = AA[:, 0, :, :]
for t in range(1, T):
m = AA[:, t, :, :]
MP = torch.matmul(MP, m)
return MP
def transition_probs(E, out_dim=-1):
""" Returns normalized transition probs with temperature and softmax """
# verify are normalized
return F.softmax(E / TAU, dim=out_dim)
def palindromed_transitions(E):
""""
Transition energies with palindromed transition energies appended
"""
# need to reverse time direction and also i -> j becomes j -> i
E_reversed = torch.flip(E, dims=[1]).transpose(-1, -2)
palindromed = torch.cat([E, E_reversed], dim=1)
return palindromed
def spatial_jitter(xp):
# tbd
return xp
class MockResnetEmbedding(torch.nn.Module):
""" mock embedding on x given of shape
B x C x T x N x h x w
D embedding dim.
Embed patches (B x D x T x N)
"""
# TODO("probably want to make this convolutional")
def __init__(self):
super(MockResnetEmbedding, self).__init__()
mock_embed_matrix = torch.randn(size=(D, C, h*w))
self.embed_matrix = torch.nn.Parameter(mock_embed_matrix)
def forward(self, x):
xf = x.flatten(4, 5)
xe = torch.tensordot(self.embed_matrix, xf, dims=([1, 2], [1, 4])) \
.permute(1, 0, 2, 3)
nm = torch.linalg.norm(xe, dim=1, keepdim=True)
xe_normalized = xe.div(nm)
# print(f'xe: {xe_normalized.shape}')
# check is in fact normalized at dim 1
return xe_normalized
def frame_to_patches(x, dim_h, dim_w, patch_height, patch_width, patch_step):
"""
Given a tensor x of dim ...H * W with H, W frame height and width,
splits it up into patches of patch_height h and patch_width w and step patch_step.
Returns a tensor of shape ... N * h * w, where N is the number of patches created for
the frame.
:param x:
:param dim_h:
:param dim_w:
:param patch_size:
:param patch_step:
:return:
"""
assert (dim_w == dim_h + 1), "error: expected dim_w to follow dim_h"
return x.unfold(dim_h, patch_height, patch_step)\
.unfold(dim_w, patch_width, patch_step)\
.flatten(dim_h, dim_h+1)
def run_pseudocode_impl():
loader = circularly_polarized_loader()
pseudocode_impl(loader)
if __name__ == "__main__":
run_pseudocode_impl()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment