-
-
Save ClashLuke/8f6521deef64789e76334f1b72a70d80 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch import nn | |
from torch.nn import functional as F | |
class CosSimConv2d(nn.Conv2d): | |
def __init__(self, in_channels: int, out_channels: int, kernel_size, stride=1, padding=None, dilation=1, | |
groups: int = 1, bias: bool = False, q_scale: float = 10, p_scale: float = 100): | |
if padding is None: | |
if int(torch.__version__.split('.')[1]) >= 10: | |
padding = "same" | |
else: | |
padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2 # This doesn't support even kernels | |
if isinstance(kernel_size, int): | |
kernel_size = (kernel_size, kernel_size) | |
bias = True # Disable bias for "true" SCS, add it for better performance | |
assert dilation == 1, "Dilation has to be 1 to use AvgPool2d as L2-Norm backend." | |
assert groups == in_channels or groups == 1, "Either depthwise or full convolution. Grouped not supported" | |
super(CosSimConv2d, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, | |
bias) | |
self.q_scale = q_scale # q scale is missing at normalization of input. minor difference, but necessary | |
self.q = torch.nn.Parameter(torch.full((1,), q_scale ** 0.5)) | |
# Uncomment for "true" SCS: | |
# self.p_scale = p_scale | |
# self.p = torch.nn.Parameter(torch.full((1,), p_scale ** 0.5)) | |
def forward(self, inp: torch.Tensor) -> torch.Tensor: | |
out = inp.square() | |
if self.groups == 1: | |
out = out.sum(1, keepdim=True) | |
norm = F.conv2d(out, torch.ones_like(self.weight[:1, :1]), None, self.stride, self.padding, self.dilation) + 1e-6 | |
q = self.q.square() / self.q_scale | |
weight = self.weight / (self.weight.square().sum(dim=(1, 2, 3), keepdim=True).sqrt() + q) | |
out = F.conv2d(inp, weight, self.bias, self.stride, self.padding, self.dilation, self.groups) / norm.sqrt() | |
# Uncomment these lines for "true" SCS (it's ~200x slower): | |
# abs = (out.square() + 1e-6).sqrt() | |
# sign = out / abs | |
# out = abs ** (self.p.square() / self.p_scale) | |
# out = out * sign | |
return out |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment