Skip to content

Instantly share code, notes, and snippets.

@johschmidt42
johschmidt42 / customdatasets.py
Created November 24, 2020 15:11
This is a snippet test
import torch
from skimage.io import imread
from torch.utils import data
class SegmentationDataSet(data.Dataset):
def __init__(self,
inputs: list,
targets: list,
transform=None
from typing import List, Callable, Tuple
import numpy as np
import albumentations as A
from sklearn.externals._pilutil import bytescale
from skimage.util import crop
def normalize_01(inp: np.ndarray):
"""Squash image input to the value range [0, 1] (no clipping)"""
import numpy as np
import napari
from transformations import re_normalize
def enable_gui_qt():
"""Performs the magic command %gui qt"""
from IPython import get_ipython
ipython = get_ipython()
from torch import nn
import torch
@torch.jit.script
def autocrop(encoder_layer: torch.Tensor, decoder_layer: torch.Tensor):
"""
Center-crops the encoder_layer to the size of the decoder_layer,
so that merging (concatenation) between levels/blocks is possible.
This is only necessary for input sizes != 2**n for 'same' padding and always required for 'valid' padding.
==========================================================================================
Layer (type:depth-idx) Output Shape Param #
==========================================================================================
├─ModuleList: 1 [] --
| └─DownBlock: 2-1 [-1, 32, 256, 256] --
| | └─Conv2d: 3-1 [-1, 32, 512, 512] 320
| | └─ReLU: 3-2 [-1, 32, 512, 512] --
| | └─BatchNorm2d: 3-3 [-1, 32, 512, 512] 64
| | └─Conv2d: 3-4 [-1, 32, 512, 512] 9,248
| | └─ReLU: 3-5 [-1, 32, 512, 512] --
import numpy as np
import torch
class Trainer:
def __init__(self,
model: torch.nn.Module,
device: torch.device,
criterion: torch.nn.Module,
optimizer: torch.optim.Optimizer,
import torch
from skimage.io import imread
from torch.utils import data
from tqdm import tqdm
class SegmentationDataSet2(data.Dataset):
"""Image segmentation dataset with caching and pretransforms."""
def __init__(self,
inputs: list,
import pandas as pd
import torch
from torch import nn
from matplotlib import pyplot as plt
from tqdm import tqdm, trange
import math
class LearningRateFinder:
"""
def plot_training(training_losses,
validation_losses,
learning_rate,
gaussian=True,
sigma=2,
figsize=(8, 6)
):
"""
Returns a loss plot with training loss, validation loss and learning rate.
"""
import torch
def predict(img,
model,
preprocess,
postprocess,
device,
):
model.eval()