Skip to content

Instantly share code, notes, and snippets.

@johschmidt42
Created November 24, 2020 15:11
Show Gist options
  • Save johschmidt42/4516d27bf77dd01cb6eb5e55a9808eae to your computer and use it in GitHub Desktop.
Save johschmidt42/4516d27bf77dd01cb6eb5e55a9808eae to your computer and use it in GitHub Desktop.
This is a snippet test
import torch
from skimage.io import imread
from torch.utils import data
class SegmentationDataSet(data.Dataset):
def __init__(self,
inputs: list,
targets: list,
transform=None
):
self.inputs = inputs
self.targets = targets
self.transform = transform
self.inputs_dtype = torch.float32
self.targets_dtype = torch.long
def __len__(self):
return len(self.inputs)
def __getitem__(self,
index: int):
# Select the sample
input_ID = self.inputs[index]
target_ID = self.targets[index]
# Load input and target
x, y = imread(input_ID), imread(target_ID)
# Preprocessing
if self.transform is not None:
x, y = self.transform(x, y)
# Typecasting
x, y = torch.from_numpy(x).type(self.inputs_dtype), torch.from_numpy(y).type(self.targets_dtype)
return x, y
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment