Skip to content

Instantly share code, notes, and snippets.

@TheodoreGalanos
Forked from umbra-scientia/FIR.py
Created August 23, 2021 05:34
Show Gist options
  • Save TheodoreGalanos/f9f9aae2a594785d1baf099357b97392 to your computer and use it in GitHub Desktop.
Save TheodoreGalanos/f9f9aae2a594785d1baf099357b97392 to your computer and use it in GitHub Desktop.
import torch
import torch.nn as nn
class FIR(nn.Module):
def __init__(self, in_dim, out_dim=None, hidden_dim=None, segment_sizes=[1,2,4,8], activation=nn.functional.gelu, device='cpu'):
super().__init__()
if not out_dim: out_dim = in_dim
if not hidden_dim: hidden_dim = in_dim
cursor = 1
nodes = [cursor]
segments = len(segment_sizes)
for i in segment_sizes:
cursor += i
nodes.append(cursor)
self.nodes = nodes
self.device = device
self.in_dim = in_dim
self.out_dim = out_dim
self.slice_dim = hidden_dim // segments
self.hidden_dim = self.slice_dim * segments
self.codebook = nn.Linear(self.in_dim, self.hidden_dim, bias=True, device=device)
self.projection = nn.Linear(self.hidden_dim, self.out_dim * 2, bias=False, device=device)
self.activation = activation
def forward(self, x):
xlen = x.shape[1]
codes = self.codebook(x)
if self.activation:
codes = self.activation(codes)
codes = torch.cumsum(codes, dim=1)
offset = self.nodes[-1]
padding = torch.zeros((x.shape[0], offset, self.slice_dim), dtype=torch.float32, device=self.device)
stack = []
for i in range(len(self.nodes)-1):
code = codes[:,:,self.slice_dim*i:self.slice_dim*(i+1)]
rpos = self.nodes[i]
lpos = self.nodes[i + 1]
pc = torch.cat([padding, code], axis=1)
rval = pc[:,offset-rpos:xlen+offset-rpos,:]
lval = pc[:,offset-lpos:xlen+offset-lpos,:]
sample = (rval - lval) / (lpos - rpos)
stack.append(sample)
stack = torch.cat(stack, axis=2)
proj = self.projection(stack)
gate = torch.sigmoid(proj[:,:,self.out_dim:])
proj = proj[:,:,:self.out_dim]
return x + gate * (proj - x)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment