Skip to content

Instantly share code, notes, and snippets.

@duhaime
Last active November 16, 2018 01:15
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save duhaime/e0058c62cf734b60a392de18c1e94905 to your computer and use it in GitHub Desktop.
Save duhaime/e0058c62cf734b60a392de18c1e94905 to your computer and use it in GitHub Desktop.
Visualizing L1 Loss
license: MIT
height: 690
scrolling: no
border: yes
<!DOCTYPE html>
<html>
<head>
<meta charset='UTF-8'>
<title>L1 Error with Random Linear Fits</title>
<style>
.container {
width: 700px;
position: relative;
margin: 0 20px;
}
.container * {
max-width: 100%;
}
button {
margin: 0 auto;
display: block;
padding: 10px;
border-radius: 3px;
font-size: 1em;
background: coral;
color: #fff;
}
div#error-label {
position: absolute;
top: 20px;
left: 20px;
font-size: 30px;
font-family: courier;
color: green;
}
#loss {
position: absolute;
top: 3px;
left: 25px;
font-family: courier;
}
h2 {
font-size: 30px;
color: purple;
margin: 0;
}
label {
width: 100%;
text-align: center;
display: block;
}
#accuracy {
position: absolute;
top: 50px;
}
</style>
</head>
<body>
<div class='container'>
<div id='loss'>
<h2>Loss</h2>
<div id='accuracy'>
<label>Offset</label>
<input id='acc' type='range' min='1' max='100' value='1'>
</div>
</div>
<svg/>
<button>Run Model</button>
<p>A <b><a target='_blank' href='https://en.wikipedia.org/wiki/Loss_function'>loss function</a></b> calculates the accuracy of a machine learning model by measuring the difference between the values the model predicts and the corresponding 'groundtruth' (or known data) values. One of the simplest loss functions is known as the <b>L1 loss</b>:</p>
<div>$$L1\ Loss = \sum_{i=1}^n |\ y_i-\hat{y}_i\ |$$</div>
<p>This equation says that to compute the L1 loss of a machine learning model, one must examine each input data point, find the difference between that data point's true y value and the model's prediction ŷ, take the absolute value of that difference, then sum those absolute differences. As the model grows more accurate, that resulting sum will shrink.</p>
</div>
<script type='text/javascript' src='https://d3js.org/d3.v5.min.js'></script>
<script type='text/javascript' src='https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/MathJax.js?config=TeX-MML-AM_CHTML'></script>
<script>
// store the estimate of the line globally
var estimate = {m: null, b: null};
// specify data params
var data = [], // fake data
m = 2, // slope
b = 10, // y intercept
n = 15, // n observations
noise = 4; // noise to include
// generate fake data
for (var x=0; x<n; x++) {
data.push({
x: x,
y: (m * x) + b + rand(noise),
})
}
// set chart sizes
var width = 700,
height = 400,
margin = {left: 130, top: 130, right: 120, bottom: 20};
// set duration of transitions
var duration = 1000,
delay = 40,
errorDelay = n * delay * 2.5;
var domains = {
x: d3.extent(data, function(d) { return d.x }),
y: d3.extent(data, function(d) { return d.y }),
}
var x = d3.scaleLinear()
.domain(domains.x)
.range([margin.left, width-margin.right])
var y = d3.scaleLinear()
.domain(domains.y)
.range([height-margin.bottom, margin.top])
var xAxis = d3.axisBottom().scale(x).ticks(5),
yAxis = d3.axisLeft().scale(y).ticks(5);
var svg = d3.select('svg')
.attr('width', width)
.attr('height', height)
// specify destination for error term transitions
var errorY = 30;
function drawPoints() {
svg.selectAll('circle').data(data).enter()
.append('circle')
.attr('cx', function(d) { return x(d.x) })
.attr('cy', function(d) { return y(d.y) })
.attr('r', 3)
.attr('fill', '#4682B4') // blue
svg.append('g')
.attr('class', 'x axis')
.attr('transform', 'translate(0,' +
(height - margin.bottom) + ')')
.call(xAxis);
svg.append('g')
.attr('class', 'y axis')
.attr('transform', 'translate(' + margin.left + ', 0)')
.call(yAxis);
}
// draw a random linear model for the points
// and show the l1 norm distance computed
function drawEstimate() {
// clear the svg and redraw
d3.selectAll('circle').remove();
d3.selectAll('line').remove();
d3.selectAll('text').remove();
drawPoints();
// mutate the estimate with random params
var acc = document.querySelector('#acc').value,
scaled = (100-acc)/100;
estimate.m = m + scaled/2;
estimate.b = b + scaled/2;
// draw the line then animate it into position
var line = svg.append('line')
.attr('x1', x(data[0].x))
.attr('x2', x(data[0].x))
.attr('y1', y(estimate.m * data[0].x + estimate.b))
.attr('y2', y(estimate.m * data[0].x + estimate.b))
.attr('fill', 'none')
.attr('stroke', 'red')
line.transition()
.duration(duration)
.attr('x2', x(data[n-1].x))
.attr('y2', y(estimate.m * data[n-1].x + estimate.b))
// draw the l1 errors for each observation
drawL1Errors()
}
// for each observation, draw a vertical line between
// that observation and the line of best fit, then show
// a numerical representation of the error for that value
// then sum up those error values in a big old equation
function drawL1Errors() {
var errorLines = svg.selectAll('.error-line').data(data).enter()
.append('line')
.attr('class', 'error-line')
.attr('x1', function(d) { return x(d.x) })
.attr('x2', function(d) { return x(d.x) })
.attr('y1', function(d) { return y(d.y) })
.attr('y2', function(d) { return y(d.y) })
.attr('stroke', 'purple')
.attr('stroke-dasharray', '4')
// transition the line into place
errorLines.transition()
.duration(duration)
.delay(function(d, idx) { return idx * delay })
.attr('y2', function(d) { return y(getEstimate(d.x)) })
// draw the error terms for each observation
var errorTexts = svg.selectAll('.error-nums').data(data).enter()
.append('text')
.attr('x', function(d) { return x(d.x)+4 })
.attr('y', function(d) { return y(d.y)-5 })
.attr('font-size', '11px')
.text(function(d) { return round(Math.abs(d.y - getEstimate(d.x))) })
.style('opacity', 0)
errorTexts.transition()
.duration(duration)
.delay(function(d, idx) { return idx * delay })
.style('opacity', 1)
// transition the texts into an equation
errorTexts.transition()
.duration(duration)
.delay(errorDelay)
.attr('y', errorY)
.attr('x', function(d) { return x(d.x) - 7 })
// add plus signs between the equation values
var plusses = svg.selectAll('.plus')
.data(data).enter()
.append('text')
.attr('x', function(d) { return x(d.x) + 13 })
.attr('y', errorY - 100)
.text(function(d, idx) {
return idx+1 == data.length
? '='
: '+'
})
.attr('font-size', '8px')
.style('opacity', 0)
plusses.transition()
.duration(duration)
.delay(function(d, idx) { return errorDelay + 500 + 50 * idx })
.style('opacity', 1)
.attr('y', errorY - 2)
// compute the final loss sum
var loss = 0;
for (var i=0; i<n; i++) {
loss += Math.abs( data[i].y - getEstimate(data[i].x ) )
}
var l1Loss = svg.append('text')
.attr('x', 610)
.attr('y', -100)
.attr('font-size', '30px')
.attr('stroke', 'purple')
.attr('font-family', 'courier')
.text(round(loss))
l1Loss.transition()
.duration(duration)
.delay(errorDelay + 800)
.attr('y', 32)
}
// get a random value between -v and v
function rand(v) {
return Math.random() * v - (Math.random() * v);
}
// get the estimate of the y value for a given x value
function getEstimate(x) {
return estimate.m * x + estimate.b;
}
// return `float` as a string with one decimal
function round(float) {
var str = float.toString(),
idx = str.indexOf('.');
return str.substring(0, idx+2);
}
// main
d3.select('button').on('click', drawEstimate)
drawPoints();
</script>
</body>
</html>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment