Skip to content

Instantly share code, notes, and snippets.

@johnbcoughlin
Last active August 29, 2015 13:57
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save johnbcoughlin/9733158 to your computer and use it in GitHub Desktop.
Save johnbcoughlin/9733158 to your computer and use it in GitHub Desktop.

Demonstration of Inverse Transform Sampling.

Relies on a numerical approximation to the probit function found here. Values are sampled uniformly from the interval [0, 1] and mapped through the probit function. If X is a random variable in Uniform(0, 1), then Y=probit(X) is a normally distributed random variable.

<!doctype html>
<meta charset="utf-8">
<title>Transforming Probability Density Functions</title>
<style>
body {
font-family: sans-serif;
font-size: 10px;
}
.axis {
shape-rendering: crispEdges;
}
.axis line, .axis path {
fill: none;
stroke: #000;
}
path.line {
fill: none;
stroke: #666;
stroke-width: 1px;
}
path.pdf {
fill: #44b;
stroke: none;
}
path.indicator {
fill: none;
stroke: #e64;
stroke-width: 1px;
}
path.normal {
fill: none;
stroke: #bbb;
stroke-width: 1.5px;
}
</style>
<body>
<script src="http://d3js.org/d3.v3.min.js"></script>
<script>
var margin = {top: 60, right: 60, bottom: 60, left: 60, center: 60},
width = 960 - margin.left - margin.right - margin.center,
height = 560 - margin.top - margin.bottom;
var svg = d3.select("body").append("svg")
.attr("width", width + margin.left + margin.right + margin.center)
.attr("height", height + margin.top + margin.bottom)
.append("g")
.attr("transform", "translate(" + margin.left + "," + margin.top + ")");
// Magic range numbers. This is the range of "reasonable" values of the normal distribution pdf.
var range = [-3.5, 3.5];
// Make the scales for the first plot
var xScale = d3.scale.linear()
.domain([0, 1])
.range([0, width / 2]);
var fScale = d3.scale.linear()
.domain(range)
.range([height - margin.bottom, 0]);
var xAxis = d3.svg.axis()
.orient("bottom")
.scale(xScale);
var fAxis = d3.svg.axis()
.orient("left")
.scale(fScale);
// Add the x axis
svg.append("g")
.attr("class", "x axis")
.attr("transform", "translate(0, " + (fScale(0)) + ")")
.call(xAxis);
// Add the y axis
svg.append("g")
.attr("class", "f axis")
.call(fAxis);
// Make scales for the second plot
var yAxis = d3.svg.axis()
.orient("left")
.scale(fScale);
var pScale = d3.scale.linear()
.domain([0, 1])
.range([0, width]);
svg.append("g")
.attr("class", "y axis")
.attr("transform", "translate(" + (width / 2 + margin.center) + ", 0)")
.call(yAxis);
var xPrecision = 1000;
var xPoints = [];
for (var i = 0; i < xPrecision; i++) {
x = (i + 1) / (xPrecision + 1);
xPoints[i] = [x, probit(x)];
}
// Draw a completed normal density function on range
function normal(x) { return (1 / Math.sqrt(2 * Math.PI)) * Math.exp(-x * x / 2); }
var normalGenerator = d3.svg.line()
.x(function(d) { return pScale(d[1]); })
.y(function(d) { return fScale(d[0]); });
var normalData = [];
var normalPrecision = 1000;
for (var i = 0; i < normalPrecision; i++) {
var x = range[0] + i * (range[1] - range[0]) / normalPrecision;
normalData[i] = [x, normal(x)];
}
function drawNormal() {
svg.append("path")
.attr("class", "normal")
.attr("transform", "translate(" + (width / 2 + margin.center) + ", 0)")
.attr("d", normalGenerator(normalData));
}
drawNormal();
// Add the graph of the transforming probit function
var fPath = d3.svg.line()
.x(function(d) { return xScale(d[0]); })
.y(function(d) { return fScale(d[1]); });
svg.append("path")
.attr("class", "line")
.attr("d", fPath(xPoints));
// Initialize density estimate data
var yPrecision = 500;
var yPoints = [];
var yIndices = [];
var samples = 0;
for (var i = 0; i < yPrecision; i++) {
yPoints[i] = 0;
yIndices[i] = i;
}
var bucket = d3.scale.quantile()
.domain(range)
.range(yIndices);
var densityEstimate = d3.svg.area()
.x(function(d, i) { return fScale(bucket.invertExtent(i)[0]); })
.y0(function(d) { return pScale(0); })
.y1(function(d, i) { return -pScale(yPoints[i] / samples); })
.interpolate("cardinal");
// Add the graph of the inferred distribution
var pdf = svg.append("path")
.attr("class", "pdf")
.attr("transform", "translate(" + (width / 2 + margin.center) + ", 0)rotate(90)");
// Add the indicator lines
var verticalIndicator = d3.svg.line()
.x(function(d) { return xScale(d[0]); })
.y(function(d) { return fScale(d[1]); });
var horizontalIndicator = d3.svg.line()
.x(function(d, i) { return i == 0 ? xScale(d[0]) : (width / 2 + margin.center); })
.y(function(d) { return fScale(d[1]); });
var vLine = svg.append("path")
.attr("class", "indicator");
var hLine = svg.append("path")
.attr("class", "indicator");
// Sample from uniform([0, 1]) and map it through probit,
// estimating the density of probit(uniform), which is normal.
function sample() {
var X = Math.random();
var Y = probit(X);
samples++;
var center = bucket(Y);
var radius = Math.round(yPrecision / 16);
for (var s = -radius; s <= radius; s++) {
// Form a parabola centered at center with radius, and bounding area 1:
val = (1 - (s / radius) * (s / radius)); // has area (2/3) * (range / 8)
val *= (3 / 2) * 8 / (range[1] - range[0]); // normalize
if (s + center < yPrecision && s + center >= 0) {
yPoints[s + center] += val;
}
}
// Starting at 1600ms and decaying exponentially towards 200ms
var timeout = 200 + (Math.exp((1 - samples) / 8)) * 1400;
pdf.transition()
.duration(timeout)
.attr("d", densityEstimate(yPoints));
vLine.attr("d", verticalIndicator([[X, 0], [X, Y]]));
hLine.attr("d", horizontalIndicator([[X, Y], [0, Y]]));
drawNormal();
window.setTimeout(sample, timeout);
}
sample();
// Function to compute probit(z), the inverse of the cumulative distribution function
// of the normal distribution. Shamelessly adapted from the python implementation
// here: http://home.online.no/~pjacklam/notes/invnorm/impl/field/
// which is based on the algorithm given here: http://home.online.no/~pjacklam/notes/invnorm/
function probit(z) {
var a = [-3.969683028665376e+1, 2.209460984245205e+2,
-2.759285104469687e+2, 1.383577518672690e+2,
-3.066479806614716e+1, 2.506628277459239],
b = [-5.447609879822406e+1, 1.615858368580409e+2,
-1.556989798598866e+2, 6.680131188771972e+1,
-1.328068155288572e+1 ],
c = [-7.784894002430293e-3, -3.223964580411365e-1,
-2.400758277161838, -2.549732539343734,
4.374664141464968, 2.938163982698783],
d = [ 7.784695709041462e-3, 3.224671290700398e-1,
2.445134137142996, 3.754408661907416];
var zlow = 0.02425,
zhigh = 1 - zlow;
if (z < zlow) {
q = Math.sqrt(-2*Math.log(z));
return (((((c[0]*q+c[1])*q+c[2])*q+c[3])*q+c[4])*q+c[5]) /
((((d[0]*q+d[1])*q+d[2])*q+d[3])*q+1);
} else if (zhigh < z) {
q = Math.sqrt(-2*Math.log(1-z));
return -(((((c[0]*q+c[1])*q+c[2])*q+c[3])*q+c[4])*q+c[5]) /
((((d[0]*q+d[1])*q+d[2])*q+d[3])*q+1);
} else {
var q = z - 0.5;
var r = q*q;
return (((((a[0]*r+a[1])*r+a[2])*r+a[3])*r+a[4])*r+a[5])*q /
(((((b[0]*r+b[1])*r+b[2])*r+b[3])*r+b[4])*r+1);
}
}
</script>
</body>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment