This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
norm_layer = nn.InstanceNorm2d | |
class ResBlock(nn.Module): | |
def __init__(self, f): | |
super(ResBlock, self).__init__() | |
self.conv = nn.Sequential(nn.Conv2d(f, f, 3, 1, 1), norm_layer(f), nn.ReLU(), | |
nn.Conv2d(f, f, 3, 1, 1)) | |
self.norm = norm_layer(f) | |
def forward(self, x): | |
return F.relu(self.norm(self.conv(x)+x)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
nc=3 | |
ndf=64 | |
class Discriminator(nn.Module): | |
def __init__(self): | |
super(Discriminator, self).__init__() | |
self.main = nn.Sequential( | |
# input is (nc) x 128 x 128 | |
nn.Conv2d(nc,ndf,4,2,1, bias=False), | |
nn.LeakyReLU(0.2, inplace=True), | |
# state size. (ndf) x 64 x 64 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def LSGAN_D(real, fake): | |
return (torch.mean((real - 1)**2) + torch.mean(fake**2)) | |
def LSGAN_G(fake): | |
return torch.mean((fake - 1)**2) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" Parts of the U-Net model """ | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
class DoubleConv(nn.Module): | |
"""(convolution => [BN] => ReLU) * 2""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" Full assembly of the parts to form the complete network """ | |
import torch.nn.functional as F | |
# from .unet_parts import * | |
class UNet(nn.Module): | |
def __init__(self, n_channels=3, n_classes=1, bilinear=True): | |
super(UNet, self).__init__() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def get_data(path, start, end, size=(128,128)): | |
""" | |
function to load image data of cells and normalize the pictures for dataloader. | |
""" | |
images = [] | |
annotations = [] | |
image_names = [] | |
mask_names = [] | |
# First get names of all images to read organized | |
image_names = [ f.name for f in os.scandir(path+'/image')][start:end] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def train(model, opt, loss_fn, epochs, data_loader, print_status): | |
loss_ls = [] | |
epoch_ls = [] | |
for epoch in range(epochs): | |
avg_loss = 0 | |
model.train() | |
b=0 | |
for X_batch, Y_batch in data_loader: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
n_channels = 3 | |
n_classes = 2 | |
# Next creat an instance of the UNet model | |
modelUnet = UNet().to(device) | |
criterion = torch.nn.BCEWithLogitsLoss() | |
# Now define the optimizer | |
optimizerUnet = optim.Adam(modelUnet.parameters(), lr = 0.00001, weight_decay=0) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class WarwickCellDataset(object): | |
def __init__(self, root, transforms=None): # transforms | |
self.root = root | |
# self.transforms = transforms | |
self.transforms=[] | |
if transforms!=None: | |
self.transforms.append(transforms) | |
# load all image files, sorting them to | |
# ensure that they are aligned | |
self.imgs = list(natsorted(os.listdir(os.path.join(root, "image")))) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torchvision | |
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor | |
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor | |
num_classes = 2 | |
# load an instance segmentation model pre-trained pre-trained on COCO | |
model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True) | |
# get number of input features for the classifier | |
in_features = model.roi_heads.box_predictor.cls_score.in_features |
OlderNewer