Skip to content

Instantly share code, notes, and snippets.

@shaunlebron
Last active December 17, 2019 11:20
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save shaunlebron/82770f8af9081988b7af0ab649085ed1 to your computer and use it in GitHub Desktop.
Save shaunlebron/82770f8af9081988b7af0ab649085ed1 to your computer and use it in GitHub Desktop.
alternative visualization for neural networks

Neural Network Alt Viz

Use both scalar and vector nodes to simplify presentation and clarify computation (by linear algebra).

The weighted edges are replaced by two layers of vector nodes:

  1. the first collects the previous neuron scalars into a vector
  2. the second is a weight rack (a matrix), each vector containing all the weights for a neuron

Showing a single path to illustrate computation along its edges:

scalar         scalar            weight         scalar
neuron         vector            vector         neuron
  * -----------> O ---------------> O ----------> *
      (insert)             (.)             (=)
                              \dot product/

network

const canvas = document.querySelector("#canvas")
const ctx = canvas.getContext("2d")
const layerSizes = [30, 16, 16, 10]
const nlayers = layerSizes.length
const maxSize = Math.max(...layerSizes)
const nodeR = 5
const scalarR = 3
const pad = 50
const colPad = 180
const rowPad = 16
const di = 0.2
const ei = (1-di)/2
const coord = (col,row) => [
pad + col * colPad,
pad + (row + maxSize/2 - layerSizes[Math.round(col)]/2) * rowPad
]
const w = pad*2 + (nlayers-1)*colPad
const h = pad*2 + (maxSize-1)*rowPad
const H = h*2 - pad
canvas.width = w*2
canvas.height = H*2
canvas.style.width = w
canvas.style.height = H
ctx.scale(2,2)
const nodeStroke = "#000"
const nodeFill = "#000"
const scalarFill = "#000"
let synStroke = "rgba(0,0,0,0.2)"
const edgeStroke = "rgba(0,0,0,0.5)"
const dot = (x,y,r,f) => {
ctx.beginPath()
ctx.arc(x,y,r,0,2*Math.PI)
if (f) ctx.fillStyle = f
ctx.fill()
}
const circ = (x,y,r,s,f) => {
ctx.beginPath()
ctx.arc(x,y,r,0,2*Math.PI)
if (f) { ctx.fillStyle = f; ctx.fill() }
if (s) ctx.strokeStyle = s
ctx.stroke()
}
const line = (x0,y0,x1,y1,s) => {
ctx.beginPath()
ctx.moveTo(x0,y0)
ctx.lineTo(x1,y1)
if (s) ctx.strokeStyle = s
ctx.stroke()
}
function drawA() {
for (let i=0; i<nlayers; i++) {
for (let j=0; j<layerSizes[i]; j++) {
const [x,y] = coord(i, j)
fanA(x,y,i+1)
dot(x,y,scalarR,scalarFill)
}
}
}
function fanA(x0,y0,i) {
if (i >= nlayers) return
for (j=0; j<layerSizes[i]; j++) {
const [x,y] = coord(i,j)
line(x0,y0,x,y,synStroke)
}
}
function drawB() {
for (let i=0; i<nlayers; i++) {
for (let j=0; j<layerSizes[i]; j++) {
const [x,y] = coord(i, j)
dot(x,y,scalarR,scalarFill)
if (i>0) {
const [x0] = coord(i-di, j)
line(x,y,x0,y,edgeStroke)
circ(x0,y,nodeR,nodeStroke,nodeFill)
}
}
fanB(i)
}
}
function fanB(i) {
if (i>=nlayers-1) return
const [x,y] = coord(i+ei, layerSizes[i]/2)
for (let j=0; j<layerSizes[i]; j++) {
const [x0,y0] = coord(i,j)
line(x,y,x0,y0,synStroke)
}
const i1 = i+1
for (let j=0; j<layerSizes[i1]; j++) {
const [x0,y0] = coord(i1-di,j)
line(x,y,x0,y0,synStroke)
}
circ(x,y,nodeR,nodeStroke,nodeFill)
}
drawA();
ctx.translate(0,h-pad)
synStroke = "rgba(0,0,0,0.3)"
drawB();
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment