Last active
November 16, 2018 01:15
-
-
Save duhaime/e0058c62cf734b60a392de18c1e94905 to your computer and use it in GitHub Desktop.
Visualizing L1 Loss
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
license: MIT | |
height: 690 | |
scrolling: no | |
border: yes |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
<!DOCTYPE html> | |
<html> | |
<head> | |
<meta charset='UTF-8'> | |
<title>L1 Error with Random Linear Fits</title> | |
<style> | |
.container { | |
width: 700px; | |
position: relative; | |
margin: 0 20px; | |
} | |
.container * { | |
max-width: 100%; | |
} | |
button { | |
margin: 0 auto; | |
display: block; | |
padding: 10px; | |
border-radius: 3px; | |
font-size: 1em; | |
background: coral; | |
color: #fff; | |
} | |
div#error-label { | |
position: absolute; | |
top: 20px; | |
left: 20px; | |
font-size: 30px; | |
font-family: courier; | |
color: green; | |
} | |
#loss { | |
position: absolute; | |
top: 3px; | |
left: 25px; | |
font-family: courier; | |
} | |
h2 { | |
font-size: 30px; | |
color: purple; | |
margin: 0; | |
} | |
label { | |
width: 100%; | |
text-align: center; | |
display: block; | |
} | |
#accuracy { | |
position: absolute; | |
top: 50px; | |
} | |
</style> | |
</head> | |
<body> | |
<div class='container'> | |
<div id='loss'> | |
<h2>Loss</h2> | |
<div id='accuracy'> | |
<label>Offset</label> | |
<input id='acc' type='range' min='1' max='100' value='1'> | |
</div> | |
</div> | |
<svg/> | |
<button>Run Model</button> | |
<p>A <b><a target='_blank' href='https://en.wikipedia.org/wiki/Loss_function'>loss function</a></b> calculates the accuracy of a machine learning model by measuring the difference between the values the model predicts and the corresponding 'groundtruth' (or known data) values. One of the simplest loss functions is known as the <b>L1 loss</b>:</p> | |
<div>$$L1\ Loss = \sum_{i=1}^n |\ y_i-\hat{y}_i\ |$$</div> | |
<p>This equation says that to compute the L1 loss of a machine learning model, one must examine each input data point, find the difference between that data point's true y value and the model's prediction ŷ, take the absolute value of that difference, then sum those absolute differences. As the model grows more accurate, that resulting sum will shrink.</p> | |
</div> | |
<script type='text/javascript' src='https://d3js.org/d3.v5.min.js'></script> | |
<script type='text/javascript' src='https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/MathJax.js?config=TeX-MML-AM_CHTML'></script> | |
<script> | |
// store the estimate of the line globally | |
var estimate = {m: null, b: null}; | |
// specify data params | |
var data = [], // fake data | |
m = 2, // slope | |
b = 10, // y intercept | |
n = 15, // n observations | |
noise = 4; // noise to include | |
// generate fake data | |
for (var x=0; x<n; x++) { | |
data.push({ | |
x: x, | |
y: (m * x) + b + rand(noise), | |
}) | |
} | |
// set chart sizes | |
var width = 700, | |
height = 400, | |
margin = {left: 130, top: 130, right: 120, bottom: 20}; | |
// set duration of transitions | |
var duration = 1000, | |
delay = 40, | |
errorDelay = n * delay * 2.5; | |
var domains = { | |
x: d3.extent(data, function(d) { return d.x }), | |
y: d3.extent(data, function(d) { return d.y }), | |
} | |
var x = d3.scaleLinear() | |
.domain(domains.x) | |
.range([margin.left, width-margin.right]) | |
var y = d3.scaleLinear() | |
.domain(domains.y) | |
.range([height-margin.bottom, margin.top]) | |
var xAxis = d3.axisBottom().scale(x).ticks(5), | |
yAxis = d3.axisLeft().scale(y).ticks(5); | |
var svg = d3.select('svg') | |
.attr('width', width) | |
.attr('height', height) | |
// specify destination for error term transitions | |
var errorY = 30; | |
function drawPoints() { | |
svg.selectAll('circle').data(data).enter() | |
.append('circle') | |
.attr('cx', function(d) { return x(d.x) }) | |
.attr('cy', function(d) { return y(d.y) }) | |
.attr('r', 3) | |
.attr('fill', '#4682B4') // blue | |
svg.append('g') | |
.attr('class', 'x axis') | |
.attr('transform', 'translate(0,' + | |
(height - margin.bottom) + ')') | |
.call(xAxis); | |
svg.append('g') | |
.attr('class', 'y axis') | |
.attr('transform', 'translate(' + margin.left + ', 0)') | |
.call(yAxis); | |
} | |
// draw a random linear model for the points | |
// and show the l1 norm distance computed | |
function drawEstimate() { | |
// clear the svg and redraw | |
d3.selectAll('circle').remove(); | |
d3.selectAll('line').remove(); | |
d3.selectAll('text').remove(); | |
drawPoints(); | |
// mutate the estimate with random params | |
var acc = document.querySelector('#acc').value, | |
scaled = (100-acc)/100; | |
estimate.m = m + scaled/2; | |
estimate.b = b + scaled/2; | |
// draw the line then animate it into position | |
var line = svg.append('line') | |
.attr('x1', x(data[0].x)) | |
.attr('x2', x(data[0].x)) | |
.attr('y1', y(estimate.m * data[0].x + estimate.b)) | |
.attr('y2', y(estimate.m * data[0].x + estimate.b)) | |
.attr('fill', 'none') | |
.attr('stroke', 'red') | |
line.transition() | |
.duration(duration) | |
.attr('x2', x(data[n-1].x)) | |
.attr('y2', y(estimate.m * data[n-1].x + estimate.b)) | |
// draw the l1 errors for each observation | |
drawL1Errors() | |
} | |
// for each observation, draw a vertical line between | |
// that observation and the line of best fit, then show | |
// a numerical representation of the error for that value | |
// then sum up those error values in a big old equation | |
function drawL1Errors() { | |
var errorLines = svg.selectAll('.error-line').data(data).enter() | |
.append('line') | |
.attr('class', 'error-line') | |
.attr('x1', function(d) { return x(d.x) }) | |
.attr('x2', function(d) { return x(d.x) }) | |
.attr('y1', function(d) { return y(d.y) }) | |
.attr('y2', function(d) { return y(d.y) }) | |
.attr('stroke', 'purple') | |
.attr('stroke-dasharray', '4') | |
// transition the line into place | |
errorLines.transition() | |
.duration(duration) | |
.delay(function(d, idx) { return idx * delay }) | |
.attr('y2', function(d) { return y(getEstimate(d.x)) }) | |
// draw the error terms for each observation | |
var errorTexts = svg.selectAll('.error-nums').data(data).enter() | |
.append('text') | |
.attr('x', function(d) { return x(d.x)+4 }) | |
.attr('y', function(d) { return y(d.y)-5 }) | |
.attr('font-size', '11px') | |
.text(function(d) { return round(Math.abs(d.y - getEstimate(d.x))) }) | |
.style('opacity', 0) | |
errorTexts.transition() | |
.duration(duration) | |
.delay(function(d, idx) { return idx * delay }) | |
.style('opacity', 1) | |
// transition the texts into an equation | |
errorTexts.transition() | |
.duration(duration) | |
.delay(errorDelay) | |
.attr('y', errorY) | |
.attr('x', function(d) { return x(d.x) - 7 }) | |
// add plus signs between the equation values | |
var plusses = svg.selectAll('.plus') | |
.data(data).enter() | |
.append('text') | |
.attr('x', function(d) { return x(d.x) + 13 }) | |
.attr('y', errorY - 100) | |
.text(function(d, idx) { | |
return idx+1 == data.length | |
? '=' | |
: '+' | |
}) | |
.attr('font-size', '8px') | |
.style('opacity', 0) | |
plusses.transition() | |
.duration(duration) | |
.delay(function(d, idx) { return errorDelay + 500 + 50 * idx }) | |
.style('opacity', 1) | |
.attr('y', errorY - 2) | |
// compute the final loss sum | |
var loss = 0; | |
for (var i=0; i<n; i++) { | |
loss += Math.abs( data[i].y - getEstimate(data[i].x ) ) | |
} | |
var l1Loss = svg.append('text') | |
.attr('x', 610) | |
.attr('y', -100) | |
.attr('font-size', '30px') | |
.attr('stroke', 'purple') | |
.attr('font-family', 'courier') | |
.text(round(loss)) | |
l1Loss.transition() | |
.duration(duration) | |
.delay(errorDelay + 800) | |
.attr('y', 32) | |
} | |
// get a random value between -v and v | |
function rand(v) { | |
return Math.random() * v - (Math.random() * v); | |
} | |
// get the estimate of the y value for a given x value | |
function getEstimate(x) { | |
return estimate.m * x + estimate.b; | |
} | |
// return `float` as a string with one decimal | |
function round(float) { | |
var str = float.toString(), | |
idx = str.indexOf('.'); | |
return str.substring(0, idx+2); | |
} | |
// main | |
d3.select('button').on('click', drawEstimate) | |
drawPoints(); | |
</script> | |
</body> | |
</html> |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment