Skip to content

Instantly share code, notes, and snippets.

@lepotatoguy
Created April 25, 2023 05:39
Show Gist options
  • Save lepotatoguy/12aa82874503f986621b35d51a608d80 to your computer and use it in GitHub Desktop.
Save lepotatoguy/12aa82874503f986621b35d51a608d80 to your computer and use it in GitHub Desktop.
import torch
import torch.nn as nn
import torch.nn.functional as F
class CctBlock(nn.Module):
def __init__(self, in_channels, out_channels, num_heads, mlp_ratio=4.0):
super().__init__()
self.norm1 = nn.LayerNorm(in_channels)
self.attn = nn.MultiheadAttention(in_channels, num_heads)
self.norm2 = nn.LayerNorm(in_channels)
self.mlp = nn.Sequential(
nn.Linear(in_channels, int(in_channels * mlp_ratio)),
nn.GELU(),
nn.Linear(int(in_channels * mlp_ratio), out_channels),
)
def forward(self, x):
x_norm = self.norm1(x)
attn_output, _ = self.attn(x_norm, x_norm, x_norm)
x = x + attn_output
x_norm = self.norm2(x)
mlp_output = self.mlp(x_norm)
x = x + mlp_output
return x
class CctEncoder(nn.Module):
def __init__(self, in_channels, cct_block_params, num_layers):
super().__init__()
self.conv = nn.Conv2d(in_channels, cct_block_params[0][0], kernel_size=3, padding=1)
self.blocks = nn.ModuleList()
for i in range(num_layers):
in_channels, out_channels, num_heads, mlp_ratio = cct_block_params[i]
block = CctBlock(in_channels, out_channels, num_heads, mlp_ratio)
self.blocks.append(block)
def forward(self, x):
x = self.conv(x)
for block in self.blocks:
x = block(x)
return x
class CnnDecoder(nn.Module):
def __init__(self, in_channels, num_blocks, out_channels):
super().__init__()
self.blocks = nn.ModuleList()
for i in range(num_blocks):
self.blocks.append(nn.Conv2d(in_channels, in_channels // 2, kernel_size=3, padding=1))
in_channels //= 2
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
def forward(self, x):
for block in self.blocks:
x = F.interpolate(x, scale_factor=2, mode="nearest")
x = block(x)
x = F.relu(x)
x = self.conv(x)
return x
class InpaintingModel(nn.Module):
def __init__(self, cct_block_params=((576, 128, 8, 2.0),) * 5, num_blocks=5):
super().__init__()
self.encoder = CctEncoder(3, cct_block_params, num_layers=len(cct_block_params))
self.grid_generator = nn.Sequential(
nn.Conv2d(1, 64, kernel_size=3, stride=2, padding=1),
nn.ReLU(),
nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
nn.ReLU(),
nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
nn.ReLU(),
nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1),
nn.ReLU(),
nn.Conv2d(512, 1024, kernel_size=3, stride=2, padding=1),
nn.ReLU(),
nn.Conv2d(1024, 2048, kernel_size=3, stride=2, padding=1),
nn.ReLU(),
nn.ConvTranspose2d(2048, 1024, kernel_size=4, stride=2, padding=1),
nn.ReLU(),
)
self.decoder = CnnDecoder(1024, num_blocks, out_channels=3)
self.mask_conv = nn.Conv2d(3, 1, kernel_size=1)
self.sigmoid = nn.Sigmoid()
def forward(self, x, mask):
encoded_x = self.encoder(x)
batch_size, channels, height, width = encoded_x.size()
mask = F.interpolate(mask, size=(height, width), mode='bilinear', align_corners=False)
mask = self.sigmoid(self.mask_conv(mask))
masked_encoded_x = encoded_x * mask
grid = self.grid_generator(mask.unsqueeze(1))
grid = grid.expand(batch_size, -1, -1, -1)
deformed_masked_encoded_x = F.grid_sample(masked_encoded_x, grid, mode='bilinear', align_corners=False)
decoded_x = self.decoder(deformed_masked_encoded_x)
return decoded_x
import torch
import torch.nn.functional as F
import torchvision.transforms.functional as TF
import matplotlib.pyplot as plt
from PIL import Image
# define device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# create model
model = InpaintingModel().to(device)
# load image and resize
image_path = '/content/bac1.jpg'
image = Image.open(image_path).convert('RGB')
width, height = image.size
new_width = (width // 32) * 32 # make sure width is a multiple of 32
new_height = (height // 32) * 32 # make sure height is a multiple of 32
image = image.resize((new_width, new_height))
image_tensor = TF.to_tensor(image).unsqueeze(0).to(device)
# create random mask
mask_size = (new_height // 2, new_width // 2)
mask = torch.zeros(1, 1, *mask_size).to(device)
mask[..., :mask_size[0]//2, :mask_size[1]//2] = 1.0
# inpaint image
inpainted_tensor = model(image_tensor, mask)
# convert tensors to numpy arrays and show images
image_np = image_tensor.squeeze(0).cpu().numpy().transpose((1, 2, 0))
mask_np = mask.squeeze(0).cpu().numpy()
inpainted_np = inpainted_tensor.squeeze(0).cpu().numpy().transpose((1, 2, 0))
fig, axes = plt.subplots(ncols=3, figsize=(10, 5))
axes[0].imshow(image_np)
axes[0].set_title('Original Image')
axes[1].imshow(mask_np, cmap='gray')
axes[1].set_title('Mask')
axes[2].imshow(inpainted_np)
axes[2].set_title('Inpainted Image')
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment