Skip to content

Instantly share code, notes, and snippets.

@rueian
Last active December 19, 2022 04:40
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save rueian/fa2ffdc0b5505e9147d20f27095cf23e to your computer and use it in GitHub Desktop.
Save rueian/fa2ffdc0b5505e9147d20f27095cf23e to your computer and use it in GitHub Desktop.
import torch
from torch import nn
from torch.nn import functional as F
from torchvision import transforms
import pytorch_lightning as pl
class Block(nn.Module):
def __init__(self, in_ch, out_ch):
super().__init__()
self.conv1 = nn.Conv2d(in_ch, out_ch, 3)
self.relu = nn.ReLU()
self.conv2 = nn.Conv2d(out_ch, out_ch, 3)
def forward(self, x):
return self.conv2(self.relu(self.conv1(x)))
class Encoder(nn.Module):
def __init__(self, chs=(3, 64, 128, 256, 512, 1024)):
super().__init__()
self.enc_blocks = nn.ModuleList([Block(chs[i], chs[i+1]) for i in range(len(chs)-1)])
self.pool = nn.MaxPool2d(2)
def forward(self, x):
ftrs = []
for block in self.enc_blocks:
x = block(x)
ftrs.append(x)
x = self.pool(x)
return ftrs
class Decoder(nn.Module):
def __init__(self, chs=(1024, 512, 256, 128, 64)):
super().__init__()
self.chs = chs
self.upconvs = nn.ModuleList([nn.ConvTranspose2d(chs[i], chs[i+1], 2, 2) for i in range(len(chs)-1)])
self.dec_blocks = nn.ModuleList([Block(chs[i], chs[i+1]) for i in range(len(chs)-1)])
def forward(self, x, encoder_features):
for i in range(len(self.chs)-1):
x = self.upconvs[i](x)
enc_ftrs = self.crop(encoder_features[i], x)
x = torch.cat([x, enc_ftrs], dim=1)
x = self.dec_blocks[i](x)
return x
def crop(self, enc_ftrs, x):
_, _, H, W = x.shape
enc_ftrs = transforms.CenterCrop([H, W])(enc_ftrs)
return enc_ftrs
class UNet(nn.Module):
def __init__(self, enc_chs=(3, 64, 128, 256, 512, 1024), dec_chs=(1024, 512, 256, 128, 64)):
super().__init__()
self.encoder = Encoder(enc_chs)
self.decoder = Decoder(dec_chs)
self.head = nn.Conv2d(dec_chs[-1], 1, 1)
def forward(self, x):
enc_ftrs = self.encoder(x)
out = self.decoder(enc_ftrs[::-1][0], enc_ftrs[::-1][1:])
out = self.head(out)
return out
class PLUnet(pl.LightningModule):
def __init__(self):
super().__init__()
self.model = UNet()
def forward(self, x):
y_hat = self.model.forward(x)
return F.interpolate(y_hat, x.shape[-2:])
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=1e-3)
def _forward_loss(self, batch, batch_idx):
y_hat = self.forward(batch['image'])
return F.mse_loss(y_hat, batch['depth_map'])
def training_step(self, train_batch, batch_idx):
loss = self._forward_loss(train_batch, batch_idx)
self.log('train_loss', loss)
return loss
def validation_step(self, val_batch, batch_idx):
loss = self._forward_loss(val_batch, batch_idx)
self.log('val_loss', loss)
image_to_tensor = transforms.Compose([transforms.ToTensor()])
def transform(x):
x['image'] = image_to_tensor(x['image'])
x['depth_map'] = image_to_tensor(x['depth_map'])
return x
if __name__ == '__main__':
from datasets import load_dataset
from torch.utils.data import DataLoader
train_loader = DataLoader(load_dataset("sayakpaul/nyu_depth_v2", split='train', streaming=True).map(transform).with_format("torch"), num_workers=2, batch_size=10)
val_loader = DataLoader(load_dataset("sayakpaul/nyu_depth_v2", split='validation', streaming=True).map(transform).with_format("torch"), num_workers=2, batch_size=10)
trainer = pl.Trainer(accelerator='mps', devices=1, limit_train_batches=1, max_epochs=100)
trainer.fit(PLUnet(), train_loader, val_loader)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment