Created
April 22, 2023 19:51
-
-
Save khannay/be087c33f7651107b8d6c75111111a83 to your computer and use it in GitHub Desktop.
Differential Equations as a Pytorch Neural Network Layer
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#| export | |
import torch | |
import torch.nn as nn | |
from torchdiffeq import odeint as odeint | |
import pylab as plt | |
from torch.utils.data import Dataset, DataLoader | |
from typing import Callable, List, Tuple, Union, Optional | |
from pathlib import Path |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#| exports | |
class VDP(nn.Module): | |
""" | |
Define the Van der Pol oscillator as a PyTorch module. | |
""" | |
def __init__(self, | |
mu: float, # Stiffness parameter of the VDP oscillator | |
): | |
super().__init__() | |
self.mu = torch.nn.Parameter(torch.tensor(mu)) # make mu a learnable parameter | |
def forward(self, | |
t: float, # time index | |
state: torch.TensorType, # state of the system first dimension is the batch size | |
) -> torch.Tensor: # return the derivative of the state | |
""" | |
Define the right hand side of the VDP oscillator. | |
""" | |
x = state[..., 0] # first dimension is the batch size | |
y = state[..., 1] | |
dX = self.mu*(x-1/3*x**3 - y) | |
dY = 1/self.mu*x | |
# trick to make sure our return value has the same shape as the input | |
dfunc = torch.zeros_like(state) | |
dfunc[..., 0] = dX | |
dfunc[..., 1] = dY | |
return dfunc | |
def __repr__(self): | |
"""Print the parameters of the model.""" | |
return f" mu: {self.mu.item()}" | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
vdp_model = VDP(mu=0.5) | |
# Create a time vector, this is the time axis of the ODE | |
ts = torch.linspace(0,30.0,1000) | |
# Create a batch of initial conditions | |
batch_size = 30 | |
# Creates some random initial conditions | |
initial_conditions = torch.tensor([0.01, 0.01]) + 0.2*torch.randn((batch_size,2)) | |
# Solve the ODE, odeint comes from torchdiffeq | |
sol = odeint(vdp_model, initial_conditions, ts, method='dopri5').detach().numpy() | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
plt.plot(ts, sol[:,:,0], lw=0.5); | |
plt.title("Time series of the VDP oscillator"); | |
plt.xlabel("time"); | |
plt.ylabel("x"); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Check the solution | |
plt.plot(sol[:,:,0], sol[:,:,1], lw=0.5); | |
plt.title("Phase plot of the VDP oscillator"); | |
plt.xlabel("x"); | |
plt.ylabel("y"); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#| exports | |
class LotkaVolterra(nn.Module): | |
""" | |
The Lotka-Volterra equations are a pair of first-order, non-linear, differential equations | |
describing the dynamics of two species interacting in a predator-prey relationship. | |
""" | |
def __init__(self, | |
alpha: float = 1.5, # The alpha parameter of the Lotka-Volterra system | |
beta: float = 1.0, # The beta parameter of the Lotka-Volterra system | |
delta: float = 3.0, # The delta parameter of the Lotka-Volterra system | |
gamma: float = 1.0 # The gamma parameter of the Lotka-Volterra system | |
) -> None: | |
super().__init__() | |
self.model_params = torch.nn.Parameter(torch.tensor([alpha, beta, delta, gamma])) | |
def forward(self, t, state): | |
x = state[...,0] #variables are part of vector array u | |
y = state[...,1] | |
sol = torch.zeros_like(state) | |
#coefficients are part of tensor model_params | |
alpha, beta, delta, gamma = self.model_params | |
sol[...,0] = alpha*x - beta*x*y | |
sol[...,1] = -delta*y + gamma*x*y | |
return sol | |
def __repr__(self): | |
return f" alpha: {self.model_params[0].item()}, \ | |
beta: {self.model_params[1].item()}, \ | |
delta: {self.model_params[2].item()}, \ | |
gamma: {self.model_params[3].item()}" | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
lv_model = LotkaVolterra() #use default parameters | |
ts = torch.linspace(0,30.0,1000) | |
batch_size = 30 | |
# Create a batch of initial conditions (batch_dim, state_dim) as small perturbations around one value | |
initial_conditions = torch.tensor([[3,3]]) + 0.50*torch.randn((batch_size,2)) | |
sol = odeint(lv_model, initial_conditions, ts, method='dopri5').detach().numpy() | |
# Check the solution | |
plt.plot(ts, sol[:,:,0], lw=0.5); | |
plt.title("Time series of the Lotka-Volterra system"); | |
plt.xlabel("time"); | |
plt.ylabel("x"); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
plt.plot(sol[:,:,0], sol[:,:,1], lw=0.5); | |
plt.title("Phase plot of the Lotka-Volterra system"); | |
plt.xlabel("x"); | |
plt.ylabel("y"); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#| exports | |
class Lorenz(nn.Module): | |
""" | |
Define the Lorenz system as a PyTorch module. | |
""" | |
def __init__(self, | |
sigma: float =10.0, # The sigma parameter of the Lorenz system | |
rho: float=28.0, # The rho parameter of the Lorenz system | |
beta: float=8.0/3, # The beta parameter of the Lorenz system | |
): | |
super().__init__() | |
self.model_params = torch.nn.Parameter(torch.tensor([sigma, rho, beta])) | |
def forward(self, t, state): | |
x = state[...,0] #variables are part of vector array u | |
y = state[...,1] | |
z = state[...,2] | |
sol = torch.zeros_like(state) | |
sigma, rho, beta = self.model_params #coefficients are part of vector array p | |
sol[...,0] = sigma*(y-x) | |
sol[...,1] = x*(rho-z) - y | |
sol[...,2] = x*y - beta*z | |
return sol | |
def __repr__(self): | |
return f" sigma: {self.model_params[0].item()}, \ | |
rho: {self.model_params[1].item()}, \ | |
beta: {self.model_params[2].item()}" | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
lorenz_model = Lorenz() | |
ts = torch.linspace(0,50.0,3000) | |
batch_size = 30 | |
# Create a batch of initial conditions (batch_dim, state_dim) as small perturbations around one value | |
initial_conditions = torch.tensor([[1.0,0.0,0.0]]) + 0.10*torch.randn((batch_size,3)) | |
sol = odeint(lorenz_model, initial_conditions, ts, method='dopri5').detach().numpy() | |
# Check the solution | |
plt.plot(ts[:2000], sol[:2000,:,0], lw=0.5); | |
plt.title("Time series of the Lorenz system"); | |
plt.xlabel("time"); | |
plt.ylabel("x"); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
plt.plot(sol[:,0,0], sol[:,0,1], color='black', lw=0.5); | |
plt.title("Phase plot of the Lorenz system"); | |
plt.xlabel("x"); | |
plt.ylabel("y"); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#| exports | |
class SimODEData(Dataset): | |
""" | |
A very simple dataset class for simulating ODEs | |
""" | |
def __init__(self, | |
ts: List[torch.Tensor], # List of time points as tensors | |
values: List[torch.Tensor], # List of dynamical state values (tensor) at each time point | |
true_model: Union[torch.nn.Module,None] = None, | |
) -> None: | |
self.ts = ts | |
self.values = values | |
self.true_model = true_model | |
def __len__(self) -> int: | |
return len(self.ts) | |
def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]: | |
return self.ts[index], self.values[index] | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#| export | |
def create_sim_dataset(model: nn.Module, # model to simulate from | |
ts: torch.Tensor, # Time points to simulate for | |
num_samples: int = 10, # Number of samples to generate | |
sigma_noise: float = 0.1, # Noise level to add to the data | |
initial_conditions_default: torch.Tensor = torch.tensor([0.0, 0.0]), # Default initial conditions | |
sigma_initial_conditions: float = 0.1, # Noise level to add to the initial conditions | |
) -> SimODEData: | |
ts_list = [] | |
states_list = [] | |
dim = initial_conditions_default.shape[0] | |
for i in range(num_samples): | |
x0 = sigma_initial_conditions * torch.randn((1,dim)).detach() + initial_conditions_default | |
ys = odeint(model, x0, ts).squeeze(1).detach() | |
ys += sigma_noise*torch.randn_like(ys) | |
ys[0,:] = x0 # Set the first value to the initial condition | |
ts_list.append(ts) | |
states_list.append(ys) | |
return SimODEData(ts_list, states_list, true_model=model) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#| export | |
def plot_time_series(true_model: torch.nn.Module, # true underlying model for the simulated data | |
fit_model: torch.nn.Module, # model fit to the data | |
data: SimODEData, # data set to plot (scatter) | |
time_range: tuple = (0.0, 30.0), # range of times to simulate the models for | |
ax: plt.Axes = None, | |
dyn_var_idx: int = 0, | |
title: str = "Model fits", | |
*args, | |
**kwargs) -> Tuple[plt.Figure, plt.Axes]: | |
""" | |
Plot the true model and fit model on the same axes. | |
""" | |
if ax is None: | |
fig, ax = plt.subplots() | |
else: | |
fig = ax.get_figure() | |
vdp_model = VDP(mu = 0.10) | |
ts = torch.linspace(time_range[0], time_range[1], 1000) | |
ts_data, y_data = data | |
initial_conditions = y_data[0, :].unsqueeze(0) | |
sol_pred = odeint(fit_model, initial_conditions, ts, method='dopri5').detach().numpy() | |
sol_true = odeint(true_model, initial_conditions, ts, method='dopri5').detach().numpy() | |
ax.plot(ts, sol_pred[:,:,dyn_var_idx], color='skyblue', lw=2.0, label='Predicted', **kwargs); | |
ax.scatter(ts_data.detach(), y_data[:,dyn_var_idx].detach(), color='black', s=30, label='Data', **kwargs); | |
ax.plot(ts, sol_true[:,:,dyn_var_idx], color='black', ls='--', lw=1.0, label='True model', **kwargs); | |
ax.set_title(title); | |
ax.set_xlabel("t"); | |
ax.set_ylabel("y"); | |
plt.legend(); | |
return fig, ax |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#| export | |
def plot_phase_plane(true_model: torch.nn.Module, # true underlying model for the simulated data | |
fit_model: torch.nn.Module, # model fit to the data | |
data: SimODEData, # data set to plot (scatter) | |
time_range: tuple = (0.0, 30.0), # range of times to simulate the models for | |
ax: plt.Axes = None, | |
dyn_var_idx: tuple = (0,1), | |
title: str = "Model fits", | |
*args, | |
**kwargs) -> Tuple[plt.Figure, plt.Axes]: | |
""" | |
Plot the true model and fit model on the same axes. | |
""" | |
if ax is None: | |
fig, ax = plt.subplots() | |
else: | |
fig = ax.get_figure() | |
ts = torch.linspace(time_range[0], time_range[1], 1000) | |
ts_data, y_data = data | |
initial_conditions = y_data[0, :].unsqueeze(0) | |
sol_pred = odeint(fit_model, initial_conditions, ts, method='dopri5').detach().numpy() | |
sol_true = odeint(true_model, initial_conditions, ts, method='dopri5').detach().numpy() | |
idx1, idx2 = dyn_var_idx | |
ax.plot(sol_pred[:,:,idx1], sol_pred[:,:,idx2], color='skyblue', lw=1.0, label='Fit model'); | |
ax.scatter(y_data[:,idx1], y_data[:,idx2].detach(), color='black', s=30, label='Data'); | |
ax.plot(sol_true[:,:,idx1], sol_true[:,:,idx2], color='black', ls='–', lw=1.0, label='True model'); | |
ax.set_xlabel(r'x') | |
ax.set_ylabel(r'y') | |
ax.set_title(title) | |
return fig, ax | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#| exports | |
def train(model: torch.nn.Module, # Model to train | |
data: SimODEData, # Data to train on | |
lr: float = 1e-2, # learning rate for the Adam optimizer | |
epochs: int = 10, # Number of epochs to train for | |
batch_size: int = 5, # Batch size for training | |
method = 'rk4', # ODE solver to use | |
step_size: float = 0.10, # for fixed diffeq solver set the step size | |
show_every: int = 10, # How often to print the loss function message | |
save_plots_every: Union[int,None] = None, # save a plot of the fit, to disable make this None | |
model_name: str = "", #string for the model, used to reference the saved plots | |
*args: tuple, | |
**kwargs: dict | |
): | |
# Create a data loader to iterate over the data. This takes in our dataset and returns batches of data | |
trainloader = DataLoader(data, batch_size=batch_size, shuffle=True) | |
# Choose an optimizer. Adam is a good default choice as a fancy gradient descent | |
optimizer = torch.optim.Adam(model.parameters(), lr=lr) | |
# Create a loss function this computes the error between the predicted and true values | |
criterion = torch.nn.MSELoss() | |
for epoch in range(epochs): | |
running_loss = 0.0 | |
for batchdata in trainloader: | |
optimizer.zero_grad() # reset gradients, famous gotcha in a pytorch training loop | |
ts, states = batchdata # unpack the data | |
initial_state = states[:,0,:] # grab the initial state | |
# Make the prediction and then flip the dimensions to be (batch, state_dim, time) | |
# Pytorch expects the batch dimension to be first | |
pred = odeint(model, initial_state, ts[0], method=method, options={'step_size': step_size}).transpose(0,1) | |
# Compute the loss | |
loss = criterion(pred, states) | |
# compute gradients | |
loss.backward() | |
# update parameters | |
optimizer.step() | |
running_loss += loss.item() # record loss | |
if epoch % show_every == 0: | |
print(f"Loss at {epoch}: {running_loss}") | |
# Use this to save plots of the fit every save_plots_every epochs | |
if save_plots_every is not None and epoch % save_plots_every == 0: | |
with torch.no_grad(): | |
fig, ax = plot_time_series(data.true_model, model, data[0]) | |
ax.set_title(f"Epoch: {epoch}") | |
fig.savefig(f"./tmp_plots/{epoch}_{model_name}_fit_plot") | |
plt.close() | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
true_mu = 0.30 | |
model_sim = VDP(mu=true_mu) | |
ts_data = torch.linspace(0.0,10.0,10) | |
data_vdp = create_sim_dataset(model_sim, | |
ts = ts_data, | |
num_samples=10, | |
sigma_noise=0.01) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
vdp_model = VDP(mu = 0.10) | |
plot_time_series(model_sim, | |
vdp_model, | |
data_vdp[0], | |
dyn_var_idx=1, | |
title = "VDP Model: Before Parameter Fits"); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
train(vdp_model, data_vdp, epochs=50, model_name="vdp"); | |
print(f"After training: {vdp_model}, where the true value is {true_mu}") | |
print(f"Final Parameter Recovery Error: {vdp_model.mu - true_mu}") | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
model_sim_lv = LotkaVolterra(1.5,1.0,3.0,1.0) | |
ts_data = torch.arange(0.0, 10.0, 0.1) | |
data_lv = create_sim_dataset(model_sim_lv, | |
ts = ts_data, | |
num_samples=10, | |
sigma_noise=0.1, | |
initial_conditions_default=torch.tensor([2.5, 2.5])) | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
model_lv = LotkaVolterra(alpha=1.6, beta=1.1,delta=2.7, gamma=1.2) | |
plot_time_series(model_sim_lv, model_lv, data = data_lv[0], title = "Lotka Volterra: Before Fitting"); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
train(model_lv, data_lv, epochs=60, lr=1e-2, model_name="lotkavolterra") | |
print(f"Fitted model: {model_lv}") | |
print(f"True model: {model_sim_lv}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
model_sim_lorenz = Lorenz(sigma=10.0, rho=28.0, beta=8.0/3.0) | |
ts_data = torch.arange(0, 10.0, 0.05) | |
data_lorenz = create_sim_dataset(model_sim_lorenz, | |
ts = ts_data, | |
num_samples=30, | |
initial_conditions_default=torch.tensor([1.0, 0.0, 0.0]), | |
sigma_noise=0.01, | |
sigma_initial_conditions=0.10) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
lorenz_model = Lorenz(sigma=10.2, rho=28.2, beta=9.0/3) | |
fig, ax = plot_time_series(model_sim_lorenz, lorenz_model, data_lorenz[0], title="Lorenz Model: Before Fitting"); | |
ax.set_xlim((2,15)); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
train(lorenz_model, | |
data_lorenz, | |
epochs=300, | |
batch_size=5, | |
method = 'rk4', | |
step_size=0.05, | |
show_every=50, | |
lr = 1e-3) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# remake the data | |
model_sim_vdp = VDP(mu=0.20) | |
ts_data = torch.linspace(0.0,30.0,100) # longer time series than the custom ode layer | |
data_vdp = create_sim_dataset(model_sim_vdp, | |
ts = ts_data, | |
num_samples=30, # more samples than the custom ode layer | |
sigma_noise=0.1, | |
initial_conditions_default=torch.tensor([0.50,0.10])) | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#| exports | |
class NeuralDiffEq(nn.Module): | |
""" | |
Basic Neural ODE model | |
""" | |
def __init__(self, | |
dim: int = 2, # dimension of the state vector | |
) -> None: | |
super().__init__() | |
self.ann = nn.Sequential(torch.nn.Linear(dim, 8), | |
torch.nn.LeakyReLU(), | |
torch.nn.Linear(8, 16), | |
torch.nn.LeakyReLU(), | |
torch.nn.Linear(16, 32), | |
torch.nn.LeakyReLU(), | |
torch.nn.Linear(32, dim)) | |
def forward(self, t, state): | |
return self.ann(state) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
train(model_vdp_nde, | |
data_vdp, | |
epochs=1500, | |
lr=1e-3, | |
batch_size=5, | |
show_every=100, | |
model_name = "nde") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment