Last active
May 6, 2016 09:27
-
-
Save apetrushin/e0ec98ca36cb3a668b081718ba47c1e0 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using PyPlot | |
using Optim | |
function scalar(x) | |
if length(x) != 1 | |
throw("error, not a scalar but $x") | |
end | |
x[1] | |
end; | |
# Preparing data. | |
data = [34.62365962451697 78.0246928153624 0; | |
30.28671076822607 43.89499752400101 0; | |
35.84740876993872 72.90219802708364 0; | |
60.18259938620976 86.30855209546826 1; | |
79.0327360507101 75.3443764369103 1; | |
45.08327747668339 56.3163717815305 0; | |
61.10666453684766 96.51142588489624 1; | |
75.02474556738889 46.55401354116538 1; | |
76.09878670226257 87.42056971926803 1; | |
84.43281996120035 43.53339331072109 1; | |
95.86155507093572 38.22527805795094 0; | |
75.01365838958247 30.60326323428011 0; | |
82.30705337399482 76.48196330235604 1; | |
69.36458875970939 97.71869196188608 1; | |
39.53833914367223 76.03681085115882 0; | |
53.9710521485623 89.20735013750205 1; | |
69.07014406283025 52.74046973016765 1; | |
67.94685547711617 46.67857410673128 0; | |
70.66150955499435 92.92713789364831 1; | |
76.97878372747498 47.57596364975532 1; | |
67.37202754570876 42.83843832029179 0; | |
89.67677575072079 65.79936592745237 1; | |
50.534788289883 48.85581152764205 0; | |
34.21206097786789 44.20952859866288 0; | |
77.9240914545704 68.9723599933059 1; | |
62.27101367004632 69.95445795447587 1; | |
80.1901807509566 44.82162893218353 1; | |
93.114388797442 38.80067033713209 0; | |
61.83020602312595 50.25610789244621 0; | |
38.78580379679423 64.99568095539578 0; | |
61.379289447425 72.80788731317097 1; | |
85.40451939411645 57.05198397627122 1; | |
52.10797973193984 63.12762376881715 0; | |
52.04540476831827 69.43286012045222 1; | |
40.23689373545111 71.16774802184875 0; | |
54.63510555424817 52.21388588061123 0; | |
33.91550010906887 98.86943574220611 0; | |
64.17698887494485 80.90806058670817 1; | |
74.78925295941542 41.57341522824434 0; | |
34.1836400264419 75.2377203360134 0; | |
83.90239366249155 56.30804621605327 1; | |
51.54772026906181 46.85629026349976 0; | |
94.44336776917852 65.56892160559052 1; | |
82.36875375713919 40.61825515970618 0; | |
51.04775177128865 45.82270145776001 0; | |
62.22267576120188 52.06099194836679 0; | |
77.19303492601364 70.45820000180959 1; | |
97.77159928000232 86.7278223300282 1; | |
62.07306379667647 96.76882412413983 1; | |
91.56497449807442 88.69629254546599 1; | |
79.94481794066932 74.16311935043758 1; | |
99.2725269292572 60.99903099844988 1; | |
90.54671411399852 43.39060180650027 1; | |
34.52451385320009 60.39634245837173 0; | |
50.2864961189907 49.80453881323059 0; | |
49.58667721632031 59.80895099453265 0; | |
97.64563396007767 68.86157272420604 1; | |
32.57720016809309 95.59854761387875 0; | |
74.24869136721598 69.82457122657193 1; | |
71.79646205863379 78.45356224515052 1; | |
75.3956114656803 85.75993667331619 1; | |
35.28611281526193 47.02051394723416 0; | |
56.25381749711624 39.26147251058019 0; | |
30.05882244669796 49.59297386723685 0; | |
44.66826172480893 66.45008614558913 0; | |
66.56089447242954 41.09209807936973 0; | |
40.45755098375164 97.53518548909936 1; | |
49.07256321908844 51.88321182073966 0; | |
80.27957401466998 92.11606081344084 1; | |
66.74671856944039 60.99139402740988 1; | |
32.72283304060323 43.30717306430063 0; | |
64.0393204150601 78.03168802018232 1; | |
72.34649422579923 96.22759296761404 1; | |
60.45788573918959 73.09499809758037 1; | |
58.84095621726802 75.85844831279042 1; | |
99.82785779692128 72.36925193383885 1; | |
47.26426910848174 88.47586499559782 1; | |
50.45815980285988 75.80985952982456 1; | |
60.45555629271532 42.50840943572217 0; | |
82.22666157785568 42.71987853716458 0; | |
88.9138964166533 69.80378889835472 1; | |
94.83450672430196 45.69430680250754 1; | |
67.31925746917527 66.58935317747915 1; | |
57.23870631569862 59.51428198012956 1; | |
80.36675600171273 90.96014789746954 1; | |
68.46852178591112 85.59430710452014 1; | |
42.0754545384731 78.84478600148043 0; | |
75.47770200533905 90.42453899753964 1; | |
78.63542434898018 96.64742716885644 1; | |
52.34800398794107 60.76950525602592 0; | |
94.09433112516793 77.15910509073893 1; | |
90.44855097096364 87.50879176484702 1; | |
55.48216114069585 35.57070347228866 0; | |
74.49269241843041 84.84513684930135 1; | |
89.84580670720979 45.35828361091658 1; | |
83.48916274498238 48.38028579728175 1; | |
42.2617008099817 87.10385094025457 1; | |
99.31500880510394 68.77540947206617 1; | |
55.34001756003703 64.9319380069486 1; | |
74.77589300092767 89.52981289513276 1;] | |
(m, _) = size(data) | |
x = Array(Float64, m, 3) | |
x[:, 1] = 1 | |
x[:, 2:3] = float(data[:, 1:2]) | |
y = round(Int, data[:, 3]) | |
# Plotting success and fails. | |
x_success = x[y .== 1, 2:3] | |
x_fails = x[y .== 0, 2:3] | |
# Sigmoid function. | |
sigmoid(z) = 1 / (1 + e ^ -z) | |
# Defining the hypotesis, cost and gradient functions J(Theta) | |
hypotesis(theta, x) = sigmoid(scalar(theta' * x)) | |
function gradient(theta, x, y) | |
(m, n) = size(x) | |
h = [hypotesis(theta, x[i,:]') for i in 1:m] | |
g = Array(Float64, n, 1) | |
for j in 1:n | |
g[j] = sum([(h[i] - y[i]) * x[i, j] for i in 1:m]) | |
end | |
g | |
end | |
function cost_j(theta, x, y) | |
m = length(y) | |
h = [hypotesis(theta, x[i,:]') for i in 1:m] | |
sum((-y)'*log(h) - (1 - y)' * log(1 - h)) / m | |
end | |
# Optimizing with gradient descent, for unknown reason it's not working. | |
function gradient_descent(x, y, theta, alpha, n_iterations) | |
m = length(y) | |
j_history = Array(Float64, n_iterations) | |
for i = 1:n_iterations | |
j_history[i] = cost_j(theta, x, y) | |
theta = theta - alpha * gradient(theta, x, y) | |
end | |
theta, j_history | |
end; | |
optimal_theta = gradient_descent(x, y, [0 0 0]', 0.0001, 400)[1] | |
println("Calculated with Gradient Descent (the wrong one): ", cost_j(optimal_theta, x, y)) | |
# Optimizing with gradient descent, for unknown reason it's not working. | |
cost_j_with_xy(theta) = cost_j(theta, x, y) | |
optimal_theta = optimize(cost_j_with_xy, [0.0, 0.0, 0.0]).minimum | |
println("Calculated with Optim package (the right one): ", cost_j(optimal_theta, x, y)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment