Skip to content

Instantly share code, notes, and snippets.

@halflearned
Created June 16, 2023 03:04
Show Gist options
  • Save halflearned/4c727ff0bc8b0ebab7e7501ba67a59c2 to your computer and use it in GitHub Desktop.
Save halflearned/4c727ff0bc8b0ebab7e7501ba67a59c2 to your computer and use it in GitHub Desktop.
VGGSoundDADAPolicy
class VGGSoundDADAPolicy(nn.Module):
""" Transformation policy learned after applying DADA algorithm to VGG-Sound """
def __init__(self, num_time_subpolicies=2, num_freq_subpolicies=2):
super().__init__()
self.num_time_subpolicies = num_time_subpolicies
self.num_freq_subpolicies = num_freq_subpolicies
def forward(self, x):
time_subpolicy_indices = np.random.choice(
len(self.time_domain_subpolicies),
size=self.num_ops,
p=self.time_domain_probabilities
)
for index in time_subpolicy_indices:
operation = self.time_subpolicies[index]
x = operation(x)
# ... Etc. Apply spectrogram, apply frequency subpolicies
@property
def time_domain_subpolicies(self):
return [
T.Compose([T.PitchShift(4), T.TimeShift(3)]),
# ...
@property
def time_domain_probabilities(self):
return np.array([.1, .5, .2, .2])
@property
def frequency_domain_probabilities(self):
return np.array([.2, .4, .4])
@property
def frequency_domain_subpolicies(self):
return [
T.Compose([T.FrequencyMasking(2), T.Identity()])
# etc
]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment