Skip to content

Instantly share code, notes, and snippets.

@BloodAxe
Created June 10, 2020 19:47
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save BloodAxe/600dc508b3149f7a010a81fbaf468404 to your computer and use it in GitHub Desktop.
Save BloodAxe/600dc508b3149f7a010a81fbaf468404 to your computer and use it in GitHub Desktop.
class MySegmentationDataset(Dataset):
...
def __getitem__(self, index):
image = cv2.imread(self.images[index])
target = cv2.imread(self.masks[index])
# No data normalization and type casting here
return torch.from_numpy(image).permute(2,0,1).contiguous(),
torch.from_numpy(target).permute(2,0,1).contiguous()
class Normalize(nn.Module):
# https://github.com/BloodAxe/pytorch-toolbelt/blob/develop/pytorch_toolbelt/modules/normalize.py
def __init__(self, mean, std):
super().__init__()
self.register_buffer("mean", torch.tensor(mean).float().reshape(1, len(mean), 1, 1).contiguous())
self.register_buffer("std", torch.tensor(std).float().reshape(1, len(std), 1, 1).reciprocal().contiguous())
def forward(self, input: torch.Tensor) -> torch.Tensor:
return (input.to(self.mean.type) - self.mean) * self.std
class MySegmentationModel(nn.Module):
def __init__(self):
self.normalize = Normalize([0.221 * 255], [0.242 * 255])
self.loss = nn.CrossEntropyLoss()
def forward(self, image, target):
image = self.normalize(image)
output = self.backbone(image)
if target is not None:
loss = self.loss(output, target.long())
return loss
return output
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment