Last active
September 16, 2022 00:25
-
-
Save slwu89/50f81e6bbc296261e29851d3585ae554 to your computer and use it in GitHub Desktop.
pcls in Julia with JuMP
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using JuMP | |
using HiGHS | |
using LinearAlgebra | |
using RCall | |
using Plots | |
# example from pcls help, see https://stat.ethz.ch/R-manual/R-devel/library/mgcv/html/pcls.html | |
R""" | |
library(mgcv) | |
x <- runif(100)*4-1;x <- sort(x); | |
f <- exp(4*x)/(1+exp(4*x)); y <- f+rnorm(100)*0.1; | |
dat <- data.frame(x=x,y=y) | |
# Show regular spline fit (and save fitted object) | |
f_ug <- gam(y~s(x,k=10,bs="cr")); | |
y_ug <- fitted(f_ug) | |
# Create Design matrix, constraints etc. for monotonic spline.... | |
sm <- smoothCon(s(x,k=10,bs="cr"),dat,knots=NULL)[[1]] | |
F <- mono.con(sm$xp); # get constraints | |
G <- list(X=sm$X,C=matrix(0,0,0),sp=f_ug$sp,p=sm$xp,y=y,w=y*0+1) | |
G$Ain <- F$A;G$bin <- F$b;G$S <- sm$S;G$off <- 0 | |
pR <- pcls(G); # fit spline (using s.p. from unconstrained fit) | |
""" | |
@rget G | |
@rget sm | |
@rget F | |
@rget pR | |
@rget x | |
@rget y_ug | |
# or diagm for Matrix object | |
W = Diagonal(G[:w]) | |
X = G[:X] | |
p0 = G[:p] | |
y = G[:y] | |
lambda = G[:sp] | |
S = G[:S][1] | |
A = G[:Ain] | |
b = G[:bin] | |
C = sm[:C] | |
pcls_model = Model(HiGHS.Optimizer) | |
@variable(pcls_model, p[1:10]) | |
@constraint(pcls_model, A * p .>= b) # linear inequality constraints | |
# @constraint(pcls_model, C * p .== 0) # linear equality constraints | |
# @objective(pcls_model, Min, sum((sqrt(W)*(X*p - y)).^2) + (lambda * transpose(p) * S * p)) | |
@objective(pcls_model, Min, (transpose(X*p - y) * W * (X*p - y)) + (lambda * transpose(p) * S * p)) | |
optimize!(pcls_model) | |
hcat(value.(p), pR) | |
R"predX <- Predict.matrix(sm,data.frame(x=x))" | |
@rget predX | |
R""" | |
plot(x,y) | |
lines(x,y_ug) | |
fv <- predX %*% pR | |
lines(x,fv,col='red') | |
""" | |
plot(x,y,seriestype = :scatter, legend = false) | |
plot!(x,y_ug,linecolor=:black) | |
plot!(x,predX * value.(p),linecolor=:red) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment