Created
June 16, 2023 03:04
-
-
Save halflearned/4c727ff0bc8b0ebab7e7501ba67a59c2 to your computer and use it in GitHub Desktop.
VGGSoundDADAPolicy
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class VGGSoundDADAPolicy(nn.Module): | |
""" Transformation policy learned after applying DADA algorithm to VGG-Sound """ | |
def __init__(self, num_time_subpolicies=2, num_freq_subpolicies=2): | |
super().__init__() | |
self.num_time_subpolicies = num_time_subpolicies | |
self.num_freq_subpolicies = num_freq_subpolicies | |
def forward(self, x): | |
time_subpolicy_indices = np.random.choice( | |
len(self.time_domain_subpolicies), | |
size=self.num_ops, | |
p=self.time_domain_probabilities | |
) | |
for index in time_subpolicy_indices: | |
operation = self.time_subpolicies[index] | |
x = operation(x) | |
# ... Etc. Apply spectrogram, apply frequency subpolicies | |
@property | |
def time_domain_subpolicies(self): | |
return [ | |
T.Compose([T.PitchShift(4), T.TimeShift(3)]), | |
# ... | |
@property | |
def time_domain_probabilities(self): | |
return np.array([.1, .5, .2, .2]) | |
@property | |
def frequency_domain_probabilities(self): | |
return np.array([.2, .4, .4]) | |
@property | |
def frequency_domain_subpolicies(self): | |
return [ | |
T.Compose([T.FrequencyMasking(2), T.Identity()]) | |
# etc | |
] | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment