This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from skimage.io import imread | |
from torch.utils import data | |
class SegmentationDataSet(data.Dataset): | |
def __init__(self, | |
inputs: list, | |
targets: list, | |
transform=None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import List, Callable, Tuple | |
import numpy as np | |
import albumentations as A | |
from sklearn.externals._pilutil import bytescale | |
from skimage.util import crop | |
def normalize_01(inp: np.ndarray): | |
"""Squash image input to the value range [0, 1] (no clipping)""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import napari | |
from transformations import re_normalize | |
def enable_gui_qt(): | |
"""Performs the magic command %gui qt""" | |
from IPython import get_ipython | |
ipython = get_ipython() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from torch import nn | |
import torch | |
@torch.jit.script | |
def autocrop(encoder_layer: torch.Tensor, decoder_layer: torch.Tensor): | |
""" | |
Center-crops the encoder_layer to the size of the decoder_layer, | |
so that merging (concatenation) between levels/blocks is possible. | |
This is only necessary for input sizes != 2**n for 'same' padding and always required for 'valid' padding. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
========================================================================================== | |
Layer (type:depth-idx) Output Shape Param # | |
========================================================================================== | |
├─ModuleList: 1 [] -- | |
| └─DownBlock: 2-1 [-1, 32, 256, 256] -- | |
| | └─Conv2d: 3-1 [-1, 32, 512, 512] 320 | |
| | └─ReLU: 3-2 [-1, 32, 512, 512] -- | |
| | └─BatchNorm2d: 3-3 [-1, 32, 512, 512] 64 | |
| | └─Conv2d: 3-4 [-1, 32, 512, 512] 9,248 | |
| | └─ReLU: 3-5 [-1, 32, 512, 512] -- |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import torch | |
class Trainer: | |
def __init__(self, | |
model: torch.nn.Module, | |
device: torch.device, | |
criterion: torch.nn.Module, | |
optimizer: torch.optim.Optimizer, |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from skimage.io import imread | |
from torch.utils import data | |
from tqdm import tqdm | |
class SegmentationDataSet2(data.Dataset): | |
"""Image segmentation dataset with caching and pretransforms.""" | |
def __init__(self, | |
inputs: list, |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pandas as pd | |
import torch | |
from torch import nn | |
from matplotlib import pyplot as plt | |
from tqdm import tqdm, trange | |
import math | |
class LearningRateFinder: | |
""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def plot_training(training_losses, | |
validation_losses, | |
learning_rate, | |
gaussian=True, | |
sigma=2, | |
figsize=(8, 6) | |
): | |
""" | |
Returns a loss plot with training loss, validation loss and learning rate. | |
""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
def predict(img, | |
model, | |
preprocess, | |
postprocess, | |
device, | |
): | |
model.eval() |
OlderNewer