Created
February 16, 2024 10:45
-
-
Save mrernst/aa92bf55f3889e366453572c48781e01 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import numba | |
from numba import njit | |
import torch | |
from torch import nn | |
import seaborn as sns | |
import matplotlib.pyplot as plt | |
import matplotlib.animation as animation | |
device = ( | |
"cuda" | |
if torch.cuda.is_available() | |
else "mps" | |
if torch.backends.mps.is_available() | |
else "cpu" | |
) | |
# 3d visualization | |
def visualize_3d(obs=None, noiseless_traj=None,times=None, trajs=None, save=None, title=''): | |
fig = plt.figure(figsize=(5,5)) | |
ax = fig.add_subplot(1, 1, 1, projection='3d') | |
if title is not None: | |
ax.set_title('Trajectory\n'+title) | |
if noiseless_traj is not None: | |
z = np.array([o for o in noiseless_traj]) | |
z = np.reshape(z, [-1,3]) | |
for i in range(len(z)): | |
ax.plot(z[i:i+10, 0], z[i:i+10, 1], z[i:i+10, 2], color=plt.cm.jet(i/len(z)/1.6)) | |
if obs is not None: | |
z = np.array([o for o in obs]) | |
z = np.reshape(z, [-1,3]) | |
ax.scatter(z[:,0], z[:,1], z[:,2], marker='.', color='k', alpha=0.5, linewidths=0, s=45) | |
if trajs is not None: | |
z = np.array([o for o in trajs]) | |
z = np.reshape(z, [-1,3]) | |
for i in range(len(z)): | |
ax.plot(z[i:i+10, 0], z[i:i+10, 1], z[i:i+10, 2], color='r', alpha=0.3) | |
fig.canvas.draw() | |
fig.canvas.flush_events() | |
if save is not None: | |
plt.savefig(save+'.png', format='png', dpi=400, bbox_inches ='tight', pad_inches = 0.1) | |
plt.show() | |
# Runge Kutta 4th order integration function | |
def RK4(t_array,y_0,dydt): | |
dt = t_array[1] - t_array[0] | |
y_array = torch.zeros([len(t_array)+1,3]).to(device) | |
y_array[0,:] = y_0 | |
for i in range(len(t_array)): | |
_, y_array[i+1,:] = rk4_step(t_array[i], y_array[i:i+1], dt, dydt) | |
return y_array | |
def rk4_step(t, y, dt, dydt): | |
# Calculate slopes | |
k1 = dt*dydt(t, y) | |
k2 = dt*dydt(t+dt/2., y+k1/2.) | |
k3 = dt*dydt(t+dt/2., y+k2/2.) | |
k4 = dt*dydt(t+dt, y+k3) | |
# Calculate new x and y | |
y = y + 1./6*(k1+2*k2+2*k3+k4) | |
t = t + dt | |
return t, y | |
# Lorenz system dynamics | |
class Lorenz(nn.Module): | |
""" | |
chaotic lorenz system | |
""" | |
def __init__(self): | |
super(Lorenz, self).__init__() | |
self.lin = nn.Linear(5, 3, bias=False) | |
W = torch.tensor([[-10., 10., 0., 0., 0.], | |
[28., -1., 0., -1., 0.], | |
[0., 0., -8. / 3., 0., 1.]]) | |
self.lin.weight = nn.Parameter(W) | |
def forward(self, t, x): | |
y = y = torch.ones([1, 5]).to(device) | |
y[0][0] = x[0][0] | |
y[0][1] = x[0][1] | |
y[0][2] = x[0][2] | |
y[0][3] = x[0][0] * x[0][2] | |
y[0][4] = x[0][0] * x[0][1] | |
return self.lin(y) | |
# recurrent neural network module | |
class fcRNN(nn.Module): | |
def __init__(self, input_size, hidden_dim, output_size, n_layers): | |
super(fcRNN, self).__init__() | |
self.hidden_dim = hidden_dim | |
self.n_layers = n_layers | |
self.rnn = nn.RNN(input_size, | |
hidden_dim, n_layers, | |
nonlinearity='relu', | |
batch_first=True) # RNN hidden units | |
self.fc = nn.Linear(hidden_dim, output_size) # output layer | |
def forward(self, x): | |
bs, _, _ = x.shape | |
h0 = torch.zeros(self.n_layers, bs, self.hidden_dim).requires_grad_().to(device) | |
out, hidden = self.rnn(x, h0.detach()) | |
out = out.view(bs, -1, self.hidden_dim) | |
out = self.fc(out) | |
return out[:, -1, :] | |
# split the ground truth into training sequences | |
def create_inout_sequences(input_data, tw): | |
train_data = [] | |
labels = [] | |
L = len(input_data) | |
for i in range(L-tw): | |
train_seq = input_data[i:i+tw] | |
train_label = input_data[i+tw:i+tw+1] | |
train_data.append(train_seq.squeeze(1).unsqueeze(0)) | |
labels.append(train_label.squeeze(1).unsqueeze(0)) | |
return torch.cat(train_data, 0), torch.cat(labels, 0) | |
def model_predict(pred_len, model, initial_condition, train_window): | |
test_inputs = initial_condition | |
pred_traj = [] | |
for i in range(train_window): | |
pred_traj.append(test_inputs[i]) | |
with torch.no_grad(): | |
for i in range(pred_len): | |
seq = torch.stack(pred_traj[-train_window:]) | |
#print(seq.shape) | |
model_out = model(seq.unsqueeze(0)).squeeze().detach().reshape(-1) | |
pred_traj.append(model_out) | |
pred_traj = torch.stack(pred_traj) | |
return pred_traj | |
# main program | |
def main(): | |
true_y0 = torch.tensor([[-8., 7., 27.]]).to(device) # initial condition | |
t = torch.linspace(0., 5., 500).to(device) # t is from 0 to 5 for 500 data points | |
L = Lorenz().to(device) | |
# generating training data | |
with torch.no_grad(): | |
true_lorenz = RK4(t,true_y0, L) | |
visualize_3d(noiseless_traj=true_lorenz.cpu()) | |
# additional training data | |
training_data = [true_lorenz] | |
with torch.no_grad(): | |
for i in range(10): | |
random_perturbation = torch.rand(3)*3 | |
random_perturbation = random_perturbation.to(device) | |
training_data.append(RK4(t,true_y0*random_perturbation, L)) | |
visualize_3d(noiseless_traj=torch.cat(training_data).cpu()) | |
observations = true_lorenz | |
feature_size = observations.shape[1] | |
model = fcRNN(feature_size, 128, feature_size, 2).to(device) | |
pytorch_total_params = sum(p.numel() for p in model.parameters()) | |
print("Model has a total of {0} parameters.".format(pytorch_total_params)) | |
list_of_predictions = [] | |
train_window = 15 | |
# generating training data and labels for RNN | |
data_train, labels = [], [] | |
for observations in training_data: | |
d_t, lab = create_inout_sequences(observations, train_window) | |
data_train.append(d_t) | |
labels.append(lab) | |
data_train, labels = torch.cat(data_train), torch.cat(labels) | |
data_train.shape | |
epochs = 10000 | |
loss_arr = [] | |
loss_function = nn.MSELoss() | |
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2) | |
# training loop | |
for i in range(epochs + 1): | |
if i > 1200: | |
optimizer.param_groups[0]['lr'] = 1e-3 | |
optimizer.zero_grad() | |
y_pred = model(data_train) | |
#print(y_pred.shape) | |
loss = loss_function(y_pred, labels.squeeze()) | |
loss.backward() | |
optimizer.step() | |
loss_arr.append(loss.item()) | |
if i%100 == 0: | |
print(f'epoch: {i:3} loss: {loss_arr[-1]:10.8f}') | |
pred_len = 490 | |
test_inputs = data_train[0].squeeze().detach() | |
pred_traj = model_predict(pred_len, model, test_inputs, train_window) | |
save_name = "PNG/"+str(i) | |
list_of_predictions.append(pred_traj.cpu()) | |
#visualize_3d(obs=true_lorenz.cpu().detach().numpy(), | |
# noiseless_traj=pred_traj.cpu(), | |
# save=None) | |
print(f'epoch: {i:3} loss: {loss_arr[-1]:10.10f}') | |
# animation and visualization | |
# --- | |
# multicolored temporal evolution | |
noiseless_traj = list_of_predictions[0] | |
obs = true_lorenz.cpu() | |
list_of_lines = [] | |
fig = plt.figure(figsize=(5,5)) | |
sns.set(style="ticks", context="paper") | |
sns.set_palette("bright") | |
plt.style.use("dark_background") | |
ax = fig.add_subplot(1, 1, 1, projection='3d') | |
ax.set_facecolor('#30363f') | |
ax.set_title('Lorenz 63') | |
z = np.array([o for o in noiseless_traj]) | |
z = np.reshape(z, [-1,3]) | |
for i in range(len(z)): | |
line = ax.plot(z[i:i+10, 0], z[i:i+10, 1], z[i:i+10, 2], color=plt.cm.magma(i/len(z))) | |
list_of_lines.append(line) | |
z = np.array([o for o in obs]) | |
z = np.reshape(z, [-1,3]) | |
scat = ax.scatter(z[:,0], z[:,1], z[:,2], marker='.', color='k', alpha=0.5, linewidths=0, s=45) | |
ax.set_xlim([-20,20]) | |
ax.set_ylim([-20,20]) # this shouldn't be necessary but the limits are usually enlarged per defailt | |
ax.set_zlim([0,40]) # this shouldn't be necessary but the limits are usually enlarged per defailt | |
def update(frame): | |
noiseless_traj = list_of_predictions[frame] | |
z = np.array([o for o in noiseless_traj]) | |
z = np.reshape(z, [-1,3]) | |
for i in range(len(z)): | |
list_of_lines[i][0].set_data_3d(z[i:i+10, 0], z[i:i+10, 1], z[i:i+10, 2]) | |
#print(z[i:i+10, 0]) | |
ani = animation.FuncAnimation(fig=fig, func=update, frames=len(list_of_predictions), interval=100) | |
plt.show() | |
ani.save(filename="./lorenz_rrn_temp.mp4", writer="ffmpeg", savefig_kwargs={'facecolor': '#30363f'}) | |
# spaghettiplot animation | |
obs = true_lorenz.cpu() | |
list_of_lines = [] | |
fig = plt.figure(figsize=(5,5)) | |
sns.set(style="ticks", context="paper") | |
sns.set_palette("bright") | |
color_palette = sns.color_palette("magma", len(list_of_predictions)) | |
plt.style.use("dark_background") | |
ax = fig.add_subplot(1, 1, 1, projection='3d') | |
ax.set_facecolor('#30363f') | |
ax.set_title('Lorenz 63') | |
for i,noiseless_traj in enumerate(list_of_predictions): | |
z = np.array([o for o in noiseless_traj]) | |
z = np.reshape(z, [-1,3]) | |
line = ax.plot(z[0:1, 0], z[0:1, 1], z[0:1, 2], color=color_palette[i]) | |
list_of_lines.append(line) | |
z = np.array([o for o in obs]) | |
z = np.reshape(z, [-1,3]) | |
ax.scatter(z[:,0], z[:,1], z[:,2], marker='.', color='k', alpha=0.5, linewidths=0, s=45) | |
scat = ax.plot(z[:1,0], z[:1,1], z[:1,2], marker='.', color='blue', linestyle="", alpha=1.0, markersize=5) | |
ax.set_xlim([-20,20]) | |
ax.set_ylim([-20,20]) # this shouldn't be necessary but the limits are usually enlarged per default | |
ax.set_zlim([0,40]) # this shouldn't be necessary but the limits are usually enlarged per default | |
def update(frame): | |
for i, noiseless_traj in enumerate(list_of_predictions): | |
z = np.array([o for o in noiseless_traj]) | |
z = np.reshape(z, [-1,3]) | |
list_of_lines[i][0].set_data_3d(z[:frame+1, 0], z[:frame+1, 1], z[:frame+1, 2]) | |
scat[0].set_data(obs[frame][0], obs[frame][1]) | |
scat[0].set_3d_properties(obs[frame][2]) | |
#scat.set_3d_properties(obs[frame][2]) | |
#print(z[i:i+10, 0]) | |
ani = animation.FuncAnimation(fig=fig, func=update, frames=len(list_of_predictions[0]), interval=30) | |
plt.show() | |
ani.save(filename="./lorenz_rrn_spaghetti_training.mp4", writer="ffmpeg", savefig_kwargs= {'facecolor': '#30363f'}) | |
pass | |
if __name__ == "__main__": | |
main() | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment