Skip to content

Instantly share code, notes, and snippets.

@mdnestor
Created December 14, 2022 08:55
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save mdnestor/2c508c26bfcd7ceaddc685f8f4d2d7f5 to your computer and use it in GitHub Desktop.
Save mdnestor/2c508c26bfcd7ceaddc685f8f4d2d7f5 to your computer and use it in GitHub Desktop.
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn.functional as F
class IDEResult():
def __init__(self, t, x, u):
self.t = t
self.x = x
self.u = u
def solve_ide_pytorch(f, k, u0, xmin, xmax, step_size, disp_radius, n_iterations,
t_eval=None, normalize_kernel=True):
if t_eval is None:
t_eval = np.arange(n_iterations)
x = np.arange(xmin, xmax+step_size, step_size)
x_conv = np.arange(-disp_radius, disp_radius+step_size, step_size)
kx = k(x_conv)
if normalize_kernel:
kx /= np.sum(kx)
u0x = torch.tensor(u0(x)).unsqueeze(0)
kx = torch.tensor(kx).unsqueeze(0).unsqueeze(0)
u = [u0x]
for t in range(n_iterations):
utx = u[-1]
utx_next = F.conv1d(f(utx), kx, padding="same")
if t in t_eval:
u.append(utx_next)
u = [utx.numpy() for utx in u]
u = np.asarray(u).squeeze(1)
return IDEResult(t_eval, x, u)
# example usage
from scipy import stats
result = solve_ide_pytorch(
f = lambda u: 2.0 * u * (1.0 - u),
k = stats.norm.pdf,
u0=lambda x: 0.5*np.heaviside(1.0 - np.abs(x), 1),
xmin=-10,
xmax=10,
step_size=0.1,
disp_radius=5,
n_iterations=100,
)
x, u = result.x, result.u
for ut in u:
plt.plot(x, ut)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment