Skip to content

Instantly share code, notes, and snippets.

@caseykneale
Last active November 23, 2020 16:08
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save caseykneale/b21c4c6cf5119c58d4f933baac16136b to your computer and use it in GitHub Desktop.
Save caseykneale/b21c4c6cf5119c58d4f933baac16136b to your computer and use it in GitHub Desktop.
# uses Flux.v11
using Plots, Flux
function update!(opt, x, x̄)
x[:] .-= apply!(opt, x, x̄)[:]
end
function update!(opt, xs::Flux.Params, gs)
for x in xs
(gs[x] === nothing) && continue
update!(opt, x, gs[x])
end
end
mutable struct βLASSO
η::Float32
λ::Float32
β::Float32
end
βLASSO(η = Float32(0.01), λ = Float32(0.009), β = Float32(50.0)) = βLASSO(η, λ, β)
function apply!(o::βLASSO, x, Δ)
Δ = o.η .* ( Δ .+ ( o.λ .* sign.( Δ ) ) )
Δ = Δ .* Float32.( abs.( x ) .> ( opt.β * opt.λ ) )
return Δ
end
function loss(x, y)
sum(abs2, ( model( x ) .- y ) ) / length(y)
end
function train_me!(loss, ps, data, opt)
ps = Flux.Params(ps)
gs = Flux.gradient(ps) do
loss(data...)
end
update!(opt, ps, gs)
end
#faux data
X = rand(100, 10) .- 0.5
#make a property value with some normally distributed noise
y = rand(100) .+ randn(100)/100
#Make 5th feature proportional to the property value
X[:,5] = 0.5 * y
X = convert.(Float32, X)
y = convert.(Float32, y)
model = Flux.Dense( 10, 1, identity )
#Classic LASSO via proj SGD
opt = βLASSO( Float32(0.03), Float32(0.005), Float32(1.0) )
#Increase β parameter to 2.0 - if you want too
#opt = βLASSO( Float32(0.03), Float32(0.005), Float32(2.0) )
losses = []
plot()#cue up plots
anim = @animate for i ∈ 1:1500
global model
train_me!(loss, Flux.params( model ), ( X', y'), opt)
if (i % 50) == 0
push!(losses, loss(X', y'))
l = @layout [ a b ]
p1 = bar(model.W', legend = false, title = "βLASSO weights")
p2 = plot(losses, legend = false, title = "Loss")
display( plot(p1, p2, layout = l) )
end
end
gif(anim, "plots/BLASSO wts.gif", fps = 60)
@luboshanus
Copy link

Hi, this is a nice example. I see the error with '.data' probably while updating. If I remove it ('.data') everywhere, your code work but the algorithm does not converge as your solution.

ERROR: LoadError: type Array has no field data
Stacktrace:
 [1] getproperty(::Array{Float32,2}, ::Symbol) at ./Base.jl:33
 [2] update!(::βLASSO, ::Array{Float32,2}, ::Array{Float32,2}) at ./REPL[336]:2
 [3] update!(::βLASSO, ::Zygote.Params, ::Zygote.Grads) at ./REPL[337]:4
 [4] train_me!(::Function, ::Zygote.Params, ::Tuple{LinearAlgebra.Adjoint{Float32,Array{Float32,2}},LinearAlgebra.Adjoint{Float32,Array{Float32,1}}}, ::βLASSO) at /Volumes/LUBOS64/phd/DisAltr/examples/blasso_gif.jl:37
 [5] macro expansion at /Volumes/LUBOS64/phd/DisAltr/examples/blasso_gif.jl:56 [inlined]
 [6] top-level scope at /Users/ies602/.julia/packages/Plots/vsE7b/src/animation.jl:183
 [7] include_string(::Function, ::Module, ::String, ::String) at ./loading.jl:1091
 [8] invokelatest(::Any, ::Any, ::Vararg{Any,N} where N; kwargs::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{,Tuple{}}}) at ./essentials.jl:710
 [9] invokelatest(::Any, ::Any, ::Vararg{Any,N} where N) at ./essentials.jl:709
 [10] inlineeval(::Module, ::String, ::Int64, ::Int64, ::String; softscope::Bool) at /Users/ies602/.vscode/extensions/julialang.language-julia-1.0.10/scripts/packages/VSCodeServer/src/eval.jl:185
 [11] (::VSCodeServer.var"#61#65"{String,Int64,Int64,String,Module,Bool,VSCodeServer.ReplRunCodeRequestParams})() at /Users/ies602/.vscode/extensions/julialang.language-julia-1.0.10/scripts/packages/VSCodeServer/src/eval.jl:144
 [12] withpath(::VSCodeServer.var"#61#65"{String,Int64,Int64,String,Module,Bool,VSCodeServer.ReplRunCodeRequestParams}, ::String) at /Users/ies602/.vscode/extensions/julialang.language-julia-1.0.10/scripts/packages/VSCodeServer/src/repl.jl:124
 [13] (::VSCodeServer.var"#60#64"{String,Int64,Int64,String,Module,Bool,Bool,VSCodeServer.ReplRunCodeRequestParams})() at /Users/ies602/.vscode/extensions/julialang.language-julia-1.0.10/scripts/packages/VSCodeServer/src/eval.jl:142
 [14] hideprompt(::VSCodeServer.var"#60#64"{String,Int64,Int64,String,Module,Bool,Bool,VSCodeServer.ReplRunCodeRequestParams}) at /Users/ies602/.vscode/extensions/julialang.language-julia-1.0.10/scripts/packages/VSCodeServer/src/repl.jl:36
 [15] (::VSCodeServer.var"#59#63"{String,Int64,Int64,String,Module,Bool,Bool,VSCodeServer.ReplRunCodeRequestParams})() at /Users/ies602/.vscode/extensions/julialang.language-julia-1.0.10/scripts/packages/VSCodeServer/src/eval.jl:110
 [16] with_logstate(::Function, ::Any) at ./logging.jl:408
 [17] with_logger at ./logging.jl:514 [inlined]
 [18] (::VSCodeServer.var"#58#62"{VSCodeServer.ReplRunCodeRequestParams})() at /Users/ies602/.vscode/extensions/julialang.language-julia-1.0.10/scripts/packages/VSCodeServer/src/eval.jl:109
 [19] #invokelatest#1 at ./essentials.jl:710 [inlined]
 [20] invokelatest(::Any) at ./essentials.jl:709
 [21] macro expansion at /Users/ies602/.vscode/extensions/julialang.language-julia-1.0.10/scripts/packages/VSCodeServer/src/eval.jl:27 [inlined]
 [22] (::VSCodeServer.var"#56#57")() at ./task.jl:356
in expression starting at /Volumes/LUBOS64/phd/DisAltr/examples/blasso_gif.jl:54

Example here with removed '.data'. It changes with initialisation but it should converge anyway.
Screen Shot 2020-11-23 at 15 14 57

@caseykneale
Copy link
Author

Probably an old version of Flux? Not sure...

@luboshanus
Copy link

Can you tell me what version do you have? I tried the newest and the 0.10.4. Neither worked.

@caseykneale
Copy link
Author

What I meant to say was, I was probably using an old version of Flux. I can't state exactly what version was used to write that code because - it was done on a previous computer in an environment that has since been lost :(. I know for a while the latest Flux was too buggy for me to get anything done so I was working off of an old version.

But you are able to get the code to run without the .data statements in the latest Flux? I'll probably change the gist to support the latest version if so. Otherwise I am kneedeep in startup transients right now and it'll have to wait a few days.

@luboshanus
Copy link

luboshanus commented Nov 23, 2020

Ok. Thanks for the message.

Yes, without .data it works in latest Flux. The figure I posted is of the results using the latest Flux. However, I think there is an error with updating of x I guess in update!() function but I did not go deeper, wrote you first :)) . Or maybe not... Just did not converge that the 5th parameter would equal to 2.0.

@caseykneale
Copy link
Author

Try the following: opt = βLASSO( Float32(0.03), Float32(0.005), Float32(1.0) )
This should be equivalent to classic LASSO barring unluckiness. The authors felt parameterizing the third parameter was useful - I'm not especially convinced - but for large datasets it may be.

@luboshanus
Copy link

Yes, it works! Thanks!

Code without .data.

# edit Flux.v11
using Plots, Flux

function update!(opt, x, x̄)
  x[:] .-= apply!(opt, x, x̄)[:]
end

function update!(opt, xs::Flux.Params, gs)
  for x in xs
    (gs[x] === nothing) && continue
    update!(opt, x, gs[x])
  end
end

mutable struct βLASSO
  η::Float32
  λ::Float32
  β::Float32
end

βLASSO= Float32(0.01), λ = Float32(0.009), β = Float32(50.0)) = βLASSO(η, λ, β)

function apply!(o::βLASSO, x, Δ)
  Δ = o.η .* ( Δ .+ ( o.λ .* sign.( Δ ) ) )
  Δ = Δ .* Float32.( abs.( x ) .> ( opt.β * opt.λ ) )
  return Δ
end

function loss(x, y)
  sum(abs2, ( model( x ) .- y ) ) / length(y)
end

function train_me!(loss, ps, data, opt)
  ps = Flux.Params(ps)
  gs = Flux.gradient(ps) do
    loss(data...)
  end
  update!(opt, ps, gs)
end

#faux data
X = rand(100, 10) .- 0.5
#make a property value with some normally distributed noise
y = rand(100) .+ randn(100)/100
#Make 5th feature proportional to the property value
X[:,5]  = 0.5 * y
X = convert.(Float32, X)
y = convert.(Float32, y)

model = Flux.Dense( 10, 1, identity )

opt = βLASSO( Float32(0.03), Float32(0.005), Float32(1.0) ) # this setup works
losses = []
plot()#cue up plots
anim = @animate for i  1:1500
  global model
  train_me!(loss, Flux.params( model ), ( X', y'), opt)
  if (i % 50) == 0
    push!(losses, loss(X', y'))
    l = @layout [ a b ]
    p1 = bar(model.W', legend = false, title = "βLASSO weights")
    p2 = plot(losses, legend = false, title = "Loss")
    display( plot(p1, p2, layout = l) )
  end
end
# gif(anim, "plots/BLASSO wts.gif", fps = 60)

@caseykneale
Copy link
Author

caseykneale commented Nov 23, 2020

Awesome! I'll update the gist with your code.

If you read the post I linked (10 min paper review) I kind of discuss that beta parameter.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment