Created
June 29, 2020 09:02
-
-
Save apacha/df9c5db38261ba82fc40de6fcf678c5b to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
from PIL import Image | |
from pytorch_lightning import LightningModule, Trainer | |
from torch import optim | |
from torch.nn import Conv2d, MaxPool2d | |
from torch.utils.data import DataLoader | |
from torchvision.datasets import VisionDataset | |
from torchvision.transforms import transforms | |
import numpy as np | |
import urllib.request | |
class TestDataset(VisionDataset): | |
def __init__(self, root: str = "../example_input_scores", transform=None, target_transform=None): | |
super().__init__(root, None, transform, target_transform) | |
self.image_index = 1 | |
print("Resetting index") | |
self.path = "00000001.jpg" | |
urllib.request.urlretrieve("http://www.gunnerkrigg.com//comics/00000001.jpg", self.path) | |
def __getitem__(self, index): | |
image = Image.open(self.path).convert('L') # type: Image.Image | |
image = image.resize([400, 600]) | |
target = np.array([0,0,0,0,0,0,0,0]) | |
print("Writing image {0}".format(self.image_index)) | |
self.image_index += 1 | |
if self.transform is not None: | |
image_tensor = self.transform(image) | |
if self.target_transform is not None: | |
target = self.target_transform(target) | |
return image_tensor, target.astype(np.float32) | |
def __len__(self): | |
return 10 | |
class Learner(LightningModule): | |
def __init__(self): | |
super().__init__() | |
self.conv1 = Conv2d(1, 4, 5) | |
self.conv2 = Conv2d(4, 8, 3) | |
self.conv3 = Conv2d(8, 12, 3) | |
self.conv4 = Conv2d(12, 16, 3) | |
self.pool = MaxPool2d(2, 2) | |
self.conv5 = Conv2d(16, 24, 1) | |
self.global_avg_pool = nn.AdaptiveAvgPool2d(1) | |
self.fc1 = nn.Linear(24, 8) | |
def forward(self, x): | |
x = self.pool(torch.relu(self.conv1(x))) | |
x = self.pool(torch.relu(self.conv2(x))) | |
x = self.pool(torch.relu(self.conv3(x))) | |
x = self.pool(torch.relu(self.conv4(x))) | |
x = torch.relu(self.conv5(x)) | |
x = self.global_avg_pool(x) | |
x = torch.tanh(self.fc1(x.view(-1, 24))) | |
return x | |
def training_step(self, batch, batch_nb): | |
x, y = batch | |
y_predicted = self(x) | |
loss = torch.nn.functional.smooth_l1_loss(y_predicted, y) | |
tensorboard_logs = {'train_loss': loss} | |
return {'loss': loss, 'log': tensorboard_logs} | |
def configure_optimizers(self): | |
return optim.Adadelta(self.parameters()) | |
def train_dataloader(self) -> DataLoader: | |
transformations = transforms.Compose([transforms.ToTensor(), transforms.Normalize(0.9293, 0.2435)]) | |
dataset = TestDataset(root='../deepscores/deep-scores-v1-extended-100pages/images_png', | |
transform=transformations) | |
# Testing with num_workers = 0, 1, 2 changes the behavior of the breakpoint in line 68 | |
return DataLoader(dataset, batch_size=1, num_workers=2, shuffle=True) | |
if __name__ == '__main__': | |
model = Learner() | |
trainer = Trainer() | |
trainer.fit(model) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment