Skip to content

Instantly share code, notes, and snippets.

@FeryET
Created August 14, 2021 08:10
Show Gist options
  • Save FeryET/f428b6f2233c91397e87e3054a78e234 to your computer and use it in GitHub Desktop.
Save FeryET/f428b6f2233c91397e87e3054a78e234 to your computer and use it in GitHub Desktop.
PneumoniaNet
def load_pretrained():
pretrained_model = torchvision.models.mobilenetv3.mobilenet_v3_small(pretrained=True,
progress=True)
return pretrained_model.features
class PneumoniaNet(nn.Module):
def __init__(self,
input_dim,
finetune=False):
super().__init__()
self.loss_fn = nn.CrossEntropyLoss(
weight=torch.FloatTensor(list(classification_weights.values()))
)
self.input_dim = torch.as_tensor(input_dim)
self.input_encoder = load_pretrained()
# Freezing all layers but the last two
for param in self.input_encoder[:-2].parameters():
param.requires_grad = False
self.output_decoder = nn.Sequential(
nn.AdaptiveAvgPool2d((1,1)),
nn.Flatten(),
nn.Dropout(),
nn.Linear(576, 2),
)
self._init_weights()
def _init_weights(self):
self.output_decoder.apply(weight_init)
def forward(self, x):
assert x.shape[-2] == self.input_dim[0] and x.shape[-1] == self.input_dim[1]
x = self.input_encoder(x)
x = self.output_decoder(x)
return x
def loss(self, outputs, targets):
return self.loss_fn(outputs, targets)
def generate_opt(self):
params = [
{"params": nn.ModuleList([self.output_decoder]).parameters()},
{"params": self.input_encoder[-2:].parameters(), "lr": PRETRAINED_LR,}
]
return torch.optim.AdamW(
params,
lr=LR,
weight_decay=WEIGHT_DECAY,
betas=BETAS,
eps=ADAM_EPS
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment