Skip to content

Instantly share code, notes, and snippets.

@IvanYashchuk
Last active April 29, 2021 21:29
Show Gist options
  • Star 3 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save IvanYashchuk/1b5643126037bb41064a00a969b805ac to your computer and use it in GitHub Desktop.
Save IvanYashchuk/1b5643126037bb41064a00a969b805ac to your computer and use it in GitHub Desktop.
FEniCS solver + Zygote.jl + Turing.jl
using PyFenicsAD
using Zygote
using PyCall
using Turing
import LinearAlgebra: norm
using Random, Distributions
fenics = pyimport("fenics")
fenics.set_log_level(fenics.LogLevel.ERROR)
fa = pyimport("fenics_adjoint")
ufl = pyimport("ufl")
n = 25
mesh = fa.UnitSquareMesh(n ,n)
V = fenics.FunctionSpace(mesh, "P", 1)
# Define FEniCS function using Python's fenics, fenics_adjoint and ufl
function solve_fenics(kappa0, kappa1)
f = fa.Expression(
"10*exp(-(pow(x[0] - 0.5, 2) + pow(x[1] - 0.5, 2)) / 0.02)", degree=2
)
u = fa.Function(V)
bcs = [fa.DirichletBC(V, fa.Constant(0.0), "on_boundary")]
inner, grad, dx = ufl.inner, ufl.grad, ufl.dx
JJ = 0.5 * inner(kappa0 * grad(u), grad(u)) * dx - kappa1 * f * u * dx
v = fenics.TestFunction(V)
F = fenics.derivative(JJ, u, v)
fa.solve(F == 0, u, bcs=bcs)
return u
end
# This is boilerplate code for registering Python's FEniCS function in Zygote
# Only solve_fenics and templates need to modified from code to code
# zygote_solve_fenics is a wrapper function that calls solve_fenics is differentiable
templates = (fa.Constant(0.0), fa.Constant(0.0))
zygote_solve_fenics(inputs...) = fem_eval(solve_fenics, templates, inputs...)[1]
Zygote.@adjoint function zygote_solve_fenics(inputs...)
pyout = pycall(fem_eval, PyObject, solve_fenics, templates, inputs...)
numpy_output, fenics_output, fenics_inputs, tape = [get(pyout, PyObject, i) for i in 0:3]
function vjp_fun(g)
vjp_out = vjp_fem_eval(g, fenics_output, fenics_inputs, tape)
end
return get(pyout, 0), vjp_fun
end
true_kappa0 = [1.25]
true_kappa1 = [0.55]
true_solution = zygote_solve_fenics(true_kappa0, true_kappa1)
# perturb state solution and create synthetic measurements
noise_level = 0.05
MAX = norm(true_solution)
noise = rand(Normal(0, noise_level * MAX), size(true_solution))
noisy_solution = true_solution + noise
# fenics_noisy_solution = numpy_to_fenics(noisy_solution, fenics.Function(V))
Turing.setadbackend(:zygote)
Turing.turnprogress(true)
@model function fit_diffusion(data)
σ ~ InverseGamma(3, 0.5)
kappa0 ~ truncated(Normal(1.0, 0.5), 1e-5, 2)
kappa1 ~ truncated(Normal(0.7, 0.5), 1e-5, 2)
predicted_solution = zygote_solve_fenics([kappa0], [kappa1])
data ~ MvNormal(predicted_solution, σ)
end
model = fit_diffusion(noisy_solution)
chain = sample(model, NUTS(.65), 1000)
@IvanYashchuk
Copy link
Author

@ChrisRackauckas
Copy link

Awesome thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment