Skip to content

Instantly share code, notes, and snippets.

What would you like to do?
class MySegmentationDataset(Dataset):
def __getitem__(self, index):
image = cv2.imread(self.images[index])
target = cv2.imread(self.masks[index])
# No data normalization and type casting here
return torch.from_numpy(image).permute(2,0,1).contiguous(),
class Normalize(nn.Module):
def __init__(self, mean, std):
self.register_buffer("mean", torch.tensor(mean).float().reshape(1, len(mean), 1, 1).contiguous())
self.register_buffer("std", torch.tensor(std).float().reshape(1, len(std), 1, 1).reciprocal().contiguous())
def forward(self, input: torch.Tensor) -> torch.Tensor:
return ( - self.mean) * self.std
class MySegmentationModel(nn.Module):
def __init__(self):
self.normalize = Normalize([0.221 * 255], [0.242 * 255])
self.loss = nn.CrossEntropyLoss()
def forward(self, image, target):
image = self.normalize(image)
output = self.backbone(image)
if target is not None:
loss = self.loss(output, target.long())
return loss
return output
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment