Skip to content

Instantly share code, notes, and snippets.

@bicycle1885
Created March 21, 2020 14:08
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save bicycle1885/105d39c5573a0bc2687ae1bdb414b522 to your computer and use it in GitHub Desktop.
Save bicycle1885/105d39c5573a0bc2687ae1bdb414b522 to your computer and use it in GitHub Desktop.
using LinearAlgebra
function main()
# input data
train_data = read_data("./ml-100k/u.data")
n_user = length(unique([t[1] for t in train_data]))
n_item = length(unique([t[2] for t in train_data]))
# parameters
P, Q = fit(n_user, n_item, train_data)
end
function read_data(file_path)
train_data = Tuple{Int,Int,Float32}[]
for l in eachline(file_path)
u, i, r, _ = parse.(Int, split(l))
push!(train_data, (u, i, r))
end
return train_data
end
function fit(n_user, n_item, train_data; n_itr=50, n_fac=5, γ=0.07f0, λ=0.01f0)
# init parameters
P = randn(Float32, n_fac, n_user)
Q = randn(Float32, n_fac, n_item)
# optimaize: SGD
for itr in 1:n_itr
loss = 0f0
for (u, i, r) in train_data
# calc error
pu, qi = P[:,u], Q[:,i]
e = r - pu ⋅ qi
@. Q[:,i] += γ * (e * pu - λ * qi)
@. P[:,u] += γ * (e * qi - λ * pu)
# calc loss
selfdot(x) = x ⋅ x
loss += e*e + λ * (selfdot(@view(P[:,u])) + selfdot(@view(Q[:,i])))
end
println("$itr: $loss")
end
return P, Q
end
using Random
Random.seed!(1234)
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment