Skip to content

Instantly share code, notes, and snippets.

@khannay
Created April 22, 2023 19:59
Show Gist options
  • Save khannay/b45760f995fc316d8f60f7f8bf62df71 to your computer and use it in GitHub Desktop.
Save khannay/b45760f995fc316d8f60f7f8bf62df71 to your computer and use it in GitHub Desktop.
Differential Equations as a Pytorch Neural Network Layer
#| export
import torch
import torch.nn as nn
from torchdiffeq import odeint as odeint
import pylab as plt
from torch.utils.data import Dataset, DataLoader
from typing import Callable, List, Tuple, Union, Optional
from pathlib import Path
#| exports
class VDP(nn.Module):
"""
Define the Van der Pol oscillator as a PyTorch module.
"""
def __init__(self,
mu: float, # Stiffness parameter of the VDP oscillator
):
super().__init__()
self.mu = torch.nn.Parameter(torch.tensor(mu)) # make mu a learnable parameter
def forward(self,
t: float, # time index
state: torch.TensorType, # state of the system first dimension is the batch size
) -> torch.Tensor: # return the derivative of the state
"""
Define the right hand side of the VDP oscillator.
"""
x = state[..., 0] # first dimension is the batch size
y = state[..., 1]
dX = self.mu*(x-1/3*x**3 - y)
dY = 1/self.mu*x
# trick to make sure our return value has the same shape as the input
dfunc = torch.zeros_like(state)
dfunc[..., 0] = dX
dfunc[..., 1] = dY
return dfunc
def __repr__(self):
"""Print the parameters of the model."""
return f" mu: {self.mu.item()}"
vdp_model = VDP(mu=0.5)
# Create a time vector, this is the time axis of the ODE
ts = torch.linspace(0,30.0,1000)
# Create a batch of initial conditions
batch_size = 30
# Creates some random initial conditions
initial_conditions = torch.tensor([0.01, 0.01]) + 0.2*torch.randn((batch_size,2))
# Solve the ODE, odeint comes from torchdiffeq
sol = odeint(vdp_model, initial_conditions, ts, method='dopri5').detach().numpy()
plt.plot(ts, sol[:,:,0], lw=0.5);
plt.title("Time series of the VDP oscillator");
plt.xlabel("time");
plt.ylabel("x");
# Check the solution
plt.plot(sol[:,:,0], sol[:,:,1], lw=0.5);
plt.title("Phase plot of the VDP oscillator");
plt.xlabel("x");
plt.ylabel("y");
#| exports
class LotkaVolterra(nn.Module):
"""
The Lotka-Volterra equations are a pair of first-order, non-linear, differential equations
describing the dynamics of two species interacting in a predator-prey relationship.
"""
def __init__(self,
alpha: float = 1.5, # The alpha parameter of the Lotka-Volterra system
beta: float = 1.0, # The beta parameter of the Lotka-Volterra system
delta: float = 3.0, # The delta parameter of the Lotka-Volterra system
gamma: float = 1.0 # The gamma parameter of the Lotka-Volterra system
) -> None:
super().__init__()
self.model_params = torch.nn.Parameter(torch.tensor([alpha, beta, delta, gamma]))
def forward(self, t, state):
x = state[...,0] #variables are part of vector array u
y = state[...,1]
sol = torch.zeros_like(state)
#coefficients are part of tensor model_params
alpha, beta, delta, gamma = self.model_params
sol[...,0] = alpha*x - beta*x*y
sol[...,1] = -delta*y + gamma*x*y
return sol
def __repr__(self):
return f" alpha: {self.model_params[0].item()}, \
beta: {self.model_params[1].item()}, \
delta: {self.model_params[2].item()}, \
gamma: {self.model_params[3].item()}"
lv_model = LotkaVolterra() #use default parameters
ts = torch.linspace(0,30.0,1000)
batch_size = 30
# Create a batch of initial conditions (batch_dim, state_dim) as small perturbations around one value
initial_conditions = torch.tensor([[3,3]]) + 0.50*torch.randn((batch_size,2))
sol = odeint(lv_model, initial_conditions, ts, method='dopri5').detach().numpy()
# Check the solution
plt.plot(ts, sol[:,:,0], lw=0.5);
plt.title("Time series of the Lotka-Volterra system");
plt.xlabel("time");
plt.ylabel("x");
plt.plot(sol[:,:,0], sol[:,:,1], lw=0.5);
plt.title("Phase plot of the Lotka-Volterra system");
plt.xlabel("x");
plt.ylabel("y");
#| exports
class Lorenz(nn.Module):
"""
Define the Lorenz system as a PyTorch module.
"""
def __init__(self,
sigma: float =10.0, # The sigma parameter of the Lorenz system
rho: float=28.0, # The rho parameter of the Lorenz system
beta: float=8.0/3, # The beta parameter of the Lorenz system
):
super().__init__()
self.model_params = torch.nn.Parameter(torch.tensor([sigma, rho, beta]))
def forward(self, t, state):
x = state[...,0] #variables are part of vector array u
y = state[...,1]
z = state[...,2]
sol = torch.zeros_like(state)
sigma, rho, beta = self.model_params #coefficients are part of vector array p
sol[...,0] = sigma*(y-x)
sol[...,1] = x*(rho-z) - y
sol[...,2] = x*y - beta*z
return sol
def __repr__(self):
return f" sigma: {self.model_params[0].item()}, \
rho: {self.model_params[1].item()}, \
beta: {self.model_params[2].item()}"
lorenz_model = Lorenz()
ts = torch.linspace(0,50.0,3000)
batch_size = 30
# Create a batch of initial conditions (batch_dim, state_dim) as small perturbations around one value
initial_conditions = torch.tensor([[1.0,0.0,0.0]]) + 0.10*torch.randn((batch_size,3))
sol = odeint(lorenz_model, initial_conditions, ts, method='dopri5').detach().numpy()
# Check the solution
plt.plot(ts[:2000], sol[:2000,:,0], lw=0.5);
plt.title("Time series of the Lorenz system");
plt.xlabel("time");
plt.ylabel("x");
plt.plot(sol[:,0,0], sol[:,0,1], color='black', lw=0.5);
plt.title("Phase plot of the Lorenz system");
plt.xlabel("x");
plt.ylabel("y");
#| exports
class SimODEData(Dataset):
"""
A very simple dataset class for simulating ODEs
"""
def __init__(self,
ts: List[torch.Tensor], # List of time points as tensors
values: List[torch.Tensor], # List of dynamical state values (tensor) at each time point
true_model: Union[torch.nn.Module,None] = None,
) -> None:
self.ts = ts
self.values = values
self.true_model = true_model
def __len__(self) -> int:
return len(self.ts)
def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]:
return self.ts[index], self.values[index]
#| export
def create_sim_dataset(model: nn.Module, # model to simulate from
ts: torch.Tensor, # Time points to simulate for
num_samples: int = 10, # Number of samples to generate
sigma_noise: float = 0.1, # Noise level to add to the data
initial_conditions_default: torch.Tensor = torch.tensor([0.0, 0.0]), # Default initial conditions
sigma_initial_conditions: float = 0.1, # Noise level to add to the initial conditions
) -> SimODEData:
ts_list = []
states_list = []
dim = initial_conditions_default.shape[0]
for i in range(num_samples):
x0 = sigma_initial_conditions * torch.randn((1,dim)).detach() + initial_conditions_default
ys = odeint(model, x0, ts).squeeze(1).detach()
ys += sigma_noise*torch.randn_like(ys)
ys[0,:] = x0 # Set the first value to the initial condition
ts_list.append(ts)
states_list.append(ys)
return SimODEData(ts_list, states_list, true_model=model)
#| export
def plot_time_series(true_model: torch.nn.Module, # true underlying model for the simulated data
fit_model: torch.nn.Module, # model fit to the data
data: SimODEData, # data set to plot (scatter)
time_range: tuple = (0.0, 30.0), # range of times to simulate the models for
ax: plt.Axes = None,
dyn_var_idx: int = 0,
title: str = "Model fits",
*args,
**kwargs) -> Tuple[plt.Figure, plt.Axes]:
"""
Plot the true model and fit model on the same axes.
"""
if ax is None:
fig, ax = plt.subplots()
else:
fig = ax.get_figure()
vdp_model = VDP(mu = 0.10)
ts = torch.linspace(time_range[0], time_range[1], 1000)
ts_data, y_data = data
initial_conditions = y_data[0, :].unsqueeze(0)
sol_pred = odeint(fit_model, initial_conditions, ts, method='dopri5').detach().numpy()
sol_true = odeint(true_model, initial_conditions, ts, method='dopri5').detach().numpy()
ax.plot(ts, sol_pred[:,:,dyn_var_idx], color='skyblue', lw=2.0, label='Predicted', **kwargs);
ax.scatter(ts_data.detach(), y_data[:,dyn_var_idx].detach(), color='black', s=30, label='Data', **kwargs);
ax.plot(ts, sol_true[:,:,dyn_var_idx], color='black', ls='--', lw=1.0, label='True model', **kwargs);
ax.set_title(title);
ax.set_xlabel("t");
ax.set_ylabel("y");
plt.legend();
return fig, ax
#| export
def plot_phase_plane(true_model: torch.nn.Module, # true underlying model for the simulated data
fit_model: torch.nn.Module, # model fit to the data
data: SimODEData, # data set to plot (scatter)
time_range: tuple = (0.0, 30.0), # range of times to simulate the models for
ax: plt.Axes = None,
dyn_var_idx: tuple = (0,1),
title: str = "Model fits",
*args,
**kwargs) -> Tuple[plt.Figure, plt.Axes]:
"""
Plot the true model and fit model on the same axes.
"""
if ax is None:
fig, ax = plt.subplots()
else:
fig = ax.get_figure()
ts = torch.linspace(time_range[0], time_range[1], 1000)
ts_data, y_data = data
initial_conditions = y_data[0, :].unsqueeze(0)
sol_pred = odeint(fit_model, initial_conditions, ts, method='dopri5').detach().numpy()
sol_true = odeint(true_model, initial_conditions, ts, method='dopri5').detach().numpy()
idx1, idx2 = dyn_var_idx
ax.plot(sol_pred[:,:,idx1], sol_pred[:,:,idx2], color='skyblue', lw=1.0, label='Fit model');
ax.scatter(y_data[:,idx1], y_data[:,idx2].detach(), color='black', s=30, label='Data');
ax.plot(sol_true[:,:,idx1], sol_true[:,:,idx2], color='black', ls='–', lw=1.0, label='True model');
ax.set_xlabel(r'x')
ax.set_ylabel(r'y')
ax.set_title(title)
return fig, ax
#| exports
def train(model: torch.nn.Module, # Model to train
data: SimODEData, # Data to train on
lr: float = 1e-2, # learning rate for the Adam optimizer
epochs: int = 10, # Number of epochs to train for
batch_size: int = 5, # Batch size for training
method = 'rk4', # ODE solver to use
step_size: float = 0.10, # for fixed diffeq solver set the step size
show_every: int = 10, # How often to print the loss function message
save_plots_every: Union[int,None] = None, # save a plot of the fit, to disable make this None
model_name: str = "", #string for the model, used to reference the saved plots
*args: tuple,
**kwargs: dict
):
# Create a data loader to iterate over the data. This takes in our dataset and returns batches of data
trainloader = DataLoader(data, batch_size=batch_size, shuffle=True)
# Choose an optimizer. Adam is a good default choice as a fancy gradient descent
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
# Create a loss function this computes the error between the predicted and true values
criterion = torch.nn.MSELoss()
for epoch in range(epochs):
running_loss = 0.0
for batchdata in trainloader:
optimizer.zero_grad() # reset gradients, famous gotcha in a pytorch training loop
ts, states = batchdata # unpack the data
initial_state = states[:,0,:] # grab the initial state
# Make the prediction and then flip the dimensions to be (batch, state_dim, time)
# Pytorch expects the batch dimension to be first
pred = odeint(model, initial_state, ts[0], method=method, options={'step_size': step_size}).transpose(0,1)
# Compute the loss
loss = criterion(pred, states)
# compute gradients
loss.backward()
# update parameters
optimizer.step()
running_loss += loss.item() # record loss
if epoch % show_every == 0:
print(f"Loss at {epoch}: {running_loss}")
# Use this to save plots of the fit every save_plots_every epochs
if save_plots_every is not None and epoch % save_plots_every == 0:
with torch.no_grad():
fig, ax = plot_time_series(data.true_model, model, data[0])
ax.set_title(f"Epoch: {epoch}")
fig.savefig(f"./tmp_plots/{epoch}_{model_name}_fit_plot")
plt.close()
true_mu = 0.30
model_sim = VDP(mu=true_mu)
ts_data = torch.linspace(0.0,10.0,10)
data_vdp = create_sim_dataset(model_sim,
ts = ts_data,
num_samples=10,
sigma_noise=0.01)
vdp_model = VDP(mu = 0.10)
plot_time_series(model_sim,
vdp_model,
data_vdp[0],
dyn_var_idx=1,
title = "VDP Model: Before Parameter Fits");
train(vdp_model, data_vdp, epochs=50, model_name="vdp");
print(f"After training: {vdp_model}, where the true value is {true_mu}")
print(f"Final Parameter Recovery Error: {vdp_model.mu - true_mu}")
model_sim_lv = LotkaVolterra(1.5,1.0,3.0,1.0)
ts_data = torch.arange(0.0, 10.0, 0.1)
data_lv = create_sim_dataset(model_sim_lv,
ts = ts_data,
num_samples=10,
sigma_noise=0.1,
initial_conditions_default=torch.tensor([2.5, 2.5]))
model_lv = LotkaVolterra(alpha=1.6, beta=1.1,delta=2.7, gamma=1.2)
plot_time_series(model_sim_lv, model_lv, data = data_lv[0], title = "Lotka Volterra: Before Fitting");
train(model_lv, data_lv, epochs=60, lr=1e-2, model_name="lotkavolterra")
print(f"Fitted model: {model_lv}")
print(f"True model: {model_sim_lv}")
model_sim_lorenz = Lorenz(sigma=10.0, rho=28.0, beta=8.0/3.0)
ts_data = torch.arange(0, 10.0, 0.05)
data_lorenz = create_sim_dataset(model_sim_lorenz,
ts = ts_data,
num_samples=30,
initial_conditions_default=torch.tensor([1.0, 0.0, 0.0]),
sigma_noise=0.01,
sigma_initial_conditions=0.10)
lorenz_model = Lorenz(sigma=10.2, rho=28.2, beta=9.0/3)
fig, ax = plot_time_series(model_sim_lorenz, lorenz_model, data_lorenz[0], title="Lorenz Model: Before Fitting");
ax.set_xlim((2,15));
train(lorenz_model,
data_lorenz,
epochs=300,
batch_size=5,
method = 'rk4',
step_size=0.05,
show_every=50,
lr = 1e-3)
# remake the data
model_sim_vdp = VDP(mu=0.20)
ts_data = torch.linspace(0.0,30.0,100) # longer time series than the custom ode layer
data_vdp = create_sim_dataset(model_sim_vdp,
ts = ts_data,
num_samples=30, # more samples than the custom ode layer
sigma_noise=0.1,
initial_conditions_default=torch.tensor([0.50,0.10]))
#| exports
class NeuralDiffEq(nn.Module):
"""
Basic Neural ODE model
"""
def __init__(self,
dim: int = 2, # dimension of the state vector
) -> None:
super().__init__()
self.ann = nn.Sequential(torch.nn.Linear(dim, 8),
torch.nn.LeakyReLU(),
torch.nn.Linear(8, 16),
torch.nn.LeakyReLU(),
torch.nn.Linear(16, 32),
torch.nn.LeakyReLU(),
torch.nn.Linear(32, dim))
def forward(self, t, state):
return self.ann(state)
train(model_vdp_nde,
data_vdp,
epochs=1500,
lr=1e-3,
batch_size=5,
show_every=100,
model_name = "nde")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment