Skip to content

Instantly share code, notes, and snippets.

Last active November 23, 2020 16:08
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save caseykneale/b21c4c6cf5119c58d4f933baac16136b to your computer and use it in GitHub Desktop.
Save caseykneale/b21c4c6cf5119c58d4f933baac16136b to your computer and use it in GitHub Desktop.
# uses Flux.v11
using Plots, Flux
function update!(opt, x, x̄)
x[:] .-= apply!(opt, x, x̄)[:]
function update!(opt, xs::Flux.Params, gs)
for x in xs
(gs[x] === nothing) && continue
update!(opt, x, gs[x])
mutable struct βLASSO
βLASSO(η = Float32(0.01), λ = Float32(0.009), β = Float32(50.0)) = βLASSO(η, λ, β)
function apply!(o::βLASSO, x, Δ)
Δ = o.η .* ( Δ .+ ( o.λ .* sign.( Δ ) ) )
Δ = Δ .* Float32.( abs.( x ) .> ( opt.β * opt.λ ) )
return Δ
function loss(x, y)
sum(abs2, ( model( x ) .- y ) ) / length(y)
function train_me!(loss, ps, data, opt)
ps = Flux.Params(ps)
gs = Flux.gradient(ps) do
update!(opt, ps, gs)
#faux data
X = rand(100, 10) .- 0.5
#make a property value with some normally distributed noise
y = rand(100) .+ randn(100)/100
#Make 5th feature proportional to the property value
X[:,5] = 0.5 * y
X = convert.(Float32, X)
y = convert.(Float32, y)
model = Flux.Dense( 10, 1, identity )
#Classic LASSO via proj SGD
opt = βLASSO( Float32(0.03), Float32(0.005), Float32(1.0) )
#Increase β parameter to 2.0 - if you want too
#opt = βLASSO( Float32(0.03), Float32(0.005), Float32(2.0) )
losses = []
plot()#cue up plots
anim = @animate for i ∈ 1:1500
global model
train_me!(loss, Flux.params( model ), ( X', y'), opt)
if (i % 50) == 0
push!(losses, loss(X', y'))
l = @layout [ a b ]
p1 = bar(model.W', legend = false, title = "βLASSO weights")
p2 = plot(losses, legend = false, title = "Loss")
display( plot(p1, p2, layout = l) )
gif(anim, "plots/BLASSO wts.gif", fps = 60)
Copy link

Hi, this is a nice example. I see the error with '.data' probably while updating. If I remove it ('.data') everywhere, your code work but the algorithm does not converge as your solution.

ERROR: LoadError: type Array has no field data
 [1] getproperty(::Array{Float32,2}, ::Symbol) at ./Base.jl:33
 [2] update!(::βLASSO, ::Array{Float32,2}, ::Array{Float32,2}) at ./REPL[336]:2
 [3] update!(::βLASSO, ::Zygote.Params, ::Zygote.Grads) at ./REPL[337]:4
 [4] train_me!(::Function, ::Zygote.Params, ::Tuple{LinearAlgebra.Adjoint{Float32,Array{Float32,2}},LinearAlgebra.Adjoint{Float32,Array{Float32,1}}}, ::βLASSO) at /Volumes/LUBOS64/phd/DisAltr/examples/blasso_gif.jl:37
 [5] macro expansion at /Volumes/LUBOS64/phd/DisAltr/examples/blasso_gif.jl:56 [inlined]
 [6] top-level scope at /Users/ies602/.julia/packages/Plots/vsE7b/src/animation.jl:183
 [7] include_string(::Function, ::Module, ::String, ::String) at ./loading.jl:1091
 [8] invokelatest(::Any, ::Any, ::Vararg{Any,N} where N; kwargs::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{,Tuple{}}}) at ./essentials.jl:710
 [9] invokelatest(::Any, ::Any, ::Vararg{Any,N} where N) at ./essentials.jl:709
 [10] inlineeval(::Module, ::String, ::Int64, ::Int64, ::String; softscope::Bool) at /Users/ies602/.vscode/extensions/julialang.language-julia-1.0.10/scripts/packages/VSCodeServer/src/eval.jl:185
 [11] (::VSCodeServer.var"#61#65"{String,Int64,Int64,String,Module,Bool,VSCodeServer.ReplRunCodeRequestParams})() at /Users/ies602/.vscode/extensions/julialang.language-julia-1.0.10/scripts/packages/VSCodeServer/src/eval.jl:144
 [12] withpath(::VSCodeServer.var"#61#65"{String,Int64,Int64,String,Module,Bool,VSCodeServer.ReplRunCodeRequestParams}, ::String) at /Users/ies602/.vscode/extensions/julialang.language-julia-1.0.10/scripts/packages/VSCodeServer/src/repl.jl:124
 [13] (::VSCodeServer.var"#60#64"{String,Int64,Int64,String,Module,Bool,Bool,VSCodeServer.ReplRunCodeRequestParams})() at /Users/ies602/.vscode/extensions/julialang.language-julia-1.0.10/scripts/packages/VSCodeServer/src/eval.jl:142
 [14] hideprompt(::VSCodeServer.var"#60#64"{String,Int64,Int64,String,Module,Bool,Bool,VSCodeServer.ReplRunCodeRequestParams}) at /Users/ies602/.vscode/extensions/julialang.language-julia-1.0.10/scripts/packages/VSCodeServer/src/repl.jl:36
 [15] (::VSCodeServer.var"#59#63"{String,Int64,Int64,String,Module,Bool,Bool,VSCodeServer.ReplRunCodeRequestParams})() at /Users/ies602/.vscode/extensions/julialang.language-julia-1.0.10/scripts/packages/VSCodeServer/src/eval.jl:110
 [16] with_logstate(::Function, ::Any) at ./logging.jl:408
 [17] with_logger at ./logging.jl:514 [inlined]
 [18] (::VSCodeServer.var"#58#62"{VSCodeServer.ReplRunCodeRequestParams})() at /Users/ies602/.vscode/extensions/julialang.language-julia-1.0.10/scripts/packages/VSCodeServer/src/eval.jl:109
 [19] #invokelatest#1 at ./essentials.jl:710 [inlined]
 [20] invokelatest(::Any) at ./essentials.jl:709
 [21] macro expansion at /Users/ies602/.vscode/extensions/julialang.language-julia-1.0.10/scripts/packages/VSCodeServer/src/eval.jl:27 [inlined]
 [22] (::VSCodeServer.var"#56#57")() at ./task.jl:356
in expression starting at /Volumes/LUBOS64/phd/DisAltr/examples/blasso_gif.jl:54

Example here with removed '.data'. It changes with initialisation but it should converge anyway.
Screen Shot 2020-11-23 at 15 14 57

Copy link

Probably an old version of Flux? Not sure...

Copy link

Can you tell me what version do you have? I tried the newest and the 0.10.4. Neither worked.

Copy link

What I meant to say was, I was probably using an old version of Flux. I can't state exactly what version was used to write that code because - it was done on a previous computer in an environment that has since been lost :(. I know for a while the latest Flux was too buggy for me to get anything done so I was working off of an old version.

But you are able to get the code to run without the .data statements in the latest Flux? I'll probably change the gist to support the latest version if so. Otherwise I am kneedeep in startup transients right now and it'll have to wait a few days.

Copy link

luboshanus commented Nov 23, 2020

Ok. Thanks for the message.

Yes, without .data it works in latest Flux. The figure I posted is of the results using the latest Flux. However, I think there is an error with updating of x I guess in update!() function but I did not go deeper, wrote you first :)) . Or maybe not... Just did not converge that the 5th parameter would equal to 2.0.

Copy link

Try the following: opt = βLASSO( Float32(0.03), Float32(0.005), Float32(1.0) )
This should be equivalent to classic LASSO barring unluckiness. The authors felt parameterizing the third parameter was useful - I'm not especially convinced - but for large datasets it may be.

Copy link

Yes, it works! Thanks!

Code without .data.

# edit Flux.v11
using Plots, Flux

function update!(opt, x, x̄)
  x[:] .-= apply!(opt, x, x̄)[:]

function update!(opt, xs::Flux.Params, gs)
  for x in xs
    (gs[x] === nothing) && continue
    update!(opt, x, gs[x])

mutable struct βLASSO

βLASSO= Float32(0.01), λ = Float32(0.009), β = Float32(50.0)) = βLASSO(η, λ, β)

function apply!(o::βLASSO, x, Δ)
  Δ = o.η .* ( Δ .+ ( o.λ .* sign.( Δ ) ) )
  Δ = Δ .* Float32.( abs.( x ) .> ( opt.β * opt.λ ) )
  return Δ

function loss(x, y)
  sum(abs2, ( model( x ) .- y ) ) / length(y)

function train_me!(loss, ps, data, opt)
  ps = Flux.Params(ps)
  gs = Flux.gradient(ps) do
  update!(opt, ps, gs)

#faux data
X = rand(100, 10) .- 0.5
#make a property value with some normally distributed noise
y = rand(100) .+ randn(100)/100
#Make 5th feature proportional to the property value
X[:,5]  = 0.5 * y
X = convert.(Float32, X)
y = convert.(Float32, y)

model = Flux.Dense( 10, 1, identity )

opt = βLASSO( Float32(0.03), Float32(0.005), Float32(1.0) ) # this setup works
losses = []
plot()#cue up plots
anim = @animate for i  1:1500
  global model
  train_me!(loss, Flux.params( model ), ( X', y'), opt)
  if (i % 50) == 0
    push!(losses, loss(X', y'))
    l = @layout [ a b ]
    p1 = bar(model.W', legend = false, title = "βLASSO weights")
    p2 = plot(losses, legend = false, title = "Loss")
    display( plot(p1, p2, layout = l) )
# gif(anim, "plots/BLASSO wts.gif", fps = 60)

Copy link

caseykneale commented Nov 23, 2020

Awesome! I'll update the gist with your code.

If you read the post I linked (10 min paper review) I kind of discuss that beta parameter.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment