Skip to content

Instantly share code, notes, and snippets.

@outlace
Created September 7, 2021 22:02
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save outlace/c81a15ab77c1c9cebb56003545ee5512 to your computer and use it in GitHub Desktop.
Save outlace/c81a15ab77c1c9cebb56003545ee5512 to your computer and use it in GitHub Desktop.
Julia prototype implementation of Sub-linear Deep Learning Engine (SLIDE); missing advanced hash table updating and other optimizations
using Flux
using Zygote
using MLDatasets
using LSHFunctions
using DataStructures
using Plots;
using Profile;
using StatProfilerHTML;
using LinearAlgebra;
#Parameters
const batch_size = 16
const k = 7
const L = 10;
#const sample_rate = 0.1 # 1%, proportion of nodes to sample from matrix
function onehotencode(y)
t = zeros(Float64,10)
t[y[]+1] = 1.0
return t
end
train_x, train_y = MNIST.traindata();
train_x = permutedims(train_x,(3,2,1))
x_train = reshape(train_x,(size(train_x)[1],prod(size(train_x)[2:end])));
y_train = [transpose(onehotencode(y)) for y in train_y]
y_train = vcat(y_train...);
test_x, test_y = MNIST.testdata();
test_x = permutedims(test_x,(3,2,1))
x_test = reshape(test_x,(size(test_x)[1],prod(size(test_x)[2:end])));
y_test = [transpose(onehotencode(y)) for y in test_y]
y_test = vcat(y_test...);
moving_average(vs,n) = [sum(@view vs[i:(i+n-1)])/n for i in 1:(length(vs)-(n-1))]
function comb(n, len)
Iterators.product(fill(BitArray([0,1]), len)...) |> collect |> vec
end
all_hash_codes = comb(1,k);
mutable struct Layer
theta::Matrix
bias::Matrix
hash_funs::Vector{SimHash}
hash_tables::Vector{Dict{Tuple,CircularBuffer{Integer}}}
end
#The `Layer` type constructor
function Layer(in_dim::Integer,out_dim::Integer,k=6,L=6,bin_size=10)
theta = randn(in_dim,out_dim) / 2
bias = randn(1,out_dim) / 5
cols = size(theta)[2] #number of columns/nodes
hash_funs = [LSHFunction(cossim, k) for i in 1:L]
hash_tables = Vector{Dict{Tuple,CircularBuffer{Integer}}}()
for i in 1:L #create L hash tables
ht_l::Dict{Tuple,CircularBuffer{Integer}} = Dict{Tuple,CircularBuffer{Integer}}((x) => CircularBuffer{Integer}(bin_size) for x in all_hash_codes)
push!(hash_tables,ht_l)
for j in 1:cols
hash_lj = hash_funs[i](theta[:,j])
hash_lj = Tuple(hash_lj)
push!(hash_tables[i][hash_lj],j)
end
end
Layer(theta,bias,hash_funs,hash_tables)
end
Flux.trainable(a::Layer) = (a.theta,a.bias)
#Define the function that runs a layer
function (m::Layer)(X::Matrix, cols::Vector, rows::Vector)
e1 = isempty(cols)
e2 = isempty(rows)
if e1 & e2 # neither rows nor cols given
y = X * m.theta .+ m.bias
elseif e2 #cols given alone
y = X * (@view m.theta[:,cols]) .+ (@view m.bias[:,cols]);
elseif e1 #rows given alone
y = X * (@view m.theta[rows,:]) .+ m.bias;
else # rows and cols given
y = X * (@view m.theta[rows,cols]) .+ (@view m.bias[:,cols]);
end
return y, cols, rows
end
#no activation function applied yet
layer1 = Layer(784,1000, k, L) #501ms for 500 cols, 917 ms for 1000 cols
layer2 = Layer(1000,10, k, L)
layers = [layer1,layer2];
function sample_nodes(query::Vector, layer::Layer)
#`query` is the input vector for this layer
S = Set{Int64}()
for i in 1:L
# compute hash of query using each hashfun
q_hash = layer.hash_funs[i](query) |> Tuple
matches = layer.hash_tables[i][q_hash]
union!(S,matches)
end
if isempty(S)
push!(S,rand(1:size(layer.theta)[2]))
end
return S |> collect
end
function model(X::Matrix,layers::Vector{Layer},S::Vector)
#layer 1
X = Flux.normalise(X;dims=ndims(X), ϵ=1e-5)
A1, cols1, rows1 = layers[1](X,S,Vector{Integer}[])
A1 = NNlib.relu.(A1)
#layer 2
A1 = Flux.normalise(A1;dims=ndims(A1), ϵ=1e-5)
A2, cols2, rows2 = layers[2](A1,Vector{Integer}[],cols1)
A2 = NNlib.softmax(A2,dims=2)
return A2
end
#run_layer(batch_x,layers[1],sort([x for x in S]),Vector{Int64}([]))
model(randn(1,784), layers,[1,50,90,112,145,240,300,301,500])
function update_htables(layer::Layer)
num_ht = length(layer.hash_funs)
cols = size(layer.theta)[2]
for i in 1:num_ht #iterate tables
for j in 1:cols
hash_lj = layer.hash_funs[i](layer.theta[:,j])
hash_lj = Tuple(hash_lj)
push!(layer.hash_tables[i][hash_lj],j)
end
end
end
isempty(Set([]))
lossfn(ŷ::Vector,y::Vector) = -1.0 * LinearAlgebra.dot(log.(ŷ),y)
#S = [1,50,90,112,145,240,300,301,500,505,506,511];
#g = Zygote.gradient(w -> lossfn(vec(model(randn(1,784),w,S)),[1.0,0,0,0,0,0,0,0,0,0]),layers)
#g = Zygote.gradient((ypred,ytrue) -> lossfn(ypred,ytrue),(A2,batch_y))
#println(g);
#g[1][1][][:theta]
function train(x_train,y_train,epochs=200000)
opt = Descent(0.001)
lossarr = []
for i in 1:epochs # assuming d looks like (data, labels)
# our super logic
rid = rand(1:60000)
x = float(x_train[rid,:])
x = reshape(x,(1,784))
S = sample_nodes(vec(x),layers[1])
yt = y_train[rid,:]
ps = Flux.params(layers)
gs = gradient(ps) do #m is our model
ŷ = vec(model(x, layers, S));
l = lossfn(ŷ,yt)
end
Flux.update!(opt, ps, gs)
if i % 50 == 0
update_htables(layers[1])
end
ŷ = vec(model(x, layers, S));
l = lossfn(ŷ,yt)
push!(lossarr,l)
end
end
#@profilehtml train(x_train,y_train,1000)
train(x_train,y_train,100000)
plot(moving_average(lossarr,1000)) #y top is > 3, x right is 2 x 10^5
function test_acc(xs,ys)
ncorr = 0
ntot = size(ys)[1]
for i in 1:ntot
#rid = rand(1:size(xs)[1])
x = float(xs[i,:])
x = reshape(x,(1,784))
yt = ys[i,:]
ŷ = vec(model(x, layers, []));
if argmax(ŷ) == argmax(yt)
ncorr += 1
end
end
println(100 * (ncorr/ntot))
end
test_acc(x_test,y_test)
#~92% accuracy with 200k iterations
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment