Skip to content

Instantly share code, notes, and snippets.

@botic
Last active October 8, 2019 13:43
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save botic/e240d14e290fc4c0d98781ba3aa2d369 to your computer and use it in GitHub Desktop.
Save botic/e240d14e290fc4c0d98781ba3aa2d369 to your computer and use it in GitHub Desktop.
TensorFlow.js with SharedMobility.ai Datasets
<h1>TensorFlow.js ❤️ SharedMobility.ai</h1>
<button id="train">Train</button>
<ul id="logs">
</ul>
const DateTime = luxon.DateTime;
const $logs = document.getElementById("logs");
function logLine(line) {
const $line = document.createElement("li");
$line.textContent = line;
$logs.appendChild($line);
}
document.getElementById("train").addEventListener("click", async () => {
const datasetURL = new URL("https://storage.googleapis.com/smai-public-datasets/citybikewien/service_2-station_42_citybikewien-oper-1055-2019-05-07_2019-05-31.csv");
train(datasetURL).then(model => {
logLine(`Finshed training.`)
});
})
async function train(datasetURL) {
const INPUT_LENGTH = 48;
logLine(`Loading dataset: ${datasetURL.href}`);
const trainingSet = tf.data.csv(datasetURL.href).take(150).map(obj => convertCsvRecord(obj));
logLine("Converted the input dataset.");
const model = tf.sequential();
// build the layered model
model.add(tf.layers.dense({units: INPUT_LENGTH * 2, activation: "relu", inputShape: [ INPUT_LENGTH ]}));
model.add(tf.layers.dense({units: 4, activation: "softmax"}));
// Compile model to prepare for training.
model.compile({
optimizer: tf.train.rmsprop(0.001),
loss: tf.losses.meanSquaredError,
metrics: ["accuracy"]
});
model.summary();
logLine("Starting the training ...");
const history = await model.fitDataset(trainingSet, {
epochs: 5,
batchSize: 64,
shuffle: true,
callbacks: {
onEpochEnd: (epoch, logs) => {
logLine(`Finished epoch ${epoch}, acc: ${logs.acc}, loss: ${logs.loss}`);
}
}
});
const loss = history.history.loss.slice(-1)[0];
const acc = history.history.acc.slice(-1)[0];
logLine(`Model for ${datasetURL.pathname} => acc ${acc} | loss ${loss}`);
return model;
}
function convertCsvRecord(record) {
return {
xs: tf.tensor2d([recordToInput(record)], [1, 48], "int32"),
ys: tf.tensor2d([recordToOutput(record)], [1, 4],"int32")
}
}
function recordToInput(record) {
const dt = DateTime.fromISO(record.timestamp, { zone: "Europe/Vienna" }).toUTC();
const dayCategory = new Array(7).fill(0);
dayCategory[dt.weekday - 1] = 1;
const hourCategory = new Array(24).fill(0);
hourCategory[dt.hour] = 1;
const minuteCategory = new Array(4).fill(0);
if (dt.minute >= 0 && dt.minute < 15) {
minuteCategory[0] = 1;
} else if (dt.minute >= 15 && dt.minute < 30) {
minuteCategory[1] = 1;
} else if (dt.minute >= 30 && dt.minute < 45) {
minuteCategory[2] = 1;
} else if (dt.minute >= 45 && dt.minute < 60) {
minuteCategory[3] = 1;
}
// weather-based data
const rainCategory = new Array(4).fill(0);
if (record.rain >= 1) {
rainCategory[0] = 1;
}
if (record.rain >= 2.5) {
rainCategory[1] = 1;
}
if (record.rain >= 5) {
rainCategory[2] = 1;
}
if (record.rain >= 9) {
rainCategory[3] = 1;
}
const sunshineCategory = new Array(4).fill(0);
if (record.sunshine >= 25) {
sunshineCategory[0] = 1;
}
if (record.sunshine >= 50) {
sunshineCategory[1] = 1;
}
if (record.sunshine >= 75) {
sunshineCategory[2] = 1;
}
if (record.sunshine === 100) {
sunshineCategory[3] = 1;
}
const temperatureCategory = new Array(4).fill(0);
if (record.temperature >= 15) {
temperatureCategory[0] = 1;
}
if (record.temperature >= 20) {
temperatureCategory[1] = 1;
}
if (record.temperature >= 25) {
temperatureCategory[2] = 1;
}
if (record.temperature >= 30) {
temperatureCategory[3] = 1;
}
return [
record.holiday,
dayCategory,
hourCategory,
minuteCategory,
rainCategory,
sunshineCategory,
temperatureCategory
].flat(1);
}
function recordToOutput(record) {
const load = record.vehicles_available / (record.vehicles_available + record.boxes_available);
return [
load > 0.8 ? 1 : 0, // full of bikes
load > 0.5 && load <= 0.8 ? 1 : 0,
load > 0.2 && load <= 0.5 ? 1 : 0,
load <= 0.2 ? 1 : 0 // very low load
];
}
<script src="https://cdnjs.cloudflare.com/ajax/libs/tensorflow/1.2.10/tf.min.js"></script>
<script src="https://cdn.jsdelivr.net/npm/luxon@1.19.3/build/global/luxon.min.js"></script>
ul {
font-family: monospace;
}

TensorFlow.js with SharedMobility.ai Datasets

The pen loads a custom dataset by SharedMobility.ai and generates a model.

A Pen by Philipp Naderer on CodePen.

License.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment