Created
September 7, 2021 22:02
-
-
Save outlace/c81a15ab77c1c9cebb56003545ee5512 to your computer and use it in GitHub Desktop.
Julia prototype implementation of Sub-linear Deep Learning Engine (SLIDE); missing advanced hash table updating and other optimizations
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using Flux | |
using Zygote | |
using MLDatasets | |
using LSHFunctions | |
using DataStructures | |
using Plots; | |
using Profile; | |
using StatProfilerHTML; | |
using LinearAlgebra; | |
#Parameters | |
const batch_size = 16 | |
const k = 7 | |
const L = 10; | |
#const sample_rate = 0.1 # 1%, proportion of nodes to sample from matrix | |
function onehotencode(y) | |
t = zeros(Float64,10) | |
t[y[]+1] = 1.0 | |
return t | |
end | |
train_x, train_y = MNIST.traindata(); | |
train_x = permutedims(train_x,(3,2,1)) | |
x_train = reshape(train_x,(size(train_x)[1],prod(size(train_x)[2:end]))); | |
y_train = [transpose(onehotencode(y)) for y in train_y] | |
y_train = vcat(y_train...); | |
test_x, test_y = MNIST.testdata(); | |
test_x = permutedims(test_x,(3,2,1)) | |
x_test = reshape(test_x,(size(test_x)[1],prod(size(test_x)[2:end]))); | |
y_test = [transpose(onehotencode(y)) for y in test_y] | |
y_test = vcat(y_test...); | |
moving_average(vs,n) = [sum(@view vs[i:(i+n-1)])/n for i in 1:(length(vs)-(n-1))] | |
function comb(n, len) | |
Iterators.product(fill(BitArray([0,1]), len)...) |> collect |> vec | |
end | |
all_hash_codes = comb(1,k); | |
mutable struct Layer | |
theta::Matrix | |
bias::Matrix | |
hash_funs::Vector{SimHash} | |
hash_tables::Vector{Dict{Tuple,CircularBuffer{Integer}}} | |
end | |
#The `Layer` type constructor | |
function Layer(in_dim::Integer,out_dim::Integer,k=6,L=6,bin_size=10) | |
theta = randn(in_dim,out_dim) / 2 | |
bias = randn(1,out_dim) / 5 | |
cols = size(theta)[2] #number of columns/nodes | |
hash_funs = [LSHFunction(cossim, k) for i in 1:L] | |
hash_tables = Vector{Dict{Tuple,CircularBuffer{Integer}}}() | |
for i in 1:L #create L hash tables | |
ht_l::Dict{Tuple,CircularBuffer{Integer}} = Dict{Tuple,CircularBuffer{Integer}}((x) => CircularBuffer{Integer}(bin_size) for x in all_hash_codes) | |
push!(hash_tables,ht_l) | |
for j in 1:cols | |
hash_lj = hash_funs[i](theta[:,j]) | |
hash_lj = Tuple(hash_lj) | |
push!(hash_tables[i][hash_lj],j) | |
end | |
end | |
Layer(theta,bias,hash_funs,hash_tables) | |
end | |
Flux.trainable(a::Layer) = (a.theta,a.bias) | |
#Define the function that runs a layer | |
function (m::Layer)(X::Matrix, cols::Vector, rows::Vector) | |
e1 = isempty(cols) | |
e2 = isempty(rows) | |
if e1 & e2 # neither rows nor cols given | |
y = X * m.theta .+ m.bias | |
elseif e2 #cols given alone | |
y = X * (@view m.theta[:,cols]) .+ (@view m.bias[:,cols]); | |
elseif e1 #rows given alone | |
y = X * (@view m.theta[rows,:]) .+ m.bias; | |
else # rows and cols given | |
y = X * (@view m.theta[rows,cols]) .+ (@view m.bias[:,cols]); | |
end | |
return y, cols, rows | |
end | |
#no activation function applied yet | |
layer1 = Layer(784,1000, k, L) #501ms for 500 cols, 917 ms for 1000 cols | |
layer2 = Layer(1000,10, k, L) | |
layers = [layer1,layer2]; | |
function sample_nodes(query::Vector, layer::Layer) | |
#`query` is the input vector for this layer | |
S = Set{Int64}() | |
for i in 1:L | |
# compute hash of query using each hashfun | |
q_hash = layer.hash_funs[i](query) |> Tuple | |
matches = layer.hash_tables[i][q_hash] | |
union!(S,matches) | |
end | |
if isempty(S) | |
push!(S,rand(1:size(layer.theta)[2])) | |
end | |
return S |> collect | |
end | |
function model(X::Matrix,layers::Vector{Layer},S::Vector) | |
#layer 1 | |
X = Flux.normalise(X;dims=ndims(X), ϵ=1e-5) | |
A1, cols1, rows1 = layers[1](X,S,Vector{Integer}[]) | |
A1 = NNlib.relu.(A1) | |
#layer 2 | |
A1 = Flux.normalise(A1;dims=ndims(A1), ϵ=1e-5) | |
A2, cols2, rows2 = layers[2](A1,Vector{Integer}[],cols1) | |
A2 = NNlib.softmax(A2,dims=2) | |
return A2 | |
end | |
#run_layer(batch_x,layers[1],sort([x for x in S]),Vector{Int64}([])) | |
model(randn(1,784), layers,[1,50,90,112,145,240,300,301,500]) | |
function update_htables(layer::Layer) | |
num_ht = length(layer.hash_funs) | |
cols = size(layer.theta)[2] | |
for i in 1:num_ht #iterate tables | |
for j in 1:cols | |
hash_lj = layer.hash_funs[i](layer.theta[:,j]) | |
hash_lj = Tuple(hash_lj) | |
push!(layer.hash_tables[i][hash_lj],j) | |
end | |
end | |
end | |
isempty(Set([])) | |
lossfn(ŷ::Vector,y::Vector) = -1.0 * LinearAlgebra.dot(log.(ŷ),y) | |
#S = [1,50,90,112,145,240,300,301,500,505,506,511]; | |
#g = Zygote.gradient(w -> lossfn(vec(model(randn(1,784),w,S)),[1.0,0,0,0,0,0,0,0,0,0]),layers) | |
#g = Zygote.gradient((ypred,ytrue) -> lossfn(ypred,ytrue),(A2,batch_y)) | |
#println(g); | |
#g[1][1][][:theta] | |
function train(x_train,y_train,epochs=200000) | |
opt = Descent(0.001) | |
lossarr = [] | |
for i in 1:epochs # assuming d looks like (data, labels) | |
# our super logic | |
rid = rand(1:60000) | |
x = float(x_train[rid,:]) | |
x = reshape(x,(1,784)) | |
S = sample_nodes(vec(x),layers[1]) | |
yt = y_train[rid,:] | |
ps = Flux.params(layers) | |
gs = gradient(ps) do #m is our model | |
ŷ = vec(model(x, layers, S)); | |
l = lossfn(ŷ,yt) | |
end | |
Flux.update!(opt, ps, gs) | |
if i % 50 == 0 | |
update_htables(layers[1]) | |
end | |
ŷ = vec(model(x, layers, S)); | |
l = lossfn(ŷ,yt) | |
push!(lossarr,l) | |
end | |
end | |
#@profilehtml train(x_train,y_train,1000) | |
train(x_train,y_train,100000) | |
plot(moving_average(lossarr,1000)) #y top is > 3, x right is 2 x 10^5 | |
function test_acc(xs,ys) | |
ncorr = 0 | |
ntot = size(ys)[1] | |
for i in 1:ntot | |
#rid = rand(1:size(xs)[1]) | |
x = float(xs[i,:]) | |
x = reshape(x,(1,784)) | |
yt = ys[i,:] | |
ŷ = vec(model(x, layers, [])); | |
if argmax(ŷ) == argmax(yt) | |
ncorr += 1 | |
end | |
end | |
println(100 * (ncorr/ntot)) | |
end | |
test_acc(x_test,y_test) | |
#~92% accuracy with 200k iterations |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment