Skip to content

Instantly share code, notes, and snippets.

@feyderm
Last active February 23, 2017 17:51
Show Gist options
  • Save feyderm/b415454761a825285653913a9975c935 to your computer and use it in GitHub Desktop.
Save feyderm/b415454761a825285653913a9975c935 to your computer and use it in GitHub Desktop.
Exploring Gradient Decent Parameters for Logistic Regression

Update of bl.ock to include user-supplied parameters and dragging of data points.

<!DOCTYPE html>
<meta charset="utf-8">
<style>
form {
font-size: 11px;
font-family: sans-serif;
}
input {
margin-right: 10px;
margin-bottom: 10px;
}
#button {
margin-left: 30px;
}
text {
font-family: sans-serif;
fill: #000000;
}
.pts {
stroke: #595959;
}
.group1 {
fill: steelblue;
}
.group2 {
fill: red;
}
#dec_boundary {
fill: none;
stroke: #000000;
stroke-width: 2px;
opacity: 0.6;
}
</style>
<body>
<form action="">
Number of Iterations: <input type="text" name="iterationNumber" value="400" size="4">
Learning Rate: <input type="text" name="alpha" value=0.0004 size="7"><br>
Theta 0: <input type="text" name="theta0" value=-24.0 size="6">
Theta 1: <input type="text" name="theta1" value=0.5 size="6">
Theta 2: <input type="text" name="theta2" value=0.2 size="6">
<input id="button" type="button" value="Submit" onClick=updateParams(this.form)><br>
</form>
<!--viz-->
<div id="chart"></div>
<script src="https://d3js.org/d3.v4.min.js"></script>
<script src="http://feyderm.github.io/math/math.js"></script>
<script type="text/javascript">
var margin = { top: 20, right: 0, bottom: 50, left: 85 },
svg_dx = 500,
svg_dy = 400,
plot_dx = svg_dx - margin.right - margin.left,
plot_dy = svg_dy - margin.top - margin.bottom;
var xPos = d3.scaleLinear().range([margin.left, plot_dx]),
yPos = d3.scaleLinear().range([plot_dy, margin.top]);
var svg = d3.select("#chart")
.append("svg")
.attr("width", svg_dx)
.attr("height", svg_dy);
d3.csv("logistic_reg_grad_decent.csv", d => {
var d_extent_x = d3.extent(d, d => +d.x),
d_extent_y = d3.extent(d, d => +d.y);
xPos.domain(d_extent_x);
yPos.domain(d_extent_y);
var axis_x = d3.axisBottom(xPos),
axis_y = d3.axisLeft(yPos);
svg.append("g")
.attr("id", "axis_x")
.attr("transform", "translate(0," + (plot_dy + margin.bottom / 2) + ")")
.call(axis_x);
svg.append("g")
.attr("id", "axis_y")
.attr("transform", "translate(" + (margin.left / 2) + ", 0)")
.call(axis_y);
svg.append("g")
.selectAll("path")
.data(d)
.enter()
.append("path")
.attr("class", d => d.group == "1" ? "pts group1" : "pts group2")
.attr("d", d3.symbol().type((d,i) => d.group == "1" ? d3.symbolCircle : d3.symbolCross))
.attr("transform", d => "translate(" + xPos(d.x) + "," + yPos(d.y) + ")")
.call(d3.drag()
.on("start", dragstarted)
.on("drag", dragged));
runGradientDescent(400, 0.0004, -24.0, 0.5, 0.2);
});
function dragstarted() {
d3.select(this).raise();
}
function dragged(d) {
var dx = d3.event.sourceEvent.offsetX,
dy = d3.event.sourceEvent.offsetY;
d3.select(this)
.attr("transform", d => "translate(" + dx + "," + dy + ")");
}
function sigmoid(z) {
var s = 1 / (1 + Math.pow(Math.E, -z));
return s;
}
function computeGradient(m, y, h, X) {
// conversion from octave of grad = (1 / m) * (h - y)' * X;
var grad = math.multiply(h.map((h, i) => h - math.subset(y, math.index(i))), X)
.map(d => (1 / m) * d);
return grad;
}
function updateParams(form) {
var iterationNumber = +form.iterationNumber.value,
alpha = +form.alpha.value,
theta0 = +form.theta0.value,
theta1 = +form.theta1.value,
theta2 = +form.theta2.value;
// remove previous decision boundary
d3.select("#dec_boundary").remove();
runGradientDescent(iterationNumber, alpha, theta0, theta1, theta2);
}
function runGradientDescent(iterationNumber, alpha, theta0, theta1, theta2) {
var coords = [],
group = [];
d3.selectAll(".pts")
.each(function() {
var pt = d3.select(this);
var xy_re = /\d+.?\d+,\d+.?\d+/;
// translated x and y values
var xy = pt.attr("transform")
.match(xy_re)[0]
.split(",");
coords.push(xy);
// group data
group.push(pt.data()[0].group);
});
var d = coords.map((coord, i) => {
return { "group": group[i],
"x": xPos.invert(+coord[0]),
"y": yPos.invert(+coord[1])
}
});
var d_extent_x = d3.extent(d, pt => +pt.x);
var X = d.map(pt => [1, +pt.x, +pt.y]),
y = d.map(pt => +pt.group);
X = math.matrix(X);
y = math.matrix(y);
var iteration = 0,
m = math.subset(math.size(X), math.index(0)),
theta = math.matrix([theta0, theta1, theta2])
var dec_bnd = svg.append("line")
.attr("id", "dec_boundary");
var iterate = d3.timer(() => {
var h = math.multiply(X, theta).map(z => sigmoid(z)),
grad = computeGradient(m, y, h, X);
// update theta
theta = theta.map((t, i) => t - (alpha * math.subset(grad, math.index(i))))
var theta0 = math.subset(theta, math.index(0)),
theta1 = math.subset(theta, math.index(1)),
theta2 = math.subset(theta, math.index(2));
dec_bnd.attr("x1",xPos(d_extent_x[0]))
.attr("y1",yPos((-1 / theta2) * (theta1 * d_extent_x[0] + theta0)))
.attr("x2",xPos(d_extent_x[1]))
.attr("y2",yPos((-1 / theta2) * (theta1 * (d_extent_x[1] * .95) + theta0)));
if (iteration++ > iterationNumber) {
iterate.stop();
}
}, 200)
}
</script>
</body>
x y group
34.62365962451697 78.0246928153624 0
30.28671076822607 43.89499752400101 0
35.84740876993872 72.90219802708364 0
60.18259938620976 86.30855209546826 1
79.0327360507101 75.3443764369103 1
45.08327747668339 56.3163717815305 0
61.10666453684766 96.51142588489624 1
75.02474556738889 46.55401354116538 1
76.09878670226257 87.42056971926803 1
84.43281996120035 43.53339331072109 1
95.86155507093572 38.22527805795094 0
75.01365838958247 30.60326323428011 0
82.30705337399482 76.48196330235604 1
69.36458875970939 97.71869196188608 1
39.53833914367223 76.03681085115882 0
53.9710521485623 89.20735013750205 1
69.07014406283025 52.74046973016765 1
67.94685547711617 46.67857410673128 0
70.66150955499435 92.92713789364831 1
76.97878372747498 47.57596364975532 1
67.37202754570876 42.83843832029179 0
89.67677575072079 65.79936592745237 1
50.534788289883 48.85581152764205 0
34.21206097786789 44.20952859866288 0
77.9240914545704 68.9723599933059 1
62.27101367004632 69.95445795447587 1
80.1901807509566 44.82162893218353 1
93.114388797442 38.80067033713209 0
61.83020602312595 50.25610789244621 0
38.78580379679423 64.99568095539578 0
61.379289447425 72.80788731317097 1
85.40451939411645 57.05198397627122 1
52.10797973193984 63.12762376881715 0
52.04540476831827 69.43286012045222 1
40.23689373545111 71.16774802184875 0
54.63510555424817 52.21388588061123 0
33.91550010906887 98.86943574220611 0
64.17698887494485 80.90806058670817 1
74.78925295941542 41.57341522824434 0
34.1836400264419 75.2377203360134 0
83.90239366249155 56.30804621605327 1
51.54772026906181 46.85629026349976 0
94.44336776917852 65.56892160559052 1
82.36875375713919 40.61825515970618 0
51.04775177128865 45.82270145776001 0
62.22267576120188 52.06099194836679 0
77.19303492601364 70.45820000180959 1
97.77159928000232 86.7278223300282 1
62.07306379667647 96.76882412413983 1
91.56497449807442 88.69629254546599 1
79.94481794066932 74.16311935043758 1
99.2725269292572 60.99903099844988 1
90.54671411399852 43.39060180650027 1
34.52451385320009 60.39634245837173 0
50.2864961189907 49.80453881323059 0
49.58667721632031 59.80895099453265 0
97.64563396007767 68.86157272420604 1
32.57720016809309 95.59854761387875 0
74.24869136721598 69.82457122657193 1
71.79646205863379 78.45356224515052 1
75.3956114656803 85.75993667331619 1
35.28611281526193 47.02051394723416 0
56.25381749711624 39.26147251058019 0
30.05882244669796 49.59297386723685 0
44.66826172480893 66.45008614558913 0
66.56089447242954 41.09209807936973 0
40.45755098375164 97.53518548909936 1
49.07256321908844 51.88321182073966 0
80.27957401466998 92.11606081344084 1
66.74671856944039 60.99139402740988 1
32.72283304060323 43.30717306430063 0
64.0393204150601 78.03168802018232 1
72.34649422579923 96.22759296761404 1
60.45788573918959 73.09499809758037 1
58.84095621726802 75.85844831279042 1
99.82785779692128 72.36925193383885 1
47.26426910848174 88.47586499559782 1
50.45815980285988 75.80985952982456 1
60.45555629271532 42.50840943572217 0
82.22666157785568 42.71987853716458 0
88.9138964166533 69.80378889835472 1
94.83450672430196 45.69430680250754 1
67.31925746917527 66.58935317747915 1
57.23870631569862 59.51428198012956 1
80.36675600171273 90.96014789746954 1
68.46852178591112 85.59430710452014 1
42.0754545384731 78.84478600148043 0
75.47770200533905 90.42453899753964 1
78.63542434898018 96.64742716885644 1
52.34800398794107 60.76950525602592 0
94.09433112516793 77.15910509073893 1
90.44855097096364 87.50879176484702 1
55.48216114069585 35.57070347228866 0
74.49269241843041 84.84513684930135 1
89.84580670720979 45.35828361091658 1
83.48916274498238 48.38028579728175 1
42.2617008099817 87.10385094025457 1
99.31500880510394 68.77540947206617 1
55.34001756003703 64.9319380069486 1
74.77589300092767 89.52981289513276 1
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment