Created
April 17, 2023 00:51
-
-
Save khannay/08e6b1c3c9ccd0a68d2b101b94128c79 to your computer and use it in GitHub Desktop.
Differential Equations as a Pytorch Neural Network Layer
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#| export | |
import torch | |
import torch.nn as nn | |
from torchdiffeq import odeint_adjoint as odeint | |
import pylab as plt | |
from torch.utils.data import Dataset, DataLoader | |
from typing import Callable, List, Tuple, Union, Optional |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
if torch.cuda.is_available(): | |
device = torch.device('cuda') | |
else: | |
device = torch.device('cpu') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#| exports | |
class VDP(nn.Module): | |
""" | |
Define the Van der Pol oscillator as a PyTorch module. | |
""" | |
def __init__(self, | |
mu: float, # Stiffness parameter of the VDP oscillator | |
): | |
super().__init__() | |
self.mu = torch.nn.Parameter(torch.tensor(mu)) # make mu a learnable parameter | |
def forward(self, | |
t: float, # time index | |
state: torch.TensorType, # state of the system first dimension is the batch size | |
) -> torch.Tensor: # return the derivative of the state | |
""" | |
Define the right hand side of the VDP oscillator. | |
""" | |
x = state[:, 0] # first dimension is the batch size | |
y = state[:, 1] | |
dX = self.mu*(x-1/3*x**3 - y) | |
dY = 1/self.mu*x | |
dfunc = torch.zeros_like(state) # trick to make sure our ret | |
dfunc[:, 0] = dX | |
dfunc[:, 1] = dY | |
return dfunc | |
def __repr__(self): | |
"""Print the parameters of the model.""" | |
return f" mu: {self.mu.item()}" | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
vdp_model = VDP(mu=0.5) | |
# Create a time vector, this is the time axis of the ODE | |
ts = torch.linspace(0,30.0,1000) | |
# Create a batch of initial conditions | |
batch_size = 30 | |
# Creates some random initial conditions | |
initial_conditions = torch.tensor([0.01, 0.01]) + 0.2*torch.randn((batch_size,2)) | |
# Solve the ODE, odeint comes from torchdiffeq | |
sol = odeint(vdp_model, initial_conditions, ts, method='dopri5').detach().numpy() | |
# Check the solution | |
plt.plot(sol[:,:,0], sol[:,:,1], color='black', lw=0.5); | |
plt.title("Phase plot of the VDP oscillator"); | |
plt.xlabel("x"); | |
plt.ylabel("y"); | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#| exports | |
class LotkaVolterra(nn.Module): | |
""" | |
The Lotka-Volterra equations are a pair of first-order, non-linear, differential equations | |
describing the dynamics of two species interacting in a predator-prey relationship. | |
""" | |
def __init__(self, | |
alpha: float = 1.5, # The alpha parameter of the Lotka-Volterra system | |
beta: float = 1.0, # The beta parameter of the Lotka-Volterra system | |
delta: float = 3.0, # The delta parameter of the Lotka-Volterra system | |
gamma: float = 1.0 # The gamma parameter of the Lotka-Volterra system | |
) -> None: | |
super().__init__() | |
self.model_params = torch.nn.Parameter(torch.tensor([alpha, beta, delta, gamma])) | |
def forward(self, t, state): | |
x = state[:,0] #variables are part of vector array u | |
y = state[:,1] | |
sol = torch.zeros_like(state) | |
alpha, beta, delta, gamma = self.model_params #coefficients are part of vector array p | |
sol[:,0] = alpha*x - beta*x*y | |
sol[:,1] = -delta*y + gamma*x*y | |
return sol | |
def __repr__(self): | |
return f" alpha: {self.model_params[0].item()}, beta: {self.model_params[1].item()}, delta: {self.model_params[2].item()}, gamma: {self.model_params[3].item()}" | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
lv_model = LotkaVolterra() | |
ts = torch.linspace(0,30.0,1000) | |
# Create a batch of initial conditions (batch_dim, state_dim) as small perturbations around one value | |
initial_conditions = torch.tensor([[3,3]]) + 0.50*torch.randn((30,2)) | |
sol = odeint(lv_model, initial_conditions, ts, method='dopri5').detach().numpy() | |
# Check the solution | |
plt.plot(ts, sol[:,:,0], lw=0.5); | |
plt.title("Time series of the Lotka-Volterra system"); | |
plt.xlabel("time"); | |
plt.ylabel("x"); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
plt.plot(sol[:,:,0], sol[:,:,1], lw=0.5); | |
plt.title("Phase plot of the Lotka-Volterra system"); | |
plt.xlabel("x"); | |
plt.ylabel("y"); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#| exports | |
class Lorenz(nn.Module): | |
""" | |
Define the Lorenz system as a PyTorch module. | |
""" | |
def __init__(self, | |
sigma: float =10.0, # The sigma parameter of the Lorenz system | |
rho: float=28.0, # The rho parameter of the Lorenz system | |
beta: float=8.0/3, # The beta parameter of the Lorenz system | |
): | |
super().__init__() | |
self.model_params = torch.nn.Parameter(torch.tensor([sigma, rho, beta])) | |
def forward(self, t, state): | |
x = state[:,0] #variables are part of vector array u | |
y = state[:,1] | |
z = state[:,2] | |
sol = torch.zeros_like(state) | |
sigma, rho, beta = self.model_params #coefficients are part of vector array p | |
sol[:,0] = sigma*(y-x) | |
sol[:,1] = x*(rho-z) - y | |
sol[:,2] = x*y - beta*z | |
return sol | |
def __repr__(self): | |
return f" sigma: {self.model_params[0].item()}, rho: {self.model_params[1].item()}, beta: {self.model_params[2].item()}" | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
lorenz_model = Lorenz() | |
ts = torch.linspace(0,50.0,3000) | |
# Create a batch of initial conditions (batch_dim, state_dim) as small perturbations around one value | |
initial_conditions = torch.tensor([[1.0,0.0,0.0]]) + 0.10*torch.randn((30,3)) | |
sol = odeint(lorenz_model, initial_conditions, ts, method='dopri5').detach().numpy() | |
# Check the solution | |
plt.plot(ts[:2000], sol[:2000,:,0], lw=0.5); | |
plt.title("Time series of the Lorenz system"); | |
plt.xlabel("time"); | |
plt.ylabel("x"); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
plt.plot(sol[:,0,0], sol[:,0,1], color='black', lw=0.5); | |
plt.title("Phase plot of the Lorenz system"); | |
plt.xlabel("x"); | |
plt.ylabel("y"); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#| exports | |
class SimODEData(Dataset): | |
""" | |
A very simple dataset class for simulating ODEs | |
""" | |
def __init__(self, | |
ts: List[torch.Tensor], # List of time points as tensors | |
values: List[torch.Tensor], # List of dynamical state values (tensor) at each time point | |
) -> None: | |
self.ts = ts | |
self.values = values | |
def __len__(self) -> int: | |
return len(self.ts) | |
def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]: | |
return self.ts[index], self.values[index] | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#| export | |
def create_sim_dataset(model: nn.Module, # model to simulate from | |
ts: torch.Tensor, # Time points to simulate at | |
num_samples: int = 10, # Number of samples to generate | |
sigma_noise: float = 0.1, # Noise level to add to the data | |
initial_conditions_default: torch.Tensor = torch.tensor([0.0, 0.0]), # Default initial conditions | |
sigma_initial_conditions: float = 0.1, # Noise level to add to the initial conditions | |
) -> SimODEData: | |
ts_list = [] | |
states_list = [] | |
dim = initial_conditions_default.shape[0] | |
for i in range(num_samples): | |
x0 = sigma_initial_conditions * torch.randn((1,dim)).detach() + initial_conditions_default | |
ys = odeint(model, x0, ts).squeeze(1).detach() | |
ys += sigma_noise*torch.randn_like(ys) | |
ts_list.append(ts) | |
states_list.append(ys) | |
return SimODEData(ts_list, states_list) | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#| exports | |
def train(model: torch.nn.Module, # Model to train | |
data: SimODEData, # Data to train on | |
lr: float = 1e-2, # learning rate for the Adam optimizer | |
epochs: int = 10, # Number of epochs to train for | |
batch_size: int = 5, # Batch size for training | |
): | |
trainloader = DataLoader(data, batch_size=batch_size, shuffle=True) | |
optimizer = torch.optim.Adam(model.parameters(), lr=lr) | |
criterion = torch.nn.MSELoss() | |
for epoch in range(epochs): | |
running_loss = 0.0 | |
for data in trainloader: | |
optimizer.zero_grad() # reset gradients | |
ts, states = data | |
initial_state = states[:,0,:] # grab the initial state | |
pred = odeint(model, initial_state, ts[0]).transpose(0,1) | |
loss = criterion(pred, states) | |
loss.backward() # compute gradients | |
optimizer.step() # update parameters | |
running_loss += loss.item() # record loss | |
if epoch % 10 == 0: | |
print(f"Loss at {epoch}: {running_loss}") | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
true_mu = 0.30 | |
model_sim = VDP(mu=true_mu) | |
ts_data = torch.linspace(0.0,10.0,10) | |
data_vdp = create_sim_dataset(model_sim, | |
ts = ts_data, | |
num_samples=10, | |
sigma_noise=0.01) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
vdp_model = VDP(mu = 0.10) | |
ts = torch.linspace(0,30.0,1000) | |
ts_data, y_data = data_vdp[0] | |
initial_conditions = y_data[0, :].unsqueeze(0) | |
sol_initial_guess = odeint(vdp_model, initial_conditions, ts, method='dopri5').detach().numpy() | |
sol_true = odeint(model_sim, initial_conditions, ts, method='dopri5').detach().numpy() | |
# Check the solution | |
plt.plot(ts, sol_initial_guess[:,:,1], color='black', lw=1.0, label='Unfit model'); | |
plt.scatter(ts_data.detach(), y_data[:,1].detach(), color='red', s=30, label='Data'); | |
plt.plot(ts, sol_true[:,:,1], color='red', ls='--', lw=1.0, label='True model'); | |
plt.title("VDP Model: Initial Parameter Guess"); | |
plt.xlabel("t"); | |
plt.ylabel("y"); | |
plt.legend(); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
train(vdp_model, data_vdp, epochs=50) | |
print(f"After training: {vdp_model}, where the true value is {true_mu}") | |
print(f"Final Parameter Recovery Error: {vdp_model.mu - true_mu}") | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
sol_after_fitting = odeint(vdp_model, initial_conditions, ts, method='dopri5').detach().numpy() | |
# Check the solution | |
plt.plot(ts, sol_after_fitting[:,:,1], color='black', lw=1.0, label='Fit model'); | |
plt.scatter(ts_data.detach(), y_data[:,1].detach(), color='red', s=30, label='Data'); | |
plt.plot(ts, sol_true[:,:,1], color='red', ls='--', lw=1.0, label='True model'); | |
plt.title("VDP Model: After Training"); | |
plt.xlabel("t"); | |
plt.ylabel("y"); | |
plt.legend(); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
plt.plot(sol_after_fitting[:,:,0], sol_after_fitting[:,:,1], color='black', lw=1.0, label='Fit model'); | |
plt.scatter(y_data[:,0], y_data[:,1].detach(), color='red', s=30, label='Data'); | |
plt.plot(sol_true[:,:,0], sol_true[:,:,1], color='red', ls='--', lw=1.0, label='True model'); | |
plt.title("VDP Model: After Training"); | |
plt.xlabel("t"); | |
plt.ylabel("y"); | |
plt.legend(); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
model_sim_lv = LotkaVolterra() | |
ts_data = torch.linspace(0.0,10.0,10) | |
data_lv = create_sim_dataset(model_sim_lv, | |
ts = ts_data, | |
num_samples=10, | |
sigma_noise=0.1, | |
initial_conditions_default=torch.tensor([2.5, 2.5])) | |
print(model_sim_lv) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
model_lv = LotkaVolterra(alpha=1.6, beta=1.1,delta=2.7, gamma=1.2) | |
ts = torch.linspace(0,30.0,1000) | |
ts_data, y_data = data_lv[0] | |
initial_conditions = y_data[0, :].unsqueeze(0) | |
sol_initial_guess = odeint(model_lv, initial_conditions, ts, method='dopri5').detach().numpy() | |
sol_true = odeint(model_sim_lv, initial_conditions, ts, method='dopri5').detach().numpy() | |
# Check the solution | |
plt.plot(ts, sol_initial_guess[:,:,1], color='black', lw=1.0, label='Unfit model'); | |
plt.scatter(ts_data.detach(), y_data[:,1].detach(), color='red', s=30, label='Data'); | |
plt.plot(ts, sol_true[:,:,1], color='red', ls='--', lw=1.0, label='True model'); | |
plt.title("Lotka-Volterra Model: Initial Parameter Guess"); | |
plt.xlabel("t"); | |
plt.ylabel("y"); | |
plt.legend(); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
train(model_lv, data_lv, epochs=60, lr=1e-2) | |
print(f"Fitted model: {model_lv}") | |
print(f"True model: {model_sim_lv}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
initial_conditions = y_data[0, :].unsqueeze(0) | |
sol_after_fitting = odeint(model_lv, initial_conditions, ts, method='dopri5').detach().numpy() | |
# Check the solution | |
plt.plot(ts, sol_after_fitting[:,:,1], color='black', lw=2.0, label='Fit model'); | |
plt.scatter(ts_data, y_data[:,1], color='red', s=30, label='Data'); | |
plt.plot(ts, sol_true[:,:,1], color='red', ls='--', lw=1.0, label='True model'); | |
plt.title("Lotka-Voltera: After Training"); | |
plt.xlabel("t"); | |
plt.ylabel("y"); | |
plt.legend(); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Check the solution | |
plt.plot(sol_after_fitting[:,:,0], sol_after_fitting[:,:,1], color='black', lw=2.0, label='Fit model'); | |
plt.scatter(y_data[:,0], y_data[:,1], color='red', s=30, label='Data'); | |
plt.plot(sol_true[:,:,0], sol_true[:,:,1], color='red', ls='--', lw=1.0, label='True model'); | |
plt.title("Lotka-Voltera: After Training"); | |
plt.xlabel("t"); | |
plt.ylabel("y"); | |
plt.legend(); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
model_sim = Lorenz(sigma=10.0, rho=28.0, beta=8.0/3.0) | |
ts_data = torch.linspace(0.0,10.0,10) | |
data_lorenz = create_sim_dataset(model_sim, | |
ts = ts_data, | |
num_samples=10, | |
initial_conditions_default=torch.tensor([8.0, 8.0, 27.0]), | |
sigma_noise=0.01) | |
_, y_data = data_lorenz[0] | |
plt.plot(ts_data, y_data[:,0]); | |
plt.plot(ts_data, y_data[:,1]); | |
plt.title("Simulated Data: Lorenz Attractor"); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
lorenz_model = Lorenz(sigma=11.0, rho=23.0, beta=1.2) | |
print(f"The starting point: {lorenz_model}") | |
ts = torch.linspace(0,30.0,1000) | |
ts_data, y_data = data_lorenz[0] | |
initial_conditions = y_data[0, :].unsqueeze(0) | |
sol = odeint(lorenz_model, initial_conditions, ts, method='dopri5').detach().numpy() | |
# Check the solution | |
plt.plot(ts, sol[:,:,1], color='black', lw=1.0); | |
plt.scatter(ts_data.detach(), y_data[:,1].detach(), color='red', s=30); | |
plt.title("Lorenz plot before fitting"); | |
plt.xlabel("t"); | |
plt.ylabel("y"); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
train(lorenz_model, | |
data_lorenz, | |
epochs=10, | |
lr = 1e-3) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
ts = torch.linspace(0,30.0,1000) | |
ts_data, y_data = data_lorenz[0] | |
initial_conditions = y_data[0, :].unsqueeze(0) | |
sol = odeint(lorenz_model, initial_conditions, ts, method='dopri5').detach().numpy() | |
# Check the solution | |
plt.plot(ts, sol[:,:,1], color='black', lw=1.0); | |
plt.scatter(ts_data.detach(), y_data[:,1].detach(), color='red', s=30); | |
plt.title("Lorenz plot before fitting"); | |
plt.xlabel("t"); | |
plt.ylabel("y"); |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment