Skip to content

Instantly share code, notes, and snippets.

@mattdesl
Created March 16, 2023 11:43
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save mattdesl/82d462d24e4aec503ae258fdac0d471b to your computer and use it in GitHub Desktop.
Save mattdesl/82d462d24e4aec503ae258fdac0d471b to your computer and use it in GitHub Desktop.

training a neural net to learn mixbox's latent space (encoding RGB to latent vector). the rest can be mixed and decoded without a neural net.

license

note this source code has no license associated with it and should be used for personal use only, as Mixbox has a commercial license and it isn't entirely clear whether this type of usage would infringe on it or not.

const fs = require("fs");
const brain = require("brain.js");
const mixbox = require("mixbox");
const Color = require("canvas-sketch-util/color");
const Random = require("canvas-sketch-util/random");
const { clamp01, clamp } = require("canvas-sketch-util/math");
const { createCanvas } = require("canvas");
// prettier-ignore
const pairs = [
[[0, 33, 133 ], [252, 211, 0]],
[[0, 33, 133 ], [255, 105, 0]],
[[25, 0, 89], [252, 211, 0]],
[[123, 72, 0 ],[107, 148, 4]],
[[255,255,255], [0,0,0]],
[[255,255,255], [0,255,0]],
[[255,255,255], [255,0,0]],
[[0,255,255], [0,0,255]],
]
run();
function run() {
// provide optional config object (or undefined). Defaults shown.
const config = {
iterations: 4000, // the maximum times to iterate the training data --> number greater than 0
errorThresh: 0.00005, // the acceptable error percentage from training data --> number between 0 and 1
log: true, // true to use console.log, when a function is supplied it is used --> Either true or a function
logPeriod: 100, // iterations between logging out --> number greater than 0
learningRate: 0.3, // scales with delta to effect training rate --> number between 0 and 1
momentum: 0.1,
// hiddenLayers: [4], // array of ints for the sizes of the hidden layers in the network
activation: "leaky-relu", // supported activation types: ['sigmoid', 'relu', 'leaky-relu', 'tanh'],
leakyReluAlpha: 0.01, // supported for activation type 'leaky-relu'
};
// create a simple feed forward neural network with backpropagation
const net0 = new brain.NeuralNetwork({ ...config, hiddenLayers: [7] });
const colors = [];
const colorSteps = 12;
for (let r = 0; r <= 0xff; r += colorSteps) {
for (let g = 0; g <= 0xff; g += colorSteps) {
for (let b = 0; b <= 0xff; b += colorSteps) {
colors.push([r, g, b]);
}
}
}
console.log("Color count:", colors.length);
const data_04 = colors.map((rgb) => {
const input = rgb.map((x) => x / 0xff);
const output = mixbox.floatRgbToLatent(input); //.slice(0, 4);
return { input, output };
});
const r0 = net0.train(data_04);
console.log(r0);
const bands = pairs.length;
const bandHeight = 32;
const width = 256;
const pixelRatio = 2;
const canvas = createCanvas(
width * pixelRatio,
pixelRatio * bandHeight * bands * 2
);
const context = canvas.getContext("2d");
context.scale(pixelRatio, pixelRatio);
for (let i = 0; i < bands; i++) {
const [rgb1, rgb2] = pairs[i];
context.textAlign = "left";
context.textBaseline = "middle";
context.font = "12px monospace";
// first use mixbox latent encoder
const latent0 = mixbox.floatRgbToLatent(rgb1.map((x) => x / 0xff));
const latent1 = mixbox.floatRgbToLatent(rgb2.map((x) => x / 0xff));
drawGradient(context, width, bandHeight, latent0, latent1);
context.fillStyle = "white";
context.fillText("MIXBOX", 12, bandHeight / 2);
context.translate(0, bandHeight);
// now use our neural net encoder
const latentA0 = toLatent(rgb1);
const latentA1 = toLatent(rgb2);
drawGradient(context, width, bandHeight, latentA0, latentA1);
context.fillStyle = "white";
context.fillText("NEURAL NET", 12, bandHeight / 2);
context.translate(0, bandHeight);
}
const buf = canvas.toBuffer();
fs.writeFileSync("test.png", buf);
function toLatent(rgb) {
const input = roundRGB(rgb, colorSteps).map((x) => x / 0xff);
return net0.run(input);
}
function drawGradient(context, width, height, latent0, latent1) {
for (let i = 0; i < width; i++) {
const t = i / (width - 1);
const mixed = lerpLatent(latent0, latent1, t);
const mixedRGB = latentToRgb(mixed).map((x) => clamp(x, 0, 0xff));
const hex = Color.parse(mixedRGB).hex;
context.fillStyle = hex;
context.fillRect(i, 0, 1, height);
}
}
}
function roundRGB(rgb, steps) {
return rgb.map((n) => Math.round(n / steps) * steps);
}
// Note this is from mixbox's licensed code
// prettier-ignore
function evalPolynomial(c0, c1, c2, c3) {
var r = 0.0;
var g = 0.0;
var b = 0.0;
var c00 = c0 * c0;
var c11 = c1 * c1;
var c22 = c2 * c2;
var c33 = c3 * c3;
var c01 = c0 * c1;
var c02 = c0 * c2;
var c12 = c1 * c2;
var w = 0.0;
w = c0*c00; r += +0.07717053*w; g += +0.02826978*w; b += +0.24832992*w;
w = c1*c11; r += +0.95912302*w; g += +0.80256528*w; b += +0.03561839*w;
w = c2*c22; r += +0.74683774*w; g += +0.04868586*w; b += +0.00000000*w;
w = c3*c33; r += +0.99518138*w; g += +0.99978149*w; b += +0.99704802*w;
w = c00*c1; r += +0.04819146*w; g += +0.83363781*w; b += +0.32515377*w;
w = c01*c1; r += -0.68146950*w; g += +1.46107803*w; b += +1.06980936*w;
w = c00*c2; r += +0.27058419*w; g += -0.15324870*w; b += +1.98735057*w;
w = c02*c2; r += +0.80478189*w; g += +0.67093710*w; b += +0.18424500*w;
w = c00*c3; r += -0.35031003*w; g += +1.37855826*w; b += +3.68865000*w;
w = c0*c33; r += +1.05128046*w; g += +1.97815239*w; b += +2.82989073*w;
w = c11*c2; r += +3.21607125*w; g += +0.81270228*w; b += +1.03384539*w;
w = c1*c22; r += +2.78893374*w; g += +0.41565549*w; b += -0.04487295*w;
w = c11*c3; r += +3.02162577*w; g += +2.55374103*w; b += +0.32766114*w;
w = c1*c33; r += +2.95124691*w; g += +2.81201112*w; b += +1.17578442*w;
w = c22*c3; r += +2.82677043*w; g += +0.79933038*w; b += +1.81715262*w;
w = c2*c33; r += +2.99691099*w; g += +1.22593053*w; b += +1.80653661*w;
w = c01*c2; r += +1.87394106*w; g += +2.05027182*w; b += -0.29835996*w;
w = c01*c3; r += +2.56609566*w; g += +7.03428198*w; b += +0.62575374*w;
w = c02*c3; r += +4.08329484*w; g += -1.40408358*w; b += +2.14995522*w;
w = c12*c3; r += +6.00078678*w; g += +2.55552042*w; b += +1.90739502*w;
return [r, g, b];
}
function latentToRgb(latent) {
var rgb = evalPolynomial(latent[0], latent[1], latent[2], latent[3]);
return [
(clamp01(rgb[0] + latent[4]) * 255.0 + 0.5) | 0,
(clamp01(rgb[1] + latent[5]) * 255.0 + 0.5) | 0,
(clamp01(rgb[2] + latent[6]) * 255.0 + 0.5) | 0,
];
}
function lerpLatent(latent1, latent2, t) {
var latentMix = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0];
for (var i = 0; i < 7; i++) {
latentMix[i] = (1.0 - t) * latent1[i] + t * latent2[i];
}
return latentMix;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment