Skip to content

Instantly share code, notes, and snippets.

@khannay
Created April 17, 2023 00:51
Show Gist options
  • Save khannay/08e6b1c3c9ccd0a68d2b101b94128c79 to your computer and use it in GitHub Desktop.
Save khannay/08e6b1c3c9ccd0a68d2b101b94128c79 to your computer and use it in GitHub Desktop.
Differential Equations as a Pytorch Neural Network Layer
#| export
import torch
import torch.nn as nn
from torchdiffeq import odeint_adjoint as odeint
import pylab as plt
from torch.utils.data import Dataset, DataLoader
from typing import Callable, List, Tuple, Union, Optional
if torch.cuda.is_available():
device = torch.device('cuda')
else:
device = torch.device('cpu')
#| exports
class VDP(nn.Module):
"""
Define the Van der Pol oscillator as a PyTorch module.
"""
def __init__(self,
mu: float, # Stiffness parameter of the VDP oscillator
):
super().__init__()
self.mu = torch.nn.Parameter(torch.tensor(mu)) # make mu a learnable parameter
def forward(self,
t: float, # time index
state: torch.TensorType, # state of the system first dimension is the batch size
) -> torch.Tensor: # return the derivative of the state
"""
Define the right hand side of the VDP oscillator.
"""
x = state[:, 0] # first dimension is the batch size
y = state[:, 1]
dX = self.mu*(x-1/3*x**3 - y)
dY = 1/self.mu*x
dfunc = torch.zeros_like(state) # trick to make sure our ret
dfunc[:, 0] = dX
dfunc[:, 1] = dY
return dfunc
def __repr__(self):
"""Print the parameters of the model."""
return f" mu: {self.mu.item()}"
vdp_model = VDP(mu=0.5)
# Create a time vector, this is the time axis of the ODE
ts = torch.linspace(0,30.0,1000)
# Create a batch of initial conditions
batch_size = 30
# Creates some random initial conditions
initial_conditions = torch.tensor([0.01, 0.01]) + 0.2*torch.randn((batch_size,2))
# Solve the ODE, odeint comes from torchdiffeq
sol = odeint(vdp_model, initial_conditions, ts, method='dopri5').detach().numpy()
# Check the solution
plt.plot(sol[:,:,0], sol[:,:,1], color='black', lw=0.5);
plt.title("Phase plot of the VDP oscillator");
plt.xlabel("x");
plt.ylabel("y");
#| exports
class LotkaVolterra(nn.Module):
"""
The Lotka-Volterra equations are a pair of first-order, non-linear, differential equations
describing the dynamics of two species interacting in a predator-prey relationship.
"""
def __init__(self,
alpha: float = 1.5, # The alpha parameter of the Lotka-Volterra system
beta: float = 1.0, # The beta parameter of the Lotka-Volterra system
delta: float = 3.0, # The delta parameter of the Lotka-Volterra system
gamma: float = 1.0 # The gamma parameter of the Lotka-Volterra system
) -> None:
super().__init__()
self.model_params = torch.nn.Parameter(torch.tensor([alpha, beta, delta, gamma]))
def forward(self, t, state):
x = state[:,0] #variables are part of vector array u
y = state[:,1]
sol = torch.zeros_like(state)
alpha, beta, delta, gamma = self.model_params #coefficients are part of vector array p
sol[:,0] = alpha*x - beta*x*y
sol[:,1] = -delta*y + gamma*x*y
return sol
def __repr__(self):
return f" alpha: {self.model_params[0].item()}, beta: {self.model_params[1].item()}, delta: {self.model_params[2].item()}, gamma: {self.model_params[3].item()}"
lv_model = LotkaVolterra()
ts = torch.linspace(0,30.0,1000)
# Create a batch of initial conditions (batch_dim, state_dim) as small perturbations around one value
initial_conditions = torch.tensor([[3,3]]) + 0.50*torch.randn((30,2))
sol = odeint(lv_model, initial_conditions, ts, method='dopri5').detach().numpy()
# Check the solution
plt.plot(ts, sol[:,:,0], lw=0.5);
plt.title("Time series of the Lotka-Volterra system");
plt.xlabel("time");
plt.ylabel("x");
plt.plot(sol[:,:,0], sol[:,:,1], lw=0.5);
plt.title("Phase plot of the Lotka-Volterra system");
plt.xlabel("x");
plt.ylabel("y");
#| exports
class Lorenz(nn.Module):
"""
Define the Lorenz system as a PyTorch module.
"""
def __init__(self,
sigma: float =10.0, # The sigma parameter of the Lorenz system
rho: float=28.0, # The rho parameter of the Lorenz system
beta: float=8.0/3, # The beta parameter of the Lorenz system
):
super().__init__()
self.model_params = torch.nn.Parameter(torch.tensor([sigma, rho, beta]))
def forward(self, t, state):
x = state[:,0] #variables are part of vector array u
y = state[:,1]
z = state[:,2]
sol = torch.zeros_like(state)
sigma, rho, beta = self.model_params #coefficients are part of vector array p
sol[:,0] = sigma*(y-x)
sol[:,1] = x*(rho-z) - y
sol[:,2] = x*y - beta*z
return sol
def __repr__(self):
return f" sigma: {self.model_params[0].item()}, rho: {self.model_params[1].item()}, beta: {self.model_params[2].item()}"
lorenz_model = Lorenz()
ts = torch.linspace(0,50.0,3000)
# Create a batch of initial conditions (batch_dim, state_dim) as small perturbations around one value
initial_conditions = torch.tensor([[1.0,0.0,0.0]]) + 0.10*torch.randn((30,3))
sol = odeint(lorenz_model, initial_conditions, ts, method='dopri5').detach().numpy()
# Check the solution
plt.plot(ts[:2000], sol[:2000,:,0], lw=0.5);
plt.title("Time series of the Lorenz system");
plt.xlabel("time");
plt.ylabel("x");
plt.plot(sol[:,0,0], sol[:,0,1], color='black', lw=0.5);
plt.title("Phase plot of the Lorenz system");
plt.xlabel("x");
plt.ylabel("y");
#| exports
class SimODEData(Dataset):
"""
A very simple dataset class for simulating ODEs
"""
def __init__(self,
ts: List[torch.Tensor], # List of time points as tensors
values: List[torch.Tensor], # List of dynamical state values (tensor) at each time point
) -> None:
self.ts = ts
self.values = values
def __len__(self) -> int:
return len(self.ts)
def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]:
return self.ts[index], self.values[index]
#| export
def create_sim_dataset(model: nn.Module, # model to simulate from
ts: torch.Tensor, # Time points to simulate at
num_samples: int = 10, # Number of samples to generate
sigma_noise: float = 0.1, # Noise level to add to the data
initial_conditions_default: torch.Tensor = torch.tensor([0.0, 0.0]), # Default initial conditions
sigma_initial_conditions: float = 0.1, # Noise level to add to the initial conditions
) -> SimODEData:
ts_list = []
states_list = []
dim = initial_conditions_default.shape[0]
for i in range(num_samples):
x0 = sigma_initial_conditions * torch.randn((1,dim)).detach() + initial_conditions_default
ys = odeint(model, x0, ts).squeeze(1).detach()
ys += sigma_noise*torch.randn_like(ys)
ts_list.append(ts)
states_list.append(ys)
return SimODEData(ts_list, states_list)
#| exports
def train(model: torch.nn.Module, # Model to train
data: SimODEData, # Data to train on
lr: float = 1e-2, # learning rate for the Adam optimizer
epochs: int = 10, # Number of epochs to train for
batch_size: int = 5, # Batch size for training
):
trainloader = DataLoader(data, batch_size=batch_size, shuffle=True)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
criterion = torch.nn.MSELoss()
for epoch in range(epochs):
running_loss = 0.0
for data in trainloader:
optimizer.zero_grad() # reset gradients
ts, states = data
initial_state = states[:,0,:] # grab the initial state
pred = odeint(model, initial_state, ts[0]).transpose(0,1)
loss = criterion(pred, states)
loss.backward() # compute gradients
optimizer.step() # update parameters
running_loss += loss.item() # record loss
if epoch % 10 == 0:
print(f"Loss at {epoch}: {running_loss}")
true_mu = 0.30
model_sim = VDP(mu=true_mu)
ts_data = torch.linspace(0.0,10.0,10)
data_vdp = create_sim_dataset(model_sim,
ts = ts_data,
num_samples=10,
sigma_noise=0.01)
vdp_model = VDP(mu = 0.10)
ts = torch.linspace(0,30.0,1000)
ts_data, y_data = data_vdp[0]
initial_conditions = y_data[0, :].unsqueeze(0)
sol_initial_guess = odeint(vdp_model, initial_conditions, ts, method='dopri5').detach().numpy()
sol_true = odeint(model_sim, initial_conditions, ts, method='dopri5').detach().numpy()
# Check the solution
plt.plot(ts, sol_initial_guess[:,:,1], color='black', lw=1.0, label='Unfit model');
plt.scatter(ts_data.detach(), y_data[:,1].detach(), color='red', s=30, label='Data');
plt.plot(ts, sol_true[:,:,1], color='red', ls='--', lw=1.0, label='True model');
plt.title("VDP Model: Initial Parameter Guess");
plt.xlabel("t");
plt.ylabel("y");
plt.legend();
train(vdp_model, data_vdp, epochs=50)
print(f"After training: {vdp_model}, where the true value is {true_mu}")
print(f"Final Parameter Recovery Error: {vdp_model.mu - true_mu}")
sol_after_fitting = odeint(vdp_model, initial_conditions, ts, method='dopri5').detach().numpy()
# Check the solution
plt.plot(ts, sol_after_fitting[:,:,1], color='black', lw=1.0, label='Fit model');
plt.scatter(ts_data.detach(), y_data[:,1].detach(), color='red', s=30, label='Data');
plt.plot(ts, sol_true[:,:,1], color='red', ls='--', lw=1.0, label='True model');
plt.title("VDP Model: After Training");
plt.xlabel("t");
plt.ylabel("y");
plt.legend();
plt.plot(sol_after_fitting[:,:,0], sol_after_fitting[:,:,1], color='black', lw=1.0, label='Fit model');
plt.scatter(y_data[:,0], y_data[:,1].detach(), color='red', s=30, label='Data');
plt.plot(sol_true[:,:,0], sol_true[:,:,1], color='red', ls='--', lw=1.0, label='True model');
plt.title("VDP Model: After Training");
plt.xlabel("t");
plt.ylabel("y");
plt.legend();
model_sim_lv = LotkaVolterra()
ts_data = torch.linspace(0.0,10.0,10)
data_lv = create_sim_dataset(model_sim_lv,
ts = ts_data,
num_samples=10,
sigma_noise=0.1,
initial_conditions_default=torch.tensor([2.5, 2.5]))
print(model_sim_lv)
model_lv = LotkaVolterra(alpha=1.6, beta=1.1,delta=2.7, gamma=1.2)
ts = torch.linspace(0,30.0,1000)
ts_data, y_data = data_lv[0]
initial_conditions = y_data[0, :].unsqueeze(0)
sol_initial_guess = odeint(model_lv, initial_conditions, ts, method='dopri5').detach().numpy()
sol_true = odeint(model_sim_lv, initial_conditions, ts, method='dopri5').detach().numpy()
# Check the solution
plt.plot(ts, sol_initial_guess[:,:,1], color='black', lw=1.0, label='Unfit model');
plt.scatter(ts_data.detach(), y_data[:,1].detach(), color='red', s=30, label='Data');
plt.plot(ts, sol_true[:,:,1], color='red', ls='--', lw=1.0, label='True model');
plt.title("Lotka-Volterra Model: Initial Parameter Guess");
plt.xlabel("t");
plt.ylabel("y");
plt.legend();
train(model_lv, data_lv, epochs=60, lr=1e-2)
print(f"Fitted model: {model_lv}")
print(f"True model: {model_sim_lv}")
initial_conditions = y_data[0, :].unsqueeze(0)
sol_after_fitting = odeint(model_lv, initial_conditions, ts, method='dopri5').detach().numpy()
# Check the solution
plt.plot(ts, sol_after_fitting[:,:,1], color='black', lw=2.0, label='Fit model');
plt.scatter(ts_data, y_data[:,1], color='red', s=30, label='Data');
plt.plot(ts, sol_true[:,:,1], color='red', ls='--', lw=1.0, label='True model');
plt.title("Lotka-Voltera: After Training");
plt.xlabel("t");
plt.ylabel("y");
plt.legend();
# Check the solution
plt.plot(sol_after_fitting[:,:,0], sol_after_fitting[:,:,1], color='black', lw=2.0, label='Fit model');
plt.scatter(y_data[:,0], y_data[:,1], color='red', s=30, label='Data');
plt.plot(sol_true[:,:,0], sol_true[:,:,1], color='red', ls='--', lw=1.0, label='True model');
plt.title("Lotka-Voltera: After Training");
plt.xlabel("t");
plt.ylabel("y");
plt.legend();
model_sim = Lorenz(sigma=10.0, rho=28.0, beta=8.0/3.0)
ts_data = torch.linspace(0.0,10.0,10)
data_lorenz = create_sim_dataset(model_sim,
ts = ts_data,
num_samples=10,
initial_conditions_default=torch.tensor([8.0, 8.0, 27.0]),
sigma_noise=0.01)
_, y_data = data_lorenz[0]
plt.plot(ts_data, y_data[:,0]);
plt.plot(ts_data, y_data[:,1]);
plt.title("Simulated Data: Lorenz Attractor");
lorenz_model = Lorenz(sigma=11.0, rho=23.0, beta=1.2)
print(f"The starting point: {lorenz_model}")
ts = torch.linspace(0,30.0,1000)
ts_data, y_data = data_lorenz[0]
initial_conditions = y_data[0, :].unsqueeze(0)
sol = odeint(lorenz_model, initial_conditions, ts, method='dopri5').detach().numpy()
# Check the solution
plt.plot(ts, sol[:,:,1], color='black', lw=1.0);
plt.scatter(ts_data.detach(), y_data[:,1].detach(), color='red', s=30);
plt.title("Lorenz plot before fitting");
plt.xlabel("t");
plt.ylabel("y");
train(lorenz_model,
data_lorenz,
epochs=10,
lr = 1e-3)
ts = torch.linspace(0,30.0,1000)
ts_data, y_data = data_lorenz[0]
initial_conditions = y_data[0, :].unsqueeze(0)
sol = odeint(lorenz_model, initial_conditions, ts, method='dopri5').detach().numpy()
# Check the solution
plt.plot(ts, sol[:,:,1], color='black', lw=1.0);
plt.scatter(ts_data.detach(), y_data[:,1].detach(), color='red', s=30);
plt.title("Lorenz plot before fitting");
plt.xlabel("t");
plt.ylabel("y");
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment