Skip to content

Instantly share code, notes, and snippets.

@RuizSerra
Last active March 30, 2022 02:13
Show Gist options
  • Save RuizSerra/c6c89837894a57539a4d2269289d7d2f to your computer and use it in GitHub Desktop.
Save RuizSerra/c6c89837894a57539a4d2269289d7d2f to your computer and use it in GitHub Desktop.
"""
Based on @4rtemi5 's TF implementation, ported to PyTorch
https://www.rpisoni.dev/posts/cossim-convolution/
"""
import torch
from torch import nn
import torch.nn.functional as F
class CosSimConv2D(nn.Module):
def __init__(self, input_shape=(10, 64, 64, 4), units=32, requires_grad=False):
super(CosSimConv2D, self).__init__()
self.units = units
self.kernel_size = 3
self.build(input_shape, requires_grad)
def build(self, input_shape, requires_grad=False):
self.in_shape = input_shape
self.flat_size = self.in_shape[1] * self.in_shape[2]
self.channels = self.in_shape[3]
self.w = nn.Parameter(
data=torch.rand(1, self.channels * (self.kernel_size ** 2), self.units),
requires_grad=requires_grad
)
self.b = nn.Parameter(
data=torch.zeros(self.units),
requires_grad=requires_grad
)
self.p = nn.Parameter(
data=torch.ones(self.units),
requires_grad=requires_grad
)
self.q = nn.Parameter(
data=torch.zeros(1),
requires_grad=requires_grad
)
def l2_normal(self, x, axis=None, epsilon=1e-12):
square_sum = torch.sum(torch.square(x), axis, keepdim=True)
x_inv_norm = torch.sqrt(torch.max(square_sum, torch.full(square_sum.shape, epsilon)))
return x_inv_norm
def stack3x3(self, image):
image = torch.tensor(image)
stack = torch.stack(
[
F.pad(image[:, :-1, :-1, :], pad=(0,0, 1,0, 1,0, 0,0), value=0), # top row
F.pad(image[:, :-1, :, :], pad=(0,0, 0,0, 1,0, 0,0), value=0),
F.pad(image[:, :-1, 1:, :], pad=(0,0, 0,1, 1,0, 0,0), value=0),
F.pad(image[:, :, :-1, :], pad=(0,0, 1,0, 0,0, 0,0), value=0), # middle row
image,
F.pad(image[:, :, 1:, :], pad=(0,0, 0,1, 0,0, 0,0), value=0),
F.pad(image[:, 1:, :-1, :], pad=(0,0, 1,0, 0,1, 0,0), value=0), # bottom row
F.pad(image[:, 1:, :, :], pad=(0,0, 0,0, 0,1, 0,0), value=0),
F.pad(image[:, 1:, 1:, :], pad=(0,0, 0,1, 0,1, 0,0), value=0)
], dim=3)
return stack
def forward(self, inputs, training=None):
x = self.stack3x3(inputs)
print(x.shape)
x = x.reshape((x.shape[0], self.flat_size, self.channels * (self.kernel_size ** 2)))
q = torch.square(self.q)
x_norm = self.l2_normal(x, axis=2) + q
w_norm = self.l2_normal(self.w, axis=1) + q
x = x.float()
x_norm = x_norm.float()
w_norm = w_norm.float()
sign = torch.sign(torch.matmul(x, self.w))
x = torch.matmul(x / x_norm, self.w / w_norm)
x = torch.abs(x) + 1e-12
x = torch.pow(x, torch.square(self.p))
x = sign * x + self.b
x = x.reshape((-1, self.in_shape[1], self.in_shape[2], self.units))
return x
# ---------------------------------------------------------------------------
import matplotlib.pyplot as plt
imgs = np.random.rand(10, 64, 64, 4)
c = CosSimConv2D(input_shape=imgs.shape)
out = c(imgs)
plt.imshow(out[0, :, :, 20])
@RuizSerra
Copy link
Author

I initialise the parameters without grad because I intend to train with an evolutionary strategy

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment