Skip to content

Instantly share code, notes, and snippets.

View bh1995's full-sized avatar

Bjørn Hansen bh1995

View GitHub Profile
norm_layer = nn.InstanceNorm2d
class ResBlock(nn.Module):
def __init__(self, f):
super(ResBlock, self).__init__()
self.conv = nn.Sequential(nn.Conv2d(f, f, 3, 1, 1), norm_layer(f), nn.ReLU(),
nn.Conv2d(f, f, 3, 1, 1))
self.norm = norm_layer(f)
def forward(self, x):
return F.relu(self.norm(self.conv(x)+x))
nc=3
ndf=64
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.main = nn.Sequential(
# input is (nc) x 128 x 128
nn.Conv2d(nc,ndf,4,2,1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf) x 64 x 64
def LSGAN_D(real, fake):
return (torch.mean((real - 1)**2) + torch.mean(fake**2))
def LSGAN_G(fake):
return torch.mean((fake - 1)**2)
""" Parts of the U-Net model """
import torch
import torch.nn as nn
import torch.nn.functional as F
class DoubleConv(nn.Module):
"""(convolution => [BN] => ReLU) * 2"""
""" Full assembly of the parts to form the complete network """
import torch.nn.functional as F
# from .unet_parts import *
class UNet(nn.Module):
def __init__(self, n_channels=3, n_classes=1, bilinear=True):
super(UNet, self).__init__()
def get_data(path, start, end, size=(128,128)):
"""
function to load image data of cells and normalize the pictures for dataloader.
"""
images = []
annotations = []
image_names = []
mask_names = []
# First get names of all images to read organized
image_names = [ f.name for f in os.scandir(path+'/image')][start:end]
def train(model, opt, loss_fn, epochs, data_loader, print_status):
loss_ls = []
epoch_ls = []
for epoch in range(epochs):
avg_loss = 0
model.train()
b=0
for X_batch, Y_batch in data_loader:
n_channels = 3
n_classes = 2
# Next creat an instance of the UNet model
modelUnet = UNet().to(device)
criterion = torch.nn.BCEWithLogitsLoss()
# Now define the optimizer
optimizerUnet = optim.Adam(modelUnet.parameters(), lr = 0.00001, weight_decay=0)
class WarwickCellDataset(object):
def __init__(self, root, transforms=None): # transforms
self.root = root
# self.transforms = transforms
self.transforms=[]
if transforms!=None:
self.transforms.append(transforms)
# load all image files, sorting them to
# ensure that they are aligned
self.imgs = list(natsorted(os.listdir(os.path.join(root, "image"))))
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
num_classes = 2
# load an instance segmentation model pre-trained pre-trained on COCO
model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)
# get number of input features for the classifier
in_features = model.roi_heads.box_predictor.cls_score.in_features