Skip to content

Instantly share code, notes, and snippets.

@jiweiqi
Created January 30, 2021 01:58
Show Gist options
  • Save jiweiqi/ee0ab0b45413836aee1bc2603464fb2c to your computer and use it in GitHub Desktop.
Save jiweiqi/ee0ab0b45413836aee1bc2603464fb2c to your computer and use it in GitHub Desktop.
fit nn using LsqFit.jl
using LsqFit
using ForwardDiff
using Plots
using Flux
xdata = 0.0:1.0:10.0
xdata = xdata'
ydata = Float32.(vec(sin.(xdata)))
nn = Chain(
Dense(1, 5, tanh),
Dense(5, 5, tanh),
Dense(5, 1))
x, re = Flux.destructure(nn)
function f(p)
ypred = vec(re(p)(xdata))
return @. abs(ypred - ydata)
end
f(x)
g = function (p)
return ForwardDiff.jacobian(x -> f(x), p)
end
g(x)
fit = LsqFit.lmfit(f, g, x, Float64[]; show_trace=true, maxIter=1000, x_tol=1e-6)
display(fit.param)
plt = scatter(xdata[:], ydata[:], label="data");
plot!(plt, xdata[:], vec(re(fit.param)(xdata)), label="pred");
png(plt, "lmnn")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment