Skip to content

Instantly share code, notes, and snippets.

@nampdn
Created June 26, 2024 18:52
Show Gist options
  • Save nampdn/1ee2201876b4855313a1ca8aa95f5e8f to your computer and use it in GitHub Desktop.
Save nampdn/1ee2201876b4855313a1ca8aa95f5e8f to your computer and use it in GitHub Desktop.
train-neural-network-visualization.html
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Neural Network Training on Sine Function</title>
<script src="https://cdnjs.cloudflare.com/ajax/libs/Chart.js/3.7.0/chart.min.js"></script>
<style>
body { font-family: Arial, sans-serif; margin: 0; padding: 20px; display: flex; flex-direction: column; align-items: center; }
#visualization-container { display: flex; justify-content: space-around; width: 100%; margin-bottom: 20px; flex-wrap: wrap; }
canvas { border: 1px solid #ccc; margin: 10px; }
#controls { margin-top: 20px; }
button, input { margin: 5px; }
</style>
</head>
<body>
<h1>Neural Network Training on Sine Function</h1>
<div id="visualization-container">
<canvas id="networkCanvas" width="800" height="400"></canvas>
<canvas id="functionPlot" width="400" height="300"></canvas>
<canvas id="errorChart" width="400" height="300"></canvas>
</div>
<div id="controls">
<button id="startBtn">Start Training</button>
<button id="resetBtn">Reset</button>
<div>
Learning Rate: <span id="learningRateValue">0.01</span>
<input type="range" id="learningRateSlider" min="0.001" max="0.1" step="0.001" value="0.01">
</div>
</div>
<script>
const networkCanvas = document.getElementById('networkCanvas');
const networkCtx = networkCanvas.getContext('2d');
const startBtn = document.getElementById('startBtn');
const resetBtn = document.getElementById('resetBtn');
const learningRateSlider = document.getElementById('learningRateSlider');
const learningRateValue = document.getElementById('learningRateValue');
// Neural Network parameters
const layers = [1, 10, 10, 1]; // Input, Hidden layers, Output
let weights = [];
let biases = [];
let activations = [];
let learningRate = 0.01;
let isTraining = false;
let epoch = 0;
const maxEpochs = 10000;
// Initialize weights, biases, and activations
function initializeNetwork() {
weights = [];
biases = [];
activations = [new Array(layers[0]).fill(0)];
for (let i = 1; i < layers.length; i++) {
const layerWeights = [];
const layerBiases = [];
for (let j = 0; j < layers[i]; j++) {
const neuronWeights = [];
for (let k = 0; k < layers[i-1]; k++) {
neuronWeights.push(Math.random() - 0.5);
}
layerWeights.push(neuronWeights);
layerBiases.push(Math.random() - 0.5);
}
weights.push(layerWeights);
biases.push(layerBiases);
activations.push(new Array(layers[i]).fill(0));
}
}
// Activation function (tanh)
function tanh(x) {
return Math.tanh(x);
}
// Forward pass
function forwardPass(input) {
activations[0] = [input];
for (let i = 1; i < layers.length; i++) {
for (let j = 0; j < layers[i]; j++) {
let sum = biases[i-1][j];
for (let k = 0; k < layers[i-1]; k++) {
sum += weights[i-1][j][k] * activations[i-1][k];
}
activations[i][j] = tanh(sum);
}
}
return activations[activations.length - 1][0];
}
// Backpropagation
function backpropagate(input, target) {
const output = forwardPass(input);
const outputError = output - target;
// Calculate gradients and update weights/biases
const gradients = [];
for (let i = layers.length - 1; i > 0; i--) {
const layerGradients = [];
for (let j = 0; j < layers[i]; j++) {
let gradient;
if (i === layers.length - 1) {
gradient = outputError * (1 - Math.pow(activations[i][j], 2));
} else {
gradient = 0;
for (let k = 0; k < layers[i+1]; k++) {
gradient += gradients[0][k] * weights[i][k][j];
}
gradient *= (1 - Math.pow(activations[i][j], 2));
}
layerGradients.push(gradient);
for (let k = 0; k < layers[i-1]; k++) {
weights[i-1][j][k] -= learningRate * gradient * activations[i-1][k];
}
biases[i-1][j] -= learningRate * gradient;
}
gradients.unshift(layerGradients);
}
return outputError ** 2;
}
// Training step
function trainingStep() {
const input = Math.random() * 2 * Math.PI;
const target = Math.sin(input);
return backpropagate(input, target);
}
// Draw neural network
function drawNetwork() {
networkCtx.clearRect(0, 0, networkCanvas.width, networkCanvas.height);
const layerWidth = networkCanvas.width / (layers.length + 1);
const maxNeurons = Math.max(...layers);
for (let i = 0; i < layers.length; i++) {
const neurons = layers[i];
for (let j = 0; j < neurons; j++) {
const x = (i + 1) * layerWidth;
const y = (j + 1) * (networkCanvas.height / (neurons + 1));
// Draw neuron
networkCtx.beginPath();
networkCtx.arc(x, y, 15, 0, 2 * Math.PI);
networkCtx.fillStyle = 'lightblue';
networkCtx.fill();
networkCtx.stroke();
// Draw activation value
networkCtx.fillStyle = 'black';
networkCtx.font = '10px Arial';
networkCtx.textAlign = 'center';
networkCtx.fillText(activations[i][j].toFixed(2), x, y);
// Draw bias (except for input layer)
if (i > 0) {
networkCtx.fillText('b: ' + biases[i-1][j].toFixed(2), x, y + 25);
}
// Draw connections to next layer
if (i < layers.length - 1) {
const nextNeurons = layers[i + 1];
for (let k = 0; k < nextNeurons; k++) {
const nextX = (i + 2) * layerWidth;
const nextY = (k + 1) * (networkCanvas.height / (nextNeurons + 1));
const weight = weights[i][k][j];
const normalizedWeight = (weight + 1) / 2; // Normalize to [0, 1]
// Draw connection
networkCtx.beginPath();
networkCtx.moveTo(x + 15, y);
networkCtx.lineTo(nextX - 15, nextY);
networkCtx.strokeStyle = `rgb(${255 * (1 - normalizedWeight)}, 0, ${255 * normalizedWeight})`;
networkCtx.lineWidth = Math.abs(weight) * 2;
networkCtx.stroke();
// Draw "zap" effect
if (isTraining && Math.random() < 0.1) {
networkCtx.beginPath();
networkCtx.moveTo(x + 15, y);
for (let t = 0.1; t < 1; t += 0.1) {
const zapX = x + 15 + (nextX - x - 30) * t;
const zapY = y + (nextY - y) * t + (Math.random() - 0.5) * 10;
networkCtx.lineTo(zapX, zapY);
}
networkCtx.lineTo(nextX - 15, nextY);
networkCtx.strokeStyle = 'yellow';
networkCtx.lineWidth = 2;
networkCtx.stroke();
}
}
}
}
}
}
// Function plot
const functionPlot = new Chart(document.getElementById('functionPlot').getContext('2d'), {
type: 'scatter',
data: {
datasets: [{
label: 'Target (sin)',
data: [],
borderColor: 'rgb(75, 192, 192)',
showLine: true,
pointRadius: 0
}, {
label: 'Network Output',
data: [],
borderColor: 'rgb(255, 99, 132)',
showLine: true,
pointRadius: 0
}]
},
options: {
responsive: true,
scales: {
x: {
type: 'linear',
position: 'bottom',
min: 0,
max: 2 * Math.PI
},
y: {
min: -1,
max: 1
}
}
}
});
// Update function plot
function updateFunctionPlot() {
const points = 100;
const targetData = [];
const outputData = [];
for (let i = 0; i <= points; i++) {
const x = (i / points) * 2 * Math.PI;
targetData.push({x: x, y: Math.sin(x)});
outputData.push({x: x, y: forwardPass(x)});
}
functionPlot.data.datasets[0].data = targetData;
functionPlot.data.datasets[1].data = outputData;
functionPlot.update();
}
// Error chart
const errorChart = new Chart(document.getElementById('errorChart').getContext('2d'), {
type: 'line',
data: {
labels: [],
datasets: [{
label: 'Mean Squared Error',
data: [],
borderColor: 'rgb(75, 192, 192)',
tension: 0.1
}]
},
options: {
responsive: true,
scales: {
y: {
beginAtZero: true
}
},
animation: {
duration: 0
}
}
});
// Update error chart
function updateErrorChart(error) {
errorChart.data.labels.push(epoch);
errorChart.data.datasets[0].data.push(error);
if (errorChart.data.labels.length > 100) {
errorChart.data.labels.shift();
errorChart.data.datasets[0].data.shift();
}
errorChart.update();
}
// Training loop
function train() {
if (isTraining && epoch < maxEpochs) {
const error = trainingStep();
if (epoch % 10 === 0) {
drawNetwork();
updateFunctionPlot();
updateErrorChart(error);
}
epoch++;
requestAnimationFrame(train);
} else {
isTraining = false;
startBtn.textContent = 'Start Training';
}
}
// Reset network
function resetNetwork() {
initializeNetwork();
epoch = 0;
errorChart.data.labels = [];
errorChart.data.datasets[0].data = [];
errorChart.update();
drawNetwork();
updateFunctionPlot();
}
// Event listeners
startBtn.addEventListener('click', () => {
isTraining = !isTraining;
startBtn.textContent = isTraining ? 'Pause Training' : 'Resume Training';
if (isTraining) train();
});
resetBtn.addEventListener('click', resetNetwork);
learningRateSlider.addEventListener('input', (e) => {
learningRate = parseFloat(e.target.value);
learningRateValue.textContent = learningRate.toFixed(3);
});
// Initialize
initializeNetwork();
drawNetwork();
updateFunctionPlot();
</script>
</body>
</html>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment