Skip to content

Instantly share code, notes, and snippets.

@apetrushin
Last active May 6, 2016 09:27
Show Gist options
  • Save apetrushin/e0ec98ca36cb3a668b081718ba47c1e0 to your computer and use it in GitHub Desktop.
Save apetrushin/e0ec98ca36cb3a668b081718ba47c1e0 to your computer and use it in GitHub Desktop.
using PyPlot
using Optim
function scalar(x)
if length(x) != 1
throw("error, not a scalar but $x")
end
x[1]
end;
# Preparing data.
data = [34.62365962451697 78.0246928153624 0;
30.28671076822607 43.89499752400101 0;
35.84740876993872 72.90219802708364 0;
60.18259938620976 86.30855209546826 1;
79.0327360507101 75.3443764369103 1;
45.08327747668339 56.3163717815305 0;
61.10666453684766 96.51142588489624 1;
75.02474556738889 46.55401354116538 1;
76.09878670226257 87.42056971926803 1;
84.43281996120035 43.53339331072109 1;
95.86155507093572 38.22527805795094 0;
75.01365838958247 30.60326323428011 0;
82.30705337399482 76.48196330235604 1;
69.36458875970939 97.71869196188608 1;
39.53833914367223 76.03681085115882 0;
53.9710521485623 89.20735013750205 1;
69.07014406283025 52.74046973016765 1;
67.94685547711617 46.67857410673128 0;
70.66150955499435 92.92713789364831 1;
76.97878372747498 47.57596364975532 1;
67.37202754570876 42.83843832029179 0;
89.67677575072079 65.79936592745237 1;
50.534788289883 48.85581152764205 0;
34.21206097786789 44.20952859866288 0;
77.9240914545704 68.9723599933059 1;
62.27101367004632 69.95445795447587 1;
80.1901807509566 44.82162893218353 1;
93.114388797442 38.80067033713209 0;
61.83020602312595 50.25610789244621 0;
38.78580379679423 64.99568095539578 0;
61.379289447425 72.80788731317097 1;
85.40451939411645 57.05198397627122 1;
52.10797973193984 63.12762376881715 0;
52.04540476831827 69.43286012045222 1;
40.23689373545111 71.16774802184875 0;
54.63510555424817 52.21388588061123 0;
33.91550010906887 98.86943574220611 0;
64.17698887494485 80.90806058670817 1;
74.78925295941542 41.57341522824434 0;
34.1836400264419 75.2377203360134 0;
83.90239366249155 56.30804621605327 1;
51.54772026906181 46.85629026349976 0;
94.44336776917852 65.56892160559052 1;
82.36875375713919 40.61825515970618 0;
51.04775177128865 45.82270145776001 0;
62.22267576120188 52.06099194836679 0;
77.19303492601364 70.45820000180959 1;
97.77159928000232 86.7278223300282 1;
62.07306379667647 96.76882412413983 1;
91.56497449807442 88.69629254546599 1;
79.94481794066932 74.16311935043758 1;
99.2725269292572 60.99903099844988 1;
90.54671411399852 43.39060180650027 1;
34.52451385320009 60.39634245837173 0;
50.2864961189907 49.80453881323059 0;
49.58667721632031 59.80895099453265 0;
97.64563396007767 68.86157272420604 1;
32.57720016809309 95.59854761387875 0;
74.24869136721598 69.82457122657193 1;
71.79646205863379 78.45356224515052 1;
75.3956114656803 85.75993667331619 1;
35.28611281526193 47.02051394723416 0;
56.25381749711624 39.26147251058019 0;
30.05882244669796 49.59297386723685 0;
44.66826172480893 66.45008614558913 0;
66.56089447242954 41.09209807936973 0;
40.45755098375164 97.53518548909936 1;
49.07256321908844 51.88321182073966 0;
80.27957401466998 92.11606081344084 1;
66.74671856944039 60.99139402740988 1;
32.72283304060323 43.30717306430063 0;
64.0393204150601 78.03168802018232 1;
72.34649422579923 96.22759296761404 1;
60.45788573918959 73.09499809758037 1;
58.84095621726802 75.85844831279042 1;
99.82785779692128 72.36925193383885 1;
47.26426910848174 88.47586499559782 1;
50.45815980285988 75.80985952982456 1;
60.45555629271532 42.50840943572217 0;
82.22666157785568 42.71987853716458 0;
88.9138964166533 69.80378889835472 1;
94.83450672430196 45.69430680250754 1;
67.31925746917527 66.58935317747915 1;
57.23870631569862 59.51428198012956 1;
80.36675600171273 90.96014789746954 1;
68.46852178591112 85.59430710452014 1;
42.0754545384731 78.84478600148043 0;
75.47770200533905 90.42453899753964 1;
78.63542434898018 96.64742716885644 1;
52.34800398794107 60.76950525602592 0;
94.09433112516793 77.15910509073893 1;
90.44855097096364 87.50879176484702 1;
55.48216114069585 35.57070347228866 0;
74.49269241843041 84.84513684930135 1;
89.84580670720979 45.35828361091658 1;
83.48916274498238 48.38028579728175 1;
42.2617008099817 87.10385094025457 1;
99.31500880510394 68.77540947206617 1;
55.34001756003703 64.9319380069486 1;
74.77589300092767 89.52981289513276 1;]
(m, _) = size(data)
x = Array(Float64, m, 3)
x[:, 1] = 1
x[:, 2:3] = float(data[:, 1:2])
y = round(Int, data[:, 3])
# Plotting success and fails.
x_success = x[y .== 1, 2:3]
x_fails = x[y .== 0, 2:3]
# Sigmoid function.
sigmoid(z) = 1 / (1 + e ^ -z)
# Defining the hypotesis, cost and gradient functions J(Theta)
hypotesis(theta, x) = sigmoid(scalar(theta' * x))
function gradient(theta, x, y)
(m, n) = size(x)
h = [hypotesis(theta, x[i,:]') for i in 1:m]
g = Array(Float64, n, 1)
for j in 1:n
g[j] = sum([(h[i] - y[i]) * x[i, j] for i in 1:m])
end
g
end
function cost_j(theta, x, y)
m = length(y)
h = [hypotesis(theta, x[i,:]') for i in 1:m]
sum((-y)'*log(h) - (1 - y)' * log(1 - h)) / m
end
# Optimizing with gradient descent, for unknown reason it's not working.
function gradient_descent(x, y, theta, alpha, n_iterations)
m = length(y)
j_history = Array(Float64, n_iterations)
for i = 1:n_iterations
j_history[i] = cost_j(theta, x, y)
theta = theta - alpha * gradient(theta, x, y)
end
theta, j_history
end;
optimal_theta = gradient_descent(x, y, [0 0 0]', 0.0001, 400)[1]
println("Calculated with Gradient Descent (the wrong one): ", cost_j(optimal_theta, x, y))
# Optimizing with gradient descent, for unknown reason it's not working.
cost_j_with_xy(theta) = cost_j(theta, x, y)
optimal_theta = optimize(cost_j_with_xy, [0.0, 0.0, 0.0]).minimum
println("Calculated with Optim package (the right one): ", cost_j(optimal_theta, x, y))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment