Skip to content

Instantly share code, notes, and snippets.

@mirth
Last active October 2, 2020 16:35
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save mirth/8fb293145c22c97f05ebcfc72273df36 to your computer and use it in GitHub Desktop.
Save mirth/8fb293145c22c97f05ebcfc72273df36 to your computer and use it in GitHub Desktop.
class NoiseClassifier(nn.Module):
def __init__(self):
super().__init__()
self.vgg = vggish(preprocess=False, postprocess=False)
self.vgg.embeddings = nn.Sequential(
nn.Linear(512 * 4 * 6 * WINDOW_MULTIPLIER, 256),
nn.ReLU(True),
nn.Linear(256, 256),
nn.ReLU(True),
nn.Linear(256, 256),
nn.ReLU(True))
self.fc = nn.Linear(256, 1)
def forward(self, x):
x = self.vgg(x)
x = self.fc(x)
x = torch.sigmoid(x)
x = x.squeeze()
return x
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment