Skip to content

Instantly share code, notes, and snippets.

@Logrus
Last active September 13, 2023 13:36
Show Gist options
  • Save Logrus/e5cd1b6f70f3898f56ecdf54fcdfcfa2 to your computer and use it in GitHub Desktop.
Save Logrus/e5cd1b6f70f3898f56ecdf54fcdfcfa2 to your computer and use it in GitHub Desktop.
Memorize an image with NN, made to play around with sin activation, positional encodings and etc.
"""
Recently Implicit Neural Representations gain popularity.
One of the issues there though is that the learned representations
are low-frequency biased, resulting in over-smoothed representations.
There has been a few approaches suggested to alleviate the issue,
for example by using positional encodings.
An alternative could be using Sin/Cos activation functions,
which in essence present a learnable basis functions.
A commonly used example to get a feel for a problem is
an image-memorization problem, where MLP has to map from (u, v)
pixel coordinates to (r,g,b) color.
Using ReLU in this example results in overly smoothed representation,
however using Sin/Cos activation is helping to represent higher frequencies better.
Another parameter that can be changed is normalizing (u, v) coordinates
between 0 and 1, which helps with ReLU activations, however unnormalized coordinates
work better with Sin/Cos activations (since periodic function wraps it back,
that doesn't destroy training), then the network converges almost instantly.
In addition, this code contains a demonstration of what MLP produces when asked about (u, v)
beyond image bounds.
Video available on Youtube:
https://youtu.be/AYFoXcl6zyU
"""
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import torch
plt.ion()
fig = plt.figure()
ax = fig.add_subplot(111)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# =====================================================================
# A sin activation
class Sin(torch.nn.Module):
__constants__ = ["inplace"]
inplace: bool
def __init__(self, inplace: bool = False):
super(Sin, self).__init__()
self.inplace = inplace
def forward(self, input: torch.Tensor) -> torch.Tensor:
return torch.sin(input)
def extra_repr(self) -> str:
inplace_str = "inplace=True" if self.inplace else ""
return inplace_str
class Cos(torch.nn.Module):
__constants__ = ["inplace"]
inplace: bool
def __init__(self, inplace: bool = False):
super(Cos, self).__init__()
self.inplace = inplace
def forward(self, input: torch.Tensor) -> torch.Tensor:
return torch.cos(input)
def extra_repr(self) -> str:
inplace_str = "inplace=True" if self.inplace else ""
return inplace_str
# =====================================================================
# =====================================================================
# Network
class StupidNet(torch.nn.Module):
def __init__(self) -> None:
super(StupidNet, self).__init__()
# activation = Sin()
# activation = torch.nn.LogSigmoid()
activation = torch.nn.ReLU()
# activation = Cos()
# A relatively high amount of neurons is needed
# for learning a proper representation
self.layers = torch.nn.Sequential(
torch.nn.Linear(2, 512),
Sin(),
torch.nn.Linear(512, 512),
activation,
torch.nn.Linear(512, 512),
activation,
torch.nn.Linear(512, 32),
activation,
torch.nn.Linear(32, 32),
activation,
torch.nn.Linear(32, 3),
)
def forward(self, coord):
color = self.layers(coord)
return color
# =====================================================================
# Load and normalize an image
image = Image.open("cat_small.jpeg")
image_array = np.array(image, dtype=np.float32)
# Image is normalized in [-0.5, 0.5]
image_normalized = (image_array / 255.0) - 0.5
H, W, _ = image_normalized.shape
print(f"Image size, height: {H}, width {W}")
# =====================================================================
# Create training data
X = np.zeros((H * W, 2), dtype=np.float32)
Y = np.zeros((H * W, 3), dtype=np.float32)
for i in range(H):
for j in range(W):
# Normalized coordinates, work better with ReLU
X[i * W + j] = np.array([i / H, j / W])
# Unnormalized coordinates, work better with Sin
# X[i * W + j] = np.array([i, j])
Y[i * W + j] = image_normalized[i, j]
# =====================================================================
# Query MLP beyond learned data with some padding around
padding = 100
X_with_padding = np.zeros(((H + padding * 2) * (W + padding * 2), 2), dtype=np.float32)
for i in range(-padding, H + padding):
for j in range(-padding, W + padding):
# Unnormalized coordinates
# X_with_padding[(i+padding)*(W+padding*2) + (j+padding)] = np.array([i,j])
# Normalized coordinates
X_with_padding[(i + padding) * (W + padding * 2) + (j + padding)] = np.array(
[i / H, j / W]
)
X_tensor_padding = torch.tensor(X_with_padding).to(device)
# The dataset is small so no batching is needed
# everything can be loaded in GPU memory
X_tensor = torch.tensor(X).to(device)
Y_tensor = torch.tensor(Y).to(device)
# =====================================================================
# Show original image
ax.set_title("Original image")
ax.imshow(Y.reshape((H, W, 3)) + 0.5)
plt.pause(1.0)
def train_loop(X, y, model, loss_fn, optimizer):
# Compute prediction and loss
pred = model(X)
loss = loss_fn(pred, y)
# Backpropagation
optimizer.zero_grad()
loss.backward()
optimizer.step()
loss = loss.item()
print(f"loss: {loss:>7f}")
def extract_image(model, X_in):
pred = model.forward(X_in)
np_arr = pred.cpu().detach().numpy()
min, max = np.min(np_arr), np.max(np_arr)
print("Image min max ", min, max)
# Predicted image overflows the allowed range [0, 1]
# so re-normalization is possible, although not required,
# the shown image is still ok
# image = (np_arr.reshape((H, W, 3)) - min) / (max - min)
image = np_arr.reshape((H + padding * 2, W + padding * 2, 3)) + 0.5
return image
# =====================================================================
# Initialize model and optimizer
model = StupidNet().to(device)
loss_fn = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
epochs = 50000
for e in range(epochs):
train_loop(X_tensor, Y_tensor, model, loss_fn, optimizer)
if (e % 100) == 0:
image_learned = extract_image(model, X_tensor_padding)
ax.imshow(image_learned)
ax.set_title(f"Epoch {e}")
plt.pause(0.001)
# plt.savefig("training_beoyond_edges/image_{:06}".format(e))
@Logrus
Copy link
Author

Logrus commented Jun 3, 2022

Test image:
cat_small

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment