Created
March 3, 2022 15:35
-
-
Save t-tera/3a67444f442c6561da7609a2b352e94d to your computer and use it in GitHub Desktop.
TensorFlow.js LSTM sample
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import {promises as fs} from 'fs'; | |
import * as tf from '@tensorflow/tfjs-node'; | |
const BEHIND_LEN = 10; | |
const NUM_EPOCHS = 20; | |
const MODEL_SAVE_DIR_URL = 'file://./model'; | |
// 予測の入力文字列をTensorにする | |
function makeTensorForPrediction(str) { | |
// U+0001-U+007F以外の入力は不可 | |
if (str.match(/[^\u0001-\u007F]/)) { | |
throw new Error(`Bad input: ${str}`); | |
} | |
// 先頭にBEHIND_LEN個の0x00を足す | |
let padded = '\x00'.repeat(BEHIND_LEN) + str; | |
// 末尾のBEHIND_LEN文字だけを切り取る | |
var behind = padded.substring(padded.length - BEHIND_LEN); | |
var buf = tf.buffer([1, BEHIND_LEN, 0x80], 'bool'); | |
// 簡単のためASCIIコードでベクトル化する | |
for (let i = 0; i < behind.length; i++) { | |
buf.set(1, 0, i, behind.charCodeAt(i)); | |
} | |
return buf.toTensor(); | |
} | |
// 訓練の入力(X), 正解(y)のTensorを作成 | |
function makeTensorsForTraining(lines) { | |
// 直前のBEHIND_LEN個の文字列の配列 | |
let behindList = []; | |
// 正解の文字列(1文字)の配列 | |
let aheadList = []; | |
lines.forEach((line) => { | |
// 先頭にBEHIND_LEN個の0x00を足す | |
let padded = '\x00'.repeat(BEHIND_LEN) + line; | |
for (let j = 0; j < line.length; j++) { | |
let behind = padded.substring(j, j + BEHIND_LEN); | |
behindList.push(behind); | |
aheadList.push(line[j]); | |
} | |
}); | |
let buf_X = tf.buffer([behindList.length, BEHIND_LEN, 0x80], 'bool'); | |
let buf_y = tf.buffer([behindList.length, 0x80], 'bool'); | |
for (let i = 0; i < behindList.length; i++) { | |
buf_y.set(1, i, aheadList[i].charCodeAt(0)); | |
for (let j = 0; j < behindList[i].length; j++) { | |
buf_X.set(1, i, j, behindList[i].charCodeAt(j)); | |
} | |
} | |
return {X: buf_X.toTensor(), y: buf_y.toTensor()}; | |
} | |
async function doTraining() { | |
// 訓練データ/テストデータ読み込みとTensor化 | |
const {test_lines: testLines, train_lines: trainLines} = JSON.parse(await fs.readFile('./data.json', 'utf8')); | |
const {X: X_Train, y: y_Train} = makeTensorsForTraining(trainLines); | |
const {X: X_Test, y: y_Test} = makeTensorsForTraining(testLines); | |
// モデル定義 | |
const model = tf.sequential(); | |
const optimizer = tf.train.rmsprop(0.01); | |
model.add(tf.layers.lstm({units: 128, inputShape: [BEHIND_LEN, 0x80]})); | |
model.add(tf.layers.dense({units: 0x80})); | |
model.add(tf.layers.activation({activation: 'softmax'})); | |
model.compile({loss: 'categoricalCrossentropy', optimizer, metrics: ['accuracy']}); | |
// 訓練 | |
const args = {batchSize: 128, epochs: NUM_EPOCHS, validationData: [X_Test, y_Test]}; | |
await model.fit(X_Train, y_Train, args); | |
// テストデータの正解率を表示 | |
const score = model.evaluate(X_Test, y_Test); | |
console.log(`Test score: ${(await score[0].data())[0]}`); | |
console.log(`Test accuracy: ${(await score[1].data())[0]}`); | |
// モデルをファイル保存 | |
await model.save(MODEL_SAVE_DIR_URL); | |
} | |
// strの次の文字を予測する | |
async function doPrediction(str) { | |
// モデルをファイルから読み込み | |
const model = await tf.loadLayersModel(`${MODEL_SAVE_DIR_URL}/model.json`); | |
// モデルへの入力用Tensorを作成 | |
const tensor = makeTensorForPrediction(str); | |
// 予測実行 | |
let preds = await model.predict(tensor).data(); | |
// 確率が高い順にソート | |
preds = [...preds.entries()]; | |
preds.sort((a, b) => {return b[1] - a[1]}); | |
// 確率が高い順に1位から5位まで表示 | |
console.log(str + '?'); | |
preds.slice(0, 5).forEach(([cd, probability], rank) => { | |
let chr = String.fromCharCode(cd); | |
chr = cd < 0x20 || cd === 0x7F ? encodeURIComponent(chr) : chr; | |
console.log(`#${rank + 1}\t${chr}\t${(probability * 100).toFixed(1)}%`); | |
}); | |
} | |
// 訓練実行 -> 予測 | |
doTraining().then(() => {doPrediction('PostgreS')}); | |
// 予測(既に訓練済みでモデルファイルが存在する場合) | |
//doPrediction('Ora'); |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment