Skip to content

Instantly share code, notes, and snippets.

@caseykneale
Last active November 23, 2020 16:08
Show Gist options
  • Save caseykneale/b21c4c6cf5119c58d4f933baac16136b to your computer and use it in GitHub Desktop.
Save caseykneale/b21c4c6cf5119c58d4f933baac16136b to your computer and use it in GitHub Desktop.
# uses Flux.v11
using Plots, Flux
function update!(opt, x, x̄)
x[:] .-= apply!(opt, x, x̄)[:]
end
function update!(opt, xs::Flux.Params, gs)
for x in xs
(gs[x] === nothing) && continue
update!(opt, x, gs[x])
end
end
mutable struct βLASSO
η::Float32
λ::Float32
β::Float32
end
βLASSO(η = Float32(0.01), λ = Float32(0.009), β = Float32(50.0)) = βLASSO(η, λ, β)
function apply!(o::βLASSO, x, Δ)
Δ = o.η .* ( Δ .+ ( o.λ .* sign.( Δ ) ) )
Δ = Δ .* Float32.( abs.( x ) .> ( opt.β * opt.λ ) )
return Δ
end
function loss(x, y)
sum(abs2, ( model( x ) .- y ) ) / length(y)
end
function train_me!(loss, ps, data, opt)
ps = Flux.Params(ps)
gs = Flux.gradient(ps) do
loss(data...)
end
update!(opt, ps, gs)
end
#faux data
X = rand(100, 10) .- 0.5
#make a property value with some normally distributed noise
y = rand(100) .+ randn(100)/100
#Make 5th feature proportional to the property value
X[:,5] = 0.5 * y
X = convert.(Float32, X)
y = convert.(Float32, y)
model = Flux.Dense( 10, 1, identity )
#Classic LASSO via proj SGD
opt = βLASSO( Float32(0.03), Float32(0.005), Float32(1.0) )
#Increase β parameter to 2.0 - if you want too
#opt = βLASSO( Float32(0.03), Float32(0.005), Float32(2.0) )
losses = []
plot()#cue up plots
anim = @animate for i ∈ 1:1500
global model
train_me!(loss, Flux.params( model ), ( X', y'), opt)
if (i % 50) == 0
push!(losses, loss(X', y'))
l = @layout [ a b ]
p1 = bar(model.W', legend = false, title = "βLASSO weights")
p2 = plot(losses, legend = false, title = "Loss")
display( plot(p1, p2, layout = l) )
end
end
gif(anim, "plots/BLASSO wts.gif", fps = 60)
@caseykneale
Copy link
Author

caseykneale commented Nov 23, 2020

Awesome! I'll update the gist with your code.

If you read the post I linked (10 min paper review) I kind of discuss that beta parameter.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment