Skip to content

Instantly share code, notes, and snippets.

@mrernst
Created February 16, 2024 10:45
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save mrernst/aa92bf55f3889e366453572c48781e01 to your computer and use it in GitHub Desktop.
Save mrernst/aa92bf55f3889e366453572c48781e01 to your computer and use it in GitHub Desktop.
import numpy as np
import numba
from numba import njit
import torch
from torch import nn
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.animation as animation
device = (
"cuda"
if torch.cuda.is_available()
else "mps"
if torch.backends.mps.is_available()
else "cpu"
)
# 3d visualization
def visualize_3d(obs=None, noiseless_traj=None,times=None, trajs=None, save=None, title=''):
fig = plt.figure(figsize=(5,5))
ax = fig.add_subplot(1, 1, 1, projection='3d')
if title is not None:
ax.set_title('Trajectory\n'+title)
if noiseless_traj is not None:
z = np.array([o for o in noiseless_traj])
z = np.reshape(z, [-1,3])
for i in range(len(z)):
ax.plot(z[i:i+10, 0], z[i:i+10, 1], z[i:i+10, 2], color=plt.cm.jet(i/len(z)/1.6))
if obs is not None:
z = np.array([o for o in obs])
z = np.reshape(z, [-1,3])
ax.scatter(z[:,0], z[:,1], z[:,2], marker='.', color='k', alpha=0.5, linewidths=0, s=45)
if trajs is not None:
z = np.array([o for o in trajs])
z = np.reshape(z, [-1,3])
for i in range(len(z)):
ax.plot(z[i:i+10, 0], z[i:i+10, 1], z[i:i+10, 2], color='r', alpha=0.3)
fig.canvas.draw()
fig.canvas.flush_events()
if save is not None:
plt.savefig(save+'.png', format='png', dpi=400, bbox_inches ='tight', pad_inches = 0.1)
plt.show()
# Runge Kutta 4th order integration function
def RK4(t_array,y_0,dydt):
dt = t_array[1] - t_array[0]
y_array = torch.zeros([len(t_array)+1,3]).to(device)
y_array[0,:] = y_0
for i in range(len(t_array)):
_, y_array[i+1,:] = rk4_step(t_array[i], y_array[i:i+1], dt, dydt)
return y_array
def rk4_step(t, y, dt, dydt):
# Calculate slopes
k1 = dt*dydt(t, y)
k2 = dt*dydt(t+dt/2., y+k1/2.)
k3 = dt*dydt(t+dt/2., y+k2/2.)
k4 = dt*dydt(t+dt, y+k3)
# Calculate new x and y
y = y + 1./6*(k1+2*k2+2*k3+k4)
t = t + dt
return t, y
# Lorenz system dynamics
class Lorenz(nn.Module):
"""
chaotic lorenz system
"""
def __init__(self):
super(Lorenz, self).__init__()
self.lin = nn.Linear(5, 3, bias=False)
W = torch.tensor([[-10., 10., 0., 0., 0.],
[28., -1., 0., -1., 0.],
[0., 0., -8. / 3., 0., 1.]])
self.lin.weight = nn.Parameter(W)
def forward(self, t, x):
y = y = torch.ones([1, 5]).to(device)
y[0][0] = x[0][0]
y[0][1] = x[0][1]
y[0][2] = x[0][2]
y[0][3] = x[0][0] * x[0][2]
y[0][4] = x[0][0] * x[0][1]
return self.lin(y)
# recurrent neural network module
class fcRNN(nn.Module):
def __init__(self, input_size, hidden_dim, output_size, n_layers):
super(fcRNN, self).__init__()
self.hidden_dim = hidden_dim
self.n_layers = n_layers
self.rnn = nn.RNN(input_size,
hidden_dim, n_layers,
nonlinearity='relu',
batch_first=True) # RNN hidden units
self.fc = nn.Linear(hidden_dim, output_size) # output layer
def forward(self, x):
bs, _, _ = x.shape
h0 = torch.zeros(self.n_layers, bs, self.hidden_dim).requires_grad_().to(device)
out, hidden = self.rnn(x, h0.detach())
out = out.view(bs, -1, self.hidden_dim)
out = self.fc(out)
return out[:, -1, :]
# split the ground truth into training sequences
def create_inout_sequences(input_data, tw):
train_data = []
labels = []
L = len(input_data)
for i in range(L-tw):
train_seq = input_data[i:i+tw]
train_label = input_data[i+tw:i+tw+1]
train_data.append(train_seq.squeeze(1).unsqueeze(0))
labels.append(train_label.squeeze(1).unsqueeze(0))
return torch.cat(train_data, 0), torch.cat(labels, 0)
def model_predict(pred_len, model, initial_condition, train_window):
test_inputs = initial_condition
pred_traj = []
for i in range(train_window):
pred_traj.append(test_inputs[i])
with torch.no_grad():
for i in range(pred_len):
seq = torch.stack(pred_traj[-train_window:])
#print(seq.shape)
model_out = model(seq.unsqueeze(0)).squeeze().detach().reshape(-1)
pred_traj.append(model_out)
pred_traj = torch.stack(pred_traj)
return pred_traj
# main program
def main():
true_y0 = torch.tensor([[-8., 7., 27.]]).to(device) # initial condition
t = torch.linspace(0., 5., 500).to(device) # t is from 0 to 5 for 500 data points
L = Lorenz().to(device)
# generating training data
with torch.no_grad():
true_lorenz = RK4(t,true_y0, L)
visualize_3d(noiseless_traj=true_lorenz.cpu())
# additional training data
training_data = [true_lorenz]
with torch.no_grad():
for i in range(10):
random_perturbation = torch.rand(3)*3
random_perturbation = random_perturbation.to(device)
training_data.append(RK4(t,true_y0*random_perturbation, L))
visualize_3d(noiseless_traj=torch.cat(training_data).cpu())
observations = true_lorenz
feature_size = observations.shape[1]
model = fcRNN(feature_size, 128, feature_size, 2).to(device)
pytorch_total_params = sum(p.numel() for p in model.parameters())
print("Model has a total of {0} parameters.".format(pytorch_total_params))
list_of_predictions = []
train_window = 15
# generating training data and labels for RNN
data_train, labels = [], []
for observations in training_data:
d_t, lab = create_inout_sequences(observations, train_window)
data_train.append(d_t)
labels.append(lab)
data_train, labels = torch.cat(data_train), torch.cat(labels)
data_train.shape
epochs = 10000
loss_arr = []
loss_function = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
# training loop
for i in range(epochs + 1):
if i > 1200:
optimizer.param_groups[0]['lr'] = 1e-3
optimizer.zero_grad()
y_pred = model(data_train)
#print(y_pred.shape)
loss = loss_function(y_pred, labels.squeeze())
loss.backward()
optimizer.step()
loss_arr.append(loss.item())
if i%100 == 0:
print(f'epoch: {i:3} loss: {loss_arr[-1]:10.8f}')
pred_len = 490
test_inputs = data_train[0].squeeze().detach()
pred_traj = model_predict(pred_len, model, test_inputs, train_window)
save_name = "PNG/"+str(i)
list_of_predictions.append(pred_traj.cpu())
#visualize_3d(obs=true_lorenz.cpu().detach().numpy(),
# noiseless_traj=pred_traj.cpu(),
# save=None)
print(f'epoch: {i:3} loss: {loss_arr[-1]:10.10f}')
# animation and visualization
# ---
# multicolored temporal evolution
noiseless_traj = list_of_predictions[0]
obs = true_lorenz.cpu()
list_of_lines = []
fig = plt.figure(figsize=(5,5))
sns.set(style="ticks", context="paper")
sns.set_palette("bright")
plt.style.use("dark_background")
ax = fig.add_subplot(1, 1, 1, projection='3d')
ax.set_facecolor('#30363f')
ax.set_title('Lorenz 63')
z = np.array([o for o in noiseless_traj])
z = np.reshape(z, [-1,3])
for i in range(len(z)):
line = ax.plot(z[i:i+10, 0], z[i:i+10, 1], z[i:i+10, 2], color=plt.cm.magma(i/len(z)))
list_of_lines.append(line)
z = np.array([o for o in obs])
z = np.reshape(z, [-1,3])
scat = ax.scatter(z[:,0], z[:,1], z[:,2], marker='.', color='k', alpha=0.5, linewidths=0, s=45)
ax.set_xlim([-20,20])
ax.set_ylim([-20,20]) # this shouldn't be necessary but the limits are usually enlarged per defailt
ax.set_zlim([0,40]) # this shouldn't be necessary but the limits are usually enlarged per defailt
def update(frame):
noiseless_traj = list_of_predictions[frame]
z = np.array([o for o in noiseless_traj])
z = np.reshape(z, [-1,3])
for i in range(len(z)):
list_of_lines[i][0].set_data_3d(z[i:i+10, 0], z[i:i+10, 1], z[i:i+10, 2])
#print(z[i:i+10, 0])
ani = animation.FuncAnimation(fig=fig, func=update, frames=len(list_of_predictions), interval=100)
plt.show()
ani.save(filename="./lorenz_rrn_temp.mp4", writer="ffmpeg", savefig_kwargs={'facecolor': '#30363f'})
# spaghettiplot animation
obs = true_lorenz.cpu()
list_of_lines = []
fig = plt.figure(figsize=(5,5))
sns.set(style="ticks", context="paper")
sns.set_palette("bright")
color_palette = sns.color_palette("magma", len(list_of_predictions))
plt.style.use("dark_background")
ax = fig.add_subplot(1, 1, 1, projection='3d')
ax.set_facecolor('#30363f')
ax.set_title('Lorenz 63')
for i,noiseless_traj in enumerate(list_of_predictions):
z = np.array([o for o in noiseless_traj])
z = np.reshape(z, [-1,3])
line = ax.plot(z[0:1, 0], z[0:1, 1], z[0:1, 2], color=color_palette[i])
list_of_lines.append(line)
z = np.array([o for o in obs])
z = np.reshape(z, [-1,3])
ax.scatter(z[:,0], z[:,1], z[:,2], marker='.', color='k', alpha=0.5, linewidths=0, s=45)
scat = ax.plot(z[:1,0], z[:1,1], z[:1,2], marker='.', color='blue', linestyle="", alpha=1.0, markersize=5)
ax.set_xlim([-20,20])
ax.set_ylim([-20,20]) # this shouldn't be necessary but the limits are usually enlarged per default
ax.set_zlim([0,40]) # this shouldn't be necessary but the limits are usually enlarged per default
def update(frame):
for i, noiseless_traj in enumerate(list_of_predictions):
z = np.array([o for o in noiseless_traj])
z = np.reshape(z, [-1,3])
list_of_lines[i][0].set_data_3d(z[:frame+1, 0], z[:frame+1, 1], z[:frame+1, 2])
scat[0].set_data(obs[frame][0], obs[frame][1])
scat[0].set_3d_properties(obs[frame][2])
#scat.set_3d_properties(obs[frame][2])
#print(z[i:i+10, 0])
ani = animation.FuncAnimation(fig=fig, func=update, frames=len(list_of_predictions[0]), interval=30)
plt.show()
ani.save(filename="./lorenz_rrn_spaghetti_training.mp4", writer="ffmpeg", savefig_kwargs= {'facecolor': '#30363f'})
pass
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment