Skip to content

Instantly share code, notes, and snippets.

@Algomancer
Created December 12, 2023 10:11
Show Gist options
  • Save Algomancer/bdbb50c993fdfe36fbce29c288a2782c to your computer and use it in GitHub Desktop.
Save Algomancer/bdbb50c993fdfe36fbce29c288a2782c to your computer and use it in GitHub Desktop.
import torch
from torchdiffeq import odeint
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import tqdm
import imageio
import os
# Load target image and preprocess it
target_image = Image.open('x.png')#.convert('L') # Convert to grayscale
target_image = target_image.resize((32, 32)) # Resize the image
target_image = torch.from_numpy(np.array(target_image)).float() / 255.0 # Normalize to [0, 1]
target_image = target_image.cuda()
num_oscillators = target_image.numel() # Number of oscillators is same as number of pixels
class KuramotoLayer(torch.nn.Module):
def __init__(self, num_oscillators, coupling_strength):
super(KuramotoLayer, self).__init__()
self.num_oscillators = num_oscillators
self.coupling_strength = coupling_strength
self.natural_frequencies = torch.nn.Parameter(torch.randn(num_oscillators)) # Learnable parameters, replace this attention weights
def forward(self, t, phase):
phase_diffs = phase[None, :] - phase[:, None] # phase differences between all pairs of oscillators
interaction_terms = torch.sin(phase_diffs).sum(dim=1) # interaction term for each oscillator
dphase_dt = self.natural_frequencies + self.coupling_strength / self.num_oscillators * interaction_terms
return dphase_dt
# Create the model
model = KuramotoModel(num_oscillators, coupling_strength=1).cuda()
# Initial conditions and time span
initial_phase = torch.rand(num_oscillators).cuda() # Initial phase
t = torch.linspace(0, 10, 100) # Time span
# Define an optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=1.0)
# Initialize list to store frames
frames = []
for epoch in tqdm.tqdm(range(1000)):
optimizer.zero_grad()
# Solve the differential equations
phase = odeint(model, initial_phase, t) # Solve the differential equations
output_image = phase[-1].view(target_image.shape) # Reshape phase to match target image
# Compute the loss
loss = torch.nn.functional.mse_loss(output_image, target_image)
# Backpropagation
loss.backward()
# Update the parameters
optimizer.step()
# Generate and store frame
fig, ax = plt.subplots()
ax.imshow(output_image.cpu().detach().numpy(), cmap='gray')
ax.set_title(f'Epoch {epoch}, Loss {loss.item()}, num_occ={num_oscillators}')
fig.canvas.draw() # draw the canvas, cache the renderer
image = np.frombuffer(fig.canvas.tostring_rgb(), dtype='uint8')
image = image.reshape(fig.canvas.get_width_height()[::-1] + (3,))
frames.append(image)
plt.close(fig)
# Create gif from frames
imageio.mimsave('training_process.gif', frames, duration=0.5)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment