Skip to content

Instantly share code, notes, and snippets.

@micahstubbs
Last active April 21, 2017 20:57
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save micahstubbs/e110f1db857b1da87595515cacd103dc to your computer and use it in GitHub Desktop.
Save micahstubbs/e110f1db857b1da87595515cacd103dc to your computer and use it in GitHub Desktop.
latent value learning | scaled up
license: Apache-2.0
height: 500
border: no

this iteration draws a 384px by 200px svg and scales it up to 960px wide by 500px tall with the svg viewBox attribute

an iteration on the block latent value learning | es2015 from @micahstubbs. originally inspired by the block Latent Value Learning from bricof


This is a cleaned and simplified version of a simulation / animation of latent variable learning used in the Recommendation Systems section of the Stitch Fix Algorithms Tour.

Each circle is assumed to have some latent value along the horizontal axis - some true value for an attribute that we cannot observe directly but that we can try to estimate based on feedback from attempted pair-matches involving one A element and one B element.

The algorithm used to find them is as follows:

  • assign each entity a current estimated latent value, initialized at the center of the scale
  • select A-B pairs randomly, weighted by the distance between their current estimated latent value (shorter distances produce higher probabilities of selection)
  • if the feedback from the pair attempt says their relative latent values are different than what our estimates suggest, move both of the current estimated latent values in the direction of feedback (e.g. if A says B is too small, then move A to the right and B to the left), multiplied by a learning rate
  • repeat

The underlying simulation, then, is running this algorithm over a set of entities while also simulating the entities - each has its own latent value and the feedback it provides when paired with other entities is based on the actual differences between their latent values, with some noise added for good measure.

The svg update is straightforward - at each timestep, pairs are shown by lines between the circles, and the circles are transitioned to their new location based on their current estimated latent value.

/* global d3 */
function animatedLearning() {
const parentSVG = d3.select('body').append('svg')
.attr('width', 960)
.attr('height', 500)
.attr('viewBox', '0 0 384 200')
// .attr('preserveAspectRatio', 'xMinYMin meet');
const width = 384;
const height = 200; // change the height and the rest of the layout responds ✨
const margin = {
top: height * 0.05,
bottom: height * 0.05,
left: 0,
right: 0
};
// calculate some useful values for plot layout
const innerHeight = height - margin.top - margin.bottom;
const pointYDistance = innerHeight * 0.3;
const pointYOffset = (innerHeight - pointYDistance) / 2;
const x = d3.scaleLinear()
.domain([0, 1])
.range([0, width]);
// viewBox="120 50 580 450"
const svg = parentSVG.append('svg')
.attr('width', width)
.attr('height', height);
const g = svg.append('g')
.attr('transform', `translate(${margin.left},${margin.top})`);
// draw the center line
g.append('line')
.attr('x1', x.range()[0])
.attr('x2', x.range()[1])
.attr('y1', innerHeight / 2)
.attr('y2', innerHeight / 2)
.style('stroke', 'black')
.style('stroke-width', '0.75')
.style('fill', 'none')
// draw A elements label
g.append('text')
.attr('x', width / 2)
.attr('y', pointYOffset * 0.4)
.attr('dy', '0.35em')
.classed('A-color', true)
.attr('text-anchor', 'middle')
.attr('font-size', '12px')
.text('A elements');
// draw B elements label
g.append('text')
.attr('x', width / 2)
.attr('y', pointYOffset + pointYDistance + (pointYOffset * 0.6))
.attr('dy', '0.35em')
.classed('B-color', true)
.attr('text-anchor', 'middle')
.attr('font-size', '12px')
.text('B elements');
const learningRate = 0.2;
const nAs = 10;
const nBs = 10;
const nPairs = 8;
//
// create elements with latent values and initial positions
//
const As = [];
for (let i = 0; i < nAs; i += 1) {
As.push({
id: i,
latentValue: Math.random(),
currentPosition: 0.5,
nextPosition: 0.5
});
}
const Bs = [];
for (let i = 0; i < nBs; i += 1) {
Bs.push({
id: i,
latentValue: Math.random(),
currentPosition: 0.5,
nextPosition: 0.5
});
}
//
// construct circles
//
g.selectAll('.A').data(As, d => d.id)
.enter().append('circle')
.attr('class', 'A A-color')
.attr('cx', width / 2)
.attr('cy', pointYOffset)
.attr('r', 3);
g.selectAll('.B').data(Bs, d => d.id)
.enter().append('circle')
.attr('class', 'B B-color')
.attr('cx', width / 2)
.attr('cy', pointYOffset + pointYDistance)
.attr('r', 3);
//
// simulation / animation loop
//
d3.interval(() => {
// *** SIMULATION ***
const pairs = [];
for (let i = 0; i < nPairs; i += 1) {
const aId = Math.floor(Math.random() * nAs);
// pair selection a stochastic function of distance from respective current positions
const weights = Bs.map(d => d.currentPosition - As[aId].currentPosition);
let cumWeights = [];
weights.reduce((a, b, i) => {
cumWeights[i] = {
v: a + b,
id: i
};
return a + b;
}, 0);
cumWeights = cumWeights.sort((a, b) => a.v > b.v);
let bId = Math.floor(Math.random() * nBs);
if (cumWeights[cumWeights.length - 1].v !== 0) {
const selRandom = Math.random() * cumWeights[cumWeights.length - 1].v;
const sel = cumWeights.find(d => d.v >= selRandom);
if (!(sel == null)) {
bId = sel.id;
}
}
pairs.push({ aId, bId });
// big = 1, small = -1
const feedback = -1 + (2 * (As[aId].latentValue > Bs[bId].latentValue));
// use feedback if it contradicts current
if ((feedback === -1) && (As[aId].currentPosition <= Bs[bId].currentPosition)) {
As[aId].nextPosition = As[aId].currentPosition + (Math.random() * learningRate);
Bs[bId].nextPosition = Bs[bId].currentPosition - (Math.random() * learningRate);
}
if ((feedback === 1) && (As[aId].currentPosition >= Bs[bId].currentPosition)) {
As[aId].nextPosition = As[aId].currentPosition - (Math.random() * learningRate);
Bs[bId].nextPosition = Bs[bId].currentPosition + (Math.random() * learningRate);
}
As[aId].nextPosition = Math.min(1, Math.max(0, As[aId].nextPosition));
Bs[bId].nextPosition = Math.min(1, Math.max(0, Bs[bId].nextPosition));
}
// *** SVG ANIMATION ***
const delay = 400;
const move = 500;
//
// draw and animate pair lines
//
g.selectAll('.pair').remove();
g.selectAll('.pair')
.data(pairs)
.enter().append('line')
.attr('class', 'pair')
.attr('y1', pointYOffset)
.attr('y2', pointYOffset + pointYDistance)
.style('stroke', '#000')
.style('stroke-width', 0.25)
.style('fill', 'none')
.attr('x1', d => x(As[d.aId].currentPosition))
.attr('x2', d => x(Bs[d.bId].currentPosition))
.transition()
.delay(delay)
.duration(move)
.attr('x1', d => x(As[d.aId].nextPosition))
.attr('x2', d => x(Bs[d.bId].nextPosition));
//
// animate circles
//
g.selectAll('.A')
.transition()
.delay(delay)
.duration(move)
.attr('cx', d => x(d.nextPosition));
g.selectAll('.B')
.transition()
.delay(delay)
.duration(move)
.attr('cx', d => x(d.nextPosition));
// *** end of svg animation code ***
//
// prep next timestep
//
As.forEach((d) => {
d.currentPosition = d.nextPosition;
});
Bs.forEach((d) => {
d.currentPosition = d.nextPosition;
});
}, 1200);
}
<!DOCTYPE html>
<meta charset="utf-8">
<style>
.A-color {
fill: #4B90A6;
}
.B-color {
fill: #F3A54A;
}
</style>
<body>
<script src="https://d3js.org/d3.v4.min.js"></script>
<script src="animated-learning.js"></script>
<script>
animatedLearning()
</script>
</body>
View raw

(Sorry about that, but we can’t show files that are this big right now.)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment