Skip to content

Instantly share code, notes, and snippets.

@t-tera
Created March 3, 2022 15:35
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save t-tera/3a67444f442c6561da7609a2b352e94d to your computer and use it in GitHub Desktop.
Save t-tera/3a67444f442c6561da7609a2b352e94d to your computer and use it in GitHub Desktop.
TensorFlow.js LSTM sample
import {promises as fs} from 'fs';
import * as tf from '@tensorflow/tfjs-node';
const BEHIND_LEN = 10;
const NUM_EPOCHS = 20;
const MODEL_SAVE_DIR_URL = 'file://./model';
// 予測の入力文字列をTensorにする
function makeTensorForPrediction(str) {
// U+0001-U+007F以外の入力は不可
if (str.match(/[^\u0001-\u007F]/)) {
throw new Error(`Bad input: ${str}`);
}
// 先頭にBEHIND_LEN個の0x00を足す
let padded = '\x00'.repeat(BEHIND_LEN) + str;
// 末尾のBEHIND_LEN文字だけを切り取る
var behind = padded.substring(padded.length - BEHIND_LEN);
var buf = tf.buffer([1, BEHIND_LEN, 0x80], 'bool');
// 簡単のためASCIIコードでベクトル化する
for (let i = 0; i < behind.length; i++) {
buf.set(1, 0, i, behind.charCodeAt(i));
}
return buf.toTensor();
}
// 訓練の入力(X), 正解(y)のTensorを作成
function makeTensorsForTraining(lines) {
// 直前のBEHIND_LEN個の文字列の配列
let behindList = [];
// 正解の文字列(1文字)の配列
let aheadList = [];
lines.forEach((line) => {
// 先頭にBEHIND_LEN個の0x00を足す
let padded = '\x00'.repeat(BEHIND_LEN) + line;
for (let j = 0; j < line.length; j++) {
let behind = padded.substring(j, j + BEHIND_LEN);
behindList.push(behind);
aheadList.push(line[j]);
}
});
let buf_X = tf.buffer([behindList.length, BEHIND_LEN, 0x80], 'bool');
let buf_y = tf.buffer([behindList.length, 0x80], 'bool');
for (let i = 0; i < behindList.length; i++) {
buf_y.set(1, i, aheadList[i].charCodeAt(0));
for (let j = 0; j < behindList[i].length; j++) {
buf_X.set(1, i, j, behindList[i].charCodeAt(j));
}
}
return {X: buf_X.toTensor(), y: buf_y.toTensor()};
}
async function doTraining() {
// 訓練データ/テストデータ読み込みとTensor化
const {test_lines: testLines, train_lines: trainLines} = JSON.parse(await fs.readFile('./data.json', 'utf8'));
const {X: X_Train, y: y_Train} = makeTensorsForTraining(trainLines);
const {X: X_Test, y: y_Test} = makeTensorsForTraining(testLines);
// モデル定義
const model = tf.sequential();
const optimizer = tf.train.rmsprop(0.01);
model.add(tf.layers.lstm({units: 128, inputShape: [BEHIND_LEN, 0x80]}));
model.add(tf.layers.dense({units: 0x80}));
model.add(tf.layers.activation({activation: 'softmax'}));
model.compile({loss: 'categoricalCrossentropy', optimizer, metrics: ['accuracy']});
// 訓練
const args = {batchSize: 128, epochs: NUM_EPOCHS, validationData: [X_Test, y_Test]};
await model.fit(X_Train, y_Train, args);
// テストデータの正解率を表示
const score = model.evaluate(X_Test, y_Test);
console.log(`Test score: ${(await score[0].data())[0]}`);
console.log(`Test accuracy: ${(await score[1].data())[0]}`);
// モデルをファイル保存
await model.save(MODEL_SAVE_DIR_URL);
}
// strの次の文字を予測する
async function doPrediction(str) {
// モデルをファイルから読み込み
const model = await tf.loadLayersModel(`${MODEL_SAVE_DIR_URL}/model.json`);
// モデルへの入力用Tensorを作成
const tensor = makeTensorForPrediction(str);
// 予測実行
let preds = await model.predict(tensor).data();
// 確率が高い順にソート
preds = [...preds.entries()];
preds.sort((a, b) => {return b[1] - a[1]});
// 確率が高い順に1位から5位まで表示
console.log(str + '?');
preds.slice(0, 5).forEach(([cd, probability], rank) => {
let chr = String.fromCharCode(cd);
chr = cd < 0x20 || cd === 0x7F ? encodeURIComponent(chr) : chr;
console.log(`#${rank + 1}\t${chr}\t${(probability * 100).toFixed(1)}%`);
});
}
// 訓練実行 -> 予測
doTraining().then(() => {doPrediction('PostgreS')});
// 予測(既に訓練済みでモデルファイルが存在する場合)
//doPrediction('Ora');
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment