Skip to content

Instantly share code, notes, and snippets.

@ahmedbesbes
Created July 11, 2019 21:27
Show Gist options
  • Save ahmedbesbes/03e688adaad2d06835985c99f61c3534 to your computer and use it in GitHub Desktop.
Save ahmedbesbes/03e688adaad2d06835985c99f61c3534 to your computer and use it in GitHub Desktop.
import torch
import torch.nn as nn
from torchvision import models
class MRNet(nn.Module):
def __init__(self):
super().__init__()
self.pretrained_model = models.alexnet(pretrained=True)
self.pooling_layer = nn.AdaptiveAvgPool2d(1)
self.classifer = nn.Linear(256, 1)
def forward(self, x):
x = torch.squeeze(x, dim=0)
features = self.pretrained_model.features(x)
pooled_features = self.pooling_layer(features)
pooled_features = pooled_features.view(pooled_features.size(0), -1)
flattened_features = torch.max(pooled_features, 0, keepdim=True)[0]
output = self.classifer(flattened_features)
return output
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment