Skip to content

Instantly share code, notes, and snippets.

@apacha
Created June 29, 2020 09:02
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save apacha/df9c5db38261ba82fc40de6fcf678c5b to your computer and use it in GitHub Desktop.
Save apacha/df9c5db38261ba82fc40de6fcf678c5b to your computer and use it in GitHub Desktop.
import torch
import torch.nn as nn
from PIL import Image
from pytorch_lightning import LightningModule, Trainer
from torch import optim
from torch.nn import Conv2d, MaxPool2d
from torch.utils.data import DataLoader
from torchvision.datasets import VisionDataset
from torchvision.transforms import transforms
import numpy as np
import urllib.request
class TestDataset(VisionDataset):
def __init__(self, root: str = "../example_input_scores", transform=None, target_transform=None):
super().__init__(root, None, transform, target_transform)
self.image_index = 1
print("Resetting index")
self.path = "00000001.jpg"
urllib.request.urlretrieve("http://www.gunnerkrigg.com//comics/00000001.jpg", self.path)
def __getitem__(self, index):
image = Image.open(self.path).convert('L') # type: Image.Image
image = image.resize([400, 600])
target = np.array([0,0,0,0,0,0,0,0])
print("Writing image {0}".format(self.image_index))
self.image_index += 1
if self.transform is not None:
image_tensor = self.transform(image)
if self.target_transform is not None:
target = self.target_transform(target)
return image_tensor, target.astype(np.float32)
def __len__(self):
return 10
class Learner(LightningModule):
def __init__(self):
super().__init__()
self.conv1 = Conv2d(1, 4, 5)
self.conv2 = Conv2d(4, 8, 3)
self.conv3 = Conv2d(8, 12, 3)
self.conv4 = Conv2d(12, 16, 3)
self.pool = MaxPool2d(2, 2)
self.conv5 = Conv2d(16, 24, 1)
self.global_avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc1 = nn.Linear(24, 8)
def forward(self, x):
x = self.pool(torch.relu(self.conv1(x)))
x = self.pool(torch.relu(self.conv2(x)))
x = self.pool(torch.relu(self.conv3(x)))
x = self.pool(torch.relu(self.conv4(x)))
x = torch.relu(self.conv5(x))
x = self.global_avg_pool(x)
x = torch.tanh(self.fc1(x.view(-1, 24)))
return x
def training_step(self, batch, batch_nb):
x, y = batch
y_predicted = self(x)
loss = torch.nn.functional.smooth_l1_loss(y_predicted, y)
tensorboard_logs = {'train_loss': loss}
return {'loss': loss, 'log': tensorboard_logs}
def configure_optimizers(self):
return optim.Adadelta(self.parameters())
def train_dataloader(self) -> DataLoader:
transformations = transforms.Compose([transforms.ToTensor(), transforms.Normalize(0.9293, 0.2435)])
dataset = TestDataset(root='../deepscores/deep-scores-v1-extended-100pages/images_png',
transform=transformations)
# Testing with num_workers = 0, 1, 2 changes the behavior of the breakpoint in line 68
return DataLoader(dataset, batch_size=1, num_workers=2, shuffle=True)
if __name__ == '__main__':
model = Learner()
trainer = Trainer()
trainer.fit(model)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment