Skip to content

Instantly share code, notes, and snippets.

@masaponto
Last active October 18, 2015 10:15
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save masaponto/da88aef6a5c60a247eb8 to your computer and use it in GitHub Desktop.
Save masaponto/da88aef6a5c60a247eb8 to your computer and use it in GitHub Desktop.
#Pkg.add("RDatasets")
using Base.Test
using RDatasets
type Model
a_vs::Array{Float64, 2}
beta_vs::Array{Float64, 1}
end
function sigmoid(x)
return 1 / (1 + exp(-x))
end
function add_bias(x::Array{Float64, 2})
bs = fill(1,size(x)[1])
return hcat(x,bs)
end
@test add_bias( [1.0 2.0; 3.0 4.0] ) == [1.0 2.0 1.0 ; 3.0 4.0 1.0]
function elm(X::Array{Float64, 2}, y::Array{Float64, 1}, hid_num::Int)
x_vs = add_bias(X)
a_vs = rand(size(x_vs)[2],hid_num) * 2 - 1
h_t = pinv( map(sigmoid, x_vs * a_vs) )
beta_vs = h_t * y
return Model(a_vs, beta_vs)
end
function predict(model::Model, x::Array{Float64, 2})
x_vs = add_bias(x)
println( (x_vs * model.a_vs) * model.beta_vs )
return sign(map(sigmoid, (x_vs * model.a_vs)) * model.beta_vs)
end
function xor()
X = [1.0 1.0; 1.0 0.0; 0.0 1.0; 0.0 0.0]
y = [1.0; -1.0; -1.0; 1.0]
hid_num = 10
model = elm(X, y, hid_num)
r = predict(model, X)
println(y)
println(r)
end
#function iris()
# iris = dataset("datasets", "iris")
# println(iris)
# X = matrix(iris([:, 1:4]))'
# print(X)
#end
function main()
xor()
#iris()
end
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment