-
-
Save mdnestor/2c508c26bfcd7ceaddc685f8f4d2d7f5 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import matplotlib.pyplot as plt | |
import numpy as np | |
import torch | |
import torch.nn.functional as F | |
class IDEResult(): | |
def __init__(self, t, x, u): | |
self.t = t | |
self.x = x | |
self.u = u | |
def solve_ide_pytorch(f, k, u0, xmin, xmax, step_size, disp_radius, n_iterations, | |
t_eval=None, normalize_kernel=True): | |
if t_eval is None: | |
t_eval = np.arange(n_iterations) | |
x = np.arange(xmin, xmax+step_size, step_size) | |
x_conv = np.arange(-disp_radius, disp_radius+step_size, step_size) | |
kx = k(x_conv) | |
if normalize_kernel: | |
kx /= np.sum(kx) | |
u0x = torch.tensor(u0(x)).unsqueeze(0) | |
kx = torch.tensor(kx).unsqueeze(0).unsqueeze(0) | |
u = [u0x] | |
for t in range(n_iterations): | |
utx = u[-1] | |
utx_next = F.conv1d(f(utx), kx, padding="same") | |
if t in t_eval: | |
u.append(utx_next) | |
u = [utx.numpy() for utx in u] | |
u = np.asarray(u).squeeze(1) | |
return IDEResult(t_eval, x, u) | |
# example usage | |
from scipy import stats | |
result = solve_ide_pytorch( | |
f = lambda u: 2.0 * u * (1.0 - u), | |
k = stats.norm.pdf, | |
u0=lambda x: 0.5*np.heaviside(1.0 - np.abs(x), 1), | |
xmin=-10, | |
xmax=10, | |
step_size=0.1, | |
disp_radius=5, | |
n_iterations=100, | |
) | |
x, u = result.x, result.u | |
for ut in u: | |
plt.plot(x, ut) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment