Skip to content

Instantly share code, notes, and snippets.

@gngdb
Created April 21, 2021 15:20
Show Gist options
  • Save gngdb/13e3b26d9fb1a572f5b87360fb72d7db to your computer and use it in GitHub Desktop.
Save gngdb/13e3b26d9fb1a572f5b87360fb72d7db to your computer and use it in GitHub Desktop.
import torch
from einops import rearrange, repeat, reduce
def relation(input, g, embedding=None, max_pairwise=None):
r"""Applies an all-to-all pairwise relation function to a set of objects.
See :class:`~torch.nn.Relation` for details.
"""
# Batch size, number of objects, feature size
b, o, c = input.size()
# Create pairwise matrix
# _pairs = torch.cat((input.unsqueeze(1).expand(b, o, o, c).contiguous().view(b, o * o, c),
# input.unsqueeze(2).expand(b, o, o, c).contiguous().view(b, o * o, c)), 2)
pairs = torch.cat([repeat(input, 'b o c -> b (m o) c', m=o),
repeat(input, 'b o c -> b (o m) c', m=o)], 2)
# assert torch.abs(pairs - _pairs).max() < 1e-3
# Append embedding if provided
if embedding is not None:
# _pairs = torch.cat((pairs, embedding.unsqueeze(1).expand(b, o ** 2, embedding.size(1))), 2)
pairs = torch.cat([pairs, repeat(embedding, 'b c -> b o2 c', o2=o**2)], 2)
# assert torch.abs(pairs - _pairs).max() < 1e-3
# Calculate new feature size
c = pairs.size(2)
# Pack into batches
# _pairs = pairs.view(b * o ** 2, c)
pairs = rearrange(pairs, 'b om c -> (b om) c')
# assert torch.abs(pairs - _pairs).max() < 1e-3
# Pass through g
if max_pairwise is None:
output = g(pairs)
else:
outputs = []
for batch in range(0, b * o ** 2, max_pairwise):
outputs.append(g(pairs[batch:batch + max_pairwise]))
output = torch.cat(outputs, 0)
# Unpack
# _output = output.view(b, o ** 2, output.size(1)).sum(1).squeeze(1)
output = reduce(output, '(b o2) c -> b c', 'sum', o2=o**2)
# assert abs(output - _output).max() < 1e-3
return output
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment