A Pen by Philipp Naderer on CodePen.
Last active
October 8, 2019 13:43
-
-
Save botic/e240d14e290fc4c0d98781ba3aa2d369 to your computer and use it in GitHub Desktop.
TensorFlow.js with SharedMobility.ai Datasets
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
<h1>TensorFlow.js ❤️ SharedMobility.ai</h1> | |
<button id="train">Train</button> | |
<ul id="logs"> | |
</ul> |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
const DateTime = luxon.DateTime; | |
const $logs = document.getElementById("logs"); | |
function logLine(line) { | |
const $line = document.createElement("li"); | |
$line.textContent = line; | |
$logs.appendChild($line); | |
} | |
document.getElementById("train").addEventListener("click", async () => { | |
const datasetURL = new URL("https://storage.googleapis.com/smai-public-datasets/citybikewien/service_2-station_42_citybikewien-oper-1055-2019-05-07_2019-05-31.csv"); | |
train(datasetURL).then(model => { | |
logLine(`Finshed training.`) | |
}); | |
}) | |
async function train(datasetURL) { | |
const INPUT_LENGTH = 48; | |
logLine(`Loading dataset: ${datasetURL.href}`); | |
const trainingSet = tf.data.csv(datasetURL.href).take(150).map(obj => convertCsvRecord(obj)); | |
logLine("Converted the input dataset."); | |
const model = tf.sequential(); | |
// build the layered model | |
model.add(tf.layers.dense({units: INPUT_LENGTH * 2, activation: "relu", inputShape: [ INPUT_LENGTH ]})); | |
model.add(tf.layers.dense({units: 4, activation: "softmax"})); | |
// Compile model to prepare for training. | |
model.compile({ | |
optimizer: tf.train.rmsprop(0.001), | |
loss: tf.losses.meanSquaredError, | |
metrics: ["accuracy"] | |
}); | |
model.summary(); | |
logLine("Starting the training ..."); | |
const history = await model.fitDataset(trainingSet, { | |
epochs: 5, | |
batchSize: 64, | |
shuffle: true, | |
callbacks: { | |
onEpochEnd: (epoch, logs) => { | |
logLine(`Finished epoch ${epoch}, acc: ${logs.acc}, loss: ${logs.loss}`); | |
} | |
} | |
}); | |
const loss = history.history.loss.slice(-1)[0]; | |
const acc = history.history.acc.slice(-1)[0]; | |
logLine(`Model for ${datasetURL.pathname} => acc ${acc} | loss ${loss}`); | |
return model; | |
} | |
function convertCsvRecord(record) { | |
return { | |
xs: tf.tensor2d([recordToInput(record)], [1, 48], "int32"), | |
ys: tf.tensor2d([recordToOutput(record)], [1, 4],"int32") | |
} | |
} | |
function recordToInput(record) { | |
const dt = DateTime.fromISO(record.timestamp, { zone: "Europe/Vienna" }).toUTC(); | |
const dayCategory = new Array(7).fill(0); | |
dayCategory[dt.weekday - 1] = 1; | |
const hourCategory = new Array(24).fill(0); | |
hourCategory[dt.hour] = 1; | |
const minuteCategory = new Array(4).fill(0); | |
if (dt.minute >= 0 && dt.minute < 15) { | |
minuteCategory[0] = 1; | |
} else if (dt.minute >= 15 && dt.minute < 30) { | |
minuteCategory[1] = 1; | |
} else if (dt.minute >= 30 && dt.minute < 45) { | |
minuteCategory[2] = 1; | |
} else if (dt.minute >= 45 && dt.minute < 60) { | |
minuteCategory[3] = 1; | |
} | |
// weather-based data | |
const rainCategory = new Array(4).fill(0); | |
if (record.rain >= 1) { | |
rainCategory[0] = 1; | |
} | |
if (record.rain >= 2.5) { | |
rainCategory[1] = 1; | |
} | |
if (record.rain >= 5) { | |
rainCategory[2] = 1; | |
} | |
if (record.rain >= 9) { | |
rainCategory[3] = 1; | |
} | |
const sunshineCategory = new Array(4).fill(0); | |
if (record.sunshine >= 25) { | |
sunshineCategory[0] = 1; | |
} | |
if (record.sunshine >= 50) { | |
sunshineCategory[1] = 1; | |
} | |
if (record.sunshine >= 75) { | |
sunshineCategory[2] = 1; | |
} | |
if (record.sunshine === 100) { | |
sunshineCategory[3] = 1; | |
} | |
const temperatureCategory = new Array(4).fill(0); | |
if (record.temperature >= 15) { | |
temperatureCategory[0] = 1; | |
} | |
if (record.temperature >= 20) { | |
temperatureCategory[1] = 1; | |
} | |
if (record.temperature >= 25) { | |
temperatureCategory[2] = 1; | |
} | |
if (record.temperature >= 30) { | |
temperatureCategory[3] = 1; | |
} | |
return [ | |
record.holiday, | |
dayCategory, | |
hourCategory, | |
minuteCategory, | |
rainCategory, | |
sunshineCategory, | |
temperatureCategory | |
].flat(1); | |
} | |
function recordToOutput(record) { | |
const load = record.vehicles_available / (record.vehicles_available + record.boxes_available); | |
return [ | |
load > 0.8 ? 1 : 0, // full of bikes | |
load > 0.5 && load <= 0.8 ? 1 : 0, | |
load > 0.2 && load <= 0.5 ? 1 : 0, | |
load <= 0.2 ? 1 : 0 // very low load | |
]; | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
<script src="https://cdnjs.cloudflare.com/ajax/libs/tensorflow/1.2.10/tf.min.js"></script> | |
<script src="https://cdn.jsdelivr.net/npm/luxon@1.19.3/build/global/luxon.min.js"></script> |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
ul { | |
font-family: monospace; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment