Created
May 7, 2023 14:14
-
-
Save madagra/64afe1b56ff5656b2b1acb19cc68f477 to your computer and use it in GitHub Desktop.
Fitting a function with functional PyTorch
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import torch | |
from torch import nn | |
from torch import Tensor | |
from torch.func import functional_call | |
import torchopt | |
import matplotlib.pyplot as plt | |
class SimpleNN(nn.Module): | |
def __init__( | |
self, | |
num_layers: int = 1, | |
num_neurons: int = 5, | |
) -> None: | |
"""Basic neural network architecture with linear layers | |
Args: | |
num_layers (int, optional): number of hidden layers | |
num_neurons (int, optional): neurons for each hidden layer | |
""" | |
super().__init__() | |
layers = [] | |
# input layer | |
layers.append(nn.Linear(1, num_neurons)) | |
# hidden layers | |
for _ in range(num_layers): | |
layers.extend([nn.Linear(num_neurons, num_neurons), nn.Tanh()]) | |
# output layer | |
layers.append(nn.Linear(num_neurons, 1)) | |
# build the network | |
self.network = nn.Sequential(*layers) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
return self.network(x.reshape(-1, 1)).squeeze() | |
def make_functional_fwd(model: torch.nn.Module): | |
"""Make a functional forward pass for a generic module | |
This function is compatible with the torchopt library which | |
returns the updated parameters as a tuple while the `functional_call` | |
routine requires parameters dictionary. This conversion is automatically | |
implemented by this function | |
""" | |
keys = list(dict(model.named_parameters()).keys()) | |
def fn(data: Tensor, parameters: tuple[Tensor, ...]): | |
params_dict = {k: v for k, v in zip(keys, parameters)} | |
return functional_call(model, params_dict, (data,)) | |
return fn | |
def get_data(n_points: int = 20) -> tuple[Tensor, Tensor]: | |
"""Prepare the input data for training/test sets""" | |
x = torch.rand(n_points) * 2.0 * torch.pi | |
y = 2.0 * torch.sin(x + 2.0 * torch.pi) | |
return x, y | |
if __name__ == "__main__": | |
torch.manual_seed(42) | |
model = SimpleNN(num_layers=2) | |
model_func = make_functional_fwd(model) | |
x_train, y_train = get_data(n_points=40) | |
x_test, y_test = get_data(n_points=10) | |
# choose optimizer with functional API using functorch | |
num_epochs = 500 | |
lr = 0.01 | |
optimizer = torchopt.FuncOptimizer(torchopt.adam(lr=lr)) | |
loss_fn = torch.nn.MSELoss() | |
# train the model | |
loss_evolution = [] | |
params = tuple(model.parameters()) | |
for i in range(num_epochs): | |
# update the parameters | |
y = model_func(x_train, params) | |
loss = loss_fn(y, y_train) | |
params = optimizer.step(loss, params) | |
if i % 100 == 0: | |
print(f"Iteration {i} with loss {float(loss)}") | |
loss_evolution.append(float(loss)) | |
# performance on the model on the test set | |
y_pred = model_func(x_test, params) | |
print(f"Loss on the test set: {loss_fn(y_pred, y_test)}") | |
# plot the final predictions | |
def fn_to_fit(x): | |
return 2.0 * np.sin(x + 2.0 * np.pi) | |
x_analyt = np.linspace(0, 2.0 * np.pi, 100) | |
y_analyt = fn_to_fit(x_analyt) | |
x_test_np = x_test.detach().numpy() | |
y_test_np = y_test.detach().numpy() | |
y_pred_np = y_pred.detach().numpy() | |
plt.figure(1) | |
plt.plot(x_analyt, y_analyt, label="Analytical") | |
plt.scatter(x_test_np, y_test_np, label="True") | |
plt.scatter(x_test_np, y_pred_np, label="Predicted") | |
plt.legend() | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment