Skip to content

Instantly share code, notes, and snippets.

@Algomancer
Created January 16, 2022 00:12
Show Gist options
  • Save Algomancer/885f7dd645b3f5b03c8b5e08484b70b2 to your computer and use it in GitHub Desktop.
Save Algomancer/885f7dd645b3f5b03c8b5e08484b70b2 to your computer and use it in GitHub Desktop.
Spectral augmentation using Mozaicing
# Author Adam Hibble @algomancer
import torch
import torch.nn.functional as F
import torch.nn as nn
import tqdm
def get_padding(padding_type, kernel_size):
assert padding_type in ['SAME', 'VALID']
if padding_type == 'SAME':
return tuple((k - 1) // 2 for k in kernel_size)
return tuple(0 for _ in kernel_size)
def maximum_filter(input, kernel_size=None):
"""Calculate multidimensional maximum filter.
returns maximum_filter Has the same shape as `input`.
"""
should_squeeze = False
if len(input.shape) == 2:
input = input[None, :, :]
should_squeeze = True
x = F.max_pool2d(input, kernel_size, stride=1, padding=get_padding('SAME', kernel_size))
if should_squeeze:
x = x.squeeze(0)
return x
class NMFMozaicing(nn.Module):
def __init__(self, r_width, c_width, polyphony, iterations):
"""
r_width: Width of the repeated activation filter
c_width: Half length of time-continuous activation filter
polyphony: Number of polyphonic voices
"""
super(NMFMozaicing, self).__init__()
self.r_width = r_width
self.c_width = c_width
self.polyphony = polyphony
self.iterations = iterations
def step(self, activation_matrix, factor):
#Step 1: Avoid repeated activations
K, N = activation_matrix.shape
activation_filter = maximum_filter(activation_matrix, kernel_size=self.r_width)
activation_matrix[activation_matrix < activation_filter] = activation_matrix[activation_matrix < activation_filter] * factor
#Step 2: Restrict number of simultaneous activations
cut_off = torch.topk(activation_matrix, self.polyphony+1, dim=0)[0][self.polyphony, :]
activation_matrix[activation_matrix > cut_off[None, :]] = activation_matrix[activation_matrix > cut_off[None, :]] * factor
#Step 3: Supporting time-continuous activations
di = K-1
dj = 0
for k in range(-activation_matrix.shape[0]+1, activation_matrix.shape[1]):
z = torch.cumsum(torch.cat((torch.zeros(self.c_width).to(activation_matrix.device), torch.diag(activation_matrix, k), torch.zeros(self.c_width).to(activation_matrix.device))), dim=0)
x2 = z[2*self.c_width::] - z[0:-2*self.c_width]
activation_matrix[di+torch.arange(len(x2)), dj+torch.arange(len(x2))] = x2
if di == 0:
dj += 1
else:
di -= 1
return activation_matrix
def forward(self, target, template):
N, K = target.shape[1], template.shape[1]
activation_matrix = torch.rand(K, N).to(target.device)
for i in tqdm.tqdm(range(self.iterations)):
factor = 1 - (i+1)/self.iterations
activation_matrix = self.step(activation_matrix, factor)
#print(template.shape, activation_matrix.shape)
template_weighted = template.matmul(activation_matrix)
template_weighted[template_weighted == 0] = 1
target_lamda = (target / template_weighted)
template_denom = torch.sum(template, 0)
template_denom[template_denom == 0] = 1
inner = template.t().matmul(target_lamda) / template_denom[:, None]
activation_matrix = activation_matrix * inner
return activation_matrix
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment