Skip to content

Instantly share code, notes, and snippets.

@ResidentMario
Created November 18, 2020 20:22
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ResidentMario/a29e33a25448f27ee3965b7c0edd4a37 to your computer and use it in GitHub Desktop.
Save ResidentMario/a29e33a25448f27ee3965b7c0edd4a37 to your computer and use it in GitHub Desktop.
import os
from pathlib import Path
import torch
from torchvision.datasets import ImageFolder
from torch.utils.data import Dataset, DataLoader
from torch import nn
from torchvision import transforms
import torch.optim as optim
import torch.quantization
import numpy as np
from PIL import Image
from sklearn.model_selection import KFold
import time
class BobRossSegmentedImagesDataset(Dataset):
def __init__(self, dataroot):
super().__init__()
self.dataroot = dataroot
self.imgs = list((self.dataroot / 'train' / 'images').rglob('*.png'))
self.segs = list((self.dataroot / 'train' / 'labels').rglob('*.png'))
self.transform = transforms.Compose([
transforms.Resize((164, 164)),
transforms.Pad(46, padding_mode='reflect'),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(
mean=(0.459387, 0.46603974, 0.4336706),
std=(0.06098535, 0.05802868, 0.08737113)
)
])
self.color_key = {
3 : 0,
5: 1,
10: 2,
14: 3,
17: 4,
18: 5,
22: 6,
27: 7,
61: 8
}
assert len(self.imgs) == len(self.segs)
# TODO: remean images to N(0, 1)?
def __len__(self):
return len(self.imgs)
def __getitem__(self, i):
def translate(x):
return self.color_key[x]
translate = np.vectorize(translate)
img = Image.open(self.imgs[i])
img = self.transform(img)
seg = Image.open(self.segs[i])
seg = seg.resize((256, 256), Image.NEAREST)
# Labels are in the ADE20K ontology and are not consequetive,
# we have to apply a remap operation over the labels in a just-in-time
# manner. This slows things down, but it's fine, this is just a demo
# anyway.
seg = translate(np.array(seg)).astype('int64')
# One-hot encode the segmentation mask.
# def ohe_mat(segmap):
# return np.array(
# list(
# np.array(segmap) == i for i in range(9)
# )
# ).astype(int).reshape(9, 256, 256)
# seg = ohe_mat(seg)
# Additionally, the original UNet implementation outputs a segmentation map
# for a subset of the overall image, not the image as a whole! With this input
# size the segmentation map targeted is a (164, 164) center crop.
seg = seg[46:210, 46:210]
return img, seg
class UNet(nn.Module):
def __init__(self):
super().__init__()
self.quant_1 = torch.quantization.QuantStub()
self.conv_1_1 = nn.Conv2d(3, 64, 3)
torch.nn.init.kaiming_normal_(self.conv_1_1.weight)
self.relu_1_2 = nn.ReLU()
self.norm_1_3 = nn.BatchNorm2d(64)
self.dequant_1 = torch.quantization.DeQuantStub()
self.conv_1_4 = nn.Conv2d(64, 64, 3)
torch.nn.init.kaiming_normal_(self.conv_1_4.weight)
self.relu_1_5 = nn.ReLU()
self.norm_1_6 = nn.BatchNorm2d(64)
self.pool_1_7 = nn.MaxPool2d(2)
self.conv_2_1 = nn.Conv2d(64, 128, 3)
torch.nn.init.kaiming_normal_(self.conv_2_1.weight)
self.relu_2_2 = nn.ReLU()
self.norm_2_3 = nn.BatchNorm2d(128)
self.conv_2_4 = nn.Conv2d(128, 128, 3)
torch.nn.init.kaiming_normal_(self.conv_2_4.weight)
self.relu_2_5 = nn.ReLU()
self.norm_2_6 = nn.BatchNorm2d(128)
self.pool_2_7 = nn.MaxPool2d(2)
self.conv_3_1 = nn.Conv2d(128, 256, 3)
torch.nn.init.kaiming_normal_(self.conv_3_1.weight)
self.relu_3_2 = nn.ReLU()
self.norm_3_3 = nn.BatchNorm2d(256)
self.conv_3_4 = nn.Conv2d(256, 256, 3)
torch.nn.init.kaiming_normal_(self.conv_3_4.weight)
self.relu_3_5 = nn.ReLU()
self.norm_3_6 = nn.BatchNorm2d(256)
self.pool_3_7 = nn.MaxPool2d(2)
self.conv_4_1 = nn.Conv2d(256, 512, 3)
torch.nn.init.kaiming_normal_(self.conv_4_1.weight)
self.relu_4_2 = nn.ReLU()
self.norm_4_3 = nn.BatchNorm2d(512)
self.conv_4_4 = nn.Conv2d(512, 512, 3)
torch.nn.init.kaiming_normal_(self.conv_4_4.weight)
self.relu_4_5 = nn.ReLU()
self.norm_4_6 = nn.BatchNorm2d(512)
# deconv is the '2D transposed convolution operator'
self.deconv_5_1 = nn.ConvTranspose2d(512, 256, (2, 2), 2)
# 61x61 -> 48x48 crop
self.c_crop_5_2 = lambda x: x[:, :, 6:54, 6:54]
self.concat_5_3 = lambda x, y: torch.cat((x, y), dim=1)
self.conv_5_4 = nn.Conv2d(512, 256, 3)
torch.nn.init.kaiming_normal_(self.conv_5_4.weight)
self.relu_5_5 = nn.ReLU()
self.norm_5_6 = nn.BatchNorm2d(256)
self.conv_5_7 = nn.Conv2d(256, 256, 3)
torch.nn.init.kaiming_normal_(self.conv_5_7.weight)
self.relu_5_8 = nn.ReLU()
self.norm_5_9 = nn.BatchNorm2d(256)
self.deconv_6_1 = nn.ConvTranspose2d(256, 128, (2, 2), 2)
# 121x121 -> 88x88 crop
self.c_crop_6_2 = lambda x: x[:, :, 17:105, 17:105]
self.concat_6_3 = lambda x, y: torch.cat((x, y), dim=1)
self.conv_6_4 = nn.Conv2d(256, 128, 3)
torch.nn.init.kaiming_normal_(self.conv_6_4.weight)
self.relu_6_5 = nn.ReLU()
self.norm_6_6 = nn.BatchNorm2d(128)
self.conv_6_7 = nn.Conv2d(128, 128, 3)
torch.nn.init.kaiming_normal_(self.conv_6_7.weight)
self.relu_6_8 = nn.ReLU()
self.norm_6_9 = nn.BatchNorm2d(128)
self.deconv_7_1 = nn.ConvTranspose2d(128, 64, (2, 2), 2)
# 252x252 -> 168x168 crop
self.c_crop_7_2 = lambda x: x[:, :, 44:212, 44:212]
self.concat_7_3 = lambda x, y: torch.cat((x, y), dim=1)
self.conv_7_4 = nn.Conv2d(128, 64, 3)
torch.nn.init.kaiming_normal_(self.conv_7_4.weight)
self.relu_7_5 = nn.ReLU()
self.norm_7_6 = nn.BatchNorm2d(64)
self.conv_7_7 = nn.Conv2d(64, 64, 3)
torch.nn.init.kaiming_normal_(self.conv_7_7.weight)
self.relu_7_8 = nn.ReLU()
self.norm_7_9 = nn.BatchNorm2d(64)
# 1x1 conv ~= fc; n_classes = 9
self.conv_8_1 = nn.Conv2d(64, 9, 1)
def forward(self, x):
x = self.quant_1(x)
x = self.conv_1_1(x)
x = self.relu_1_2(x)
x = self.norm_1_3(x)
x = self.dequant_1(x)
x = self.conv_1_4(x)
x = self.relu_1_5(x)
x_residual_1 = self.norm_1_6(x)
x = self.pool_1_7(x_residual_1)
x = self.conv_2_1(x)
x = self.relu_2_2(x)
x = self.norm_2_3(x)
x = self.conv_2_4(x)
x = self.relu_2_5(x)
x_residual_2 = self.norm_2_6(x)
x = self.pool_2_7(x_residual_2)
x = self.conv_3_1(x)
x = self.relu_3_2(x)
x = self.norm_3_3(x)
x = self.conv_3_4(x)
x = self.relu_3_5(x)
x_residual_3 = self.norm_3_6(x)
x = self.pool_3_7(x_residual_3)
x = self.conv_4_1(x)
x = self.relu_4_2(x)
x = self.norm_4_3(x)
x = self.conv_4_4(x)
x = self.relu_4_5(x)
x = self.norm_4_6(x)
x = self.deconv_5_1(x)
x = self.concat_5_3(self.c_crop_5_2(x_residual_3), x)
x = self.conv_5_4(x)
x = self.relu_5_5(x)
x = self.norm_5_6(x)
x = self.conv_5_7(x)
x = self.relu_5_8(x)
x = self.norm_5_9(x)
x = self.deconv_6_1(x)
x = self.concat_6_3(self.c_crop_6_2(x_residual_2), x)
x = self.conv_6_4(x)
x = self.relu_6_5(x)
x = self.norm_6_6(x)
x = self.conv_6_7(x)
x = self.relu_6_8(x)
x = self.norm_6_9(x)
x = self.deconv_7_1(x)
x = self.concat_7_3(self.c_crop_7_2(x_residual_1), x)
x = self.conv_7_4(x)
x = self.relu_7_5(x)
x = self.norm_7_6(x)
x = self.conv_7_7(x)
x = self.relu_7_8(x)
x = self.norm_7_9(x)
x = self.conv_8_1(x)
return x
def main():
print("Starting up...")
dataroot = Path('/mnt/segmented-bob-ross-images/')
dataset = BobRossSegmentedImagesDataset(dataroot)
dataloader = DataLoader(dataset, shuffle=True, batch_size=8)
print("Loading the model...")
model = UNet()
model.qconfig = torch.quantization.get_default_qconfig('fbgemm')
checkpoints_dir = '/spell/checkpoints'
model.load_state_dict(
torch.load(f"{checkpoints_dir}/model_50.pth", map_location=torch.device('cpu'))
)
model.eval()
# NEW
model = torch.quantization.prepare(model)
print(f"Quantizing the model...")
start_time = time.time()
for i, (batch, segmap) in enumerate(dataloader):
# batch = batch.cuda()
# segmap = segmap.cuda()
model(batch)
model = torch.quantization.convert(model)
print(f"Quantization done in {str(time.time() - start_time)} seconds.")
print(f"Evaluating the model...")
start_time = time.time()
for i, (batch, segmap) in enumerate(dataloader):
# batch = batch.cuda()
# segmap = segmap.cuda()
model(batch)
print(f"Evaluation done in {str(time.time() - start_time)} seconds.")
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment