Skip to content

Instantly share code, notes, and snippets.

@ohyicong
Last active September 18, 2023 12:06
Show Gist options
  • Save ohyicong/b1e9dab5eec6371b404dbe603ac4685d to your computer and use it in GitHub Desktop.
Save ohyicong/b1e9dab5eec6371b404dbe603ac4685d to your computer and use it in GitHub Desktop.
LSTM Auto-Complete Model

LSTM auto complete model

Train your own auto-complete model on the web with Tensorflow.js

OS support

  1. Chrome Web browser

Usage

  1. Download predicts.html
  2. Upload dataset.txt
  3. Train
  4. Test
awesome,test,grove,one,two,apple
<!DOCTYPE html>
<html>
<head>
<title>Auto-complete</title>
<!-- Import TensorFlow.js -->
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@2.0.0/dist/tf.min.js"></script>
<!-- Import tfjs-vis -->
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs-vis@1.0.2/dist/tfjs-vis.umd.min.js"></script>
<!-- Import visual frameworks -->
<link rel="stylesheet" href="https://stackpath.bootstrapcdn.com/bootstrap/4.5.2/css/bootstrap.min.css" integrity="sha384-JcKb8q3iqJ61gNV9KGb8thSsNjpSL0n8PARn9HuZOnIxN0hoP+VmmDGMN5t9UJ0Z" crossorigin="anonymous">
<link rel="stylesheet" href="https://use.fontawesome.com/releases/v5.7.0/css/all.css" integrity="sha384-lZN37f5QGtY3VHgisS14W3ExzMWZxybE1SJSEsQp9S+oqd12jhcu+A56Ebc1zFSJ" crossorigin="anonymous">
</head>
<body class="">
<div class="container p-0">
<div class="container p-0" >
<header >
<nav class="navbar navbar-dark bg-dark rounded">
<a class="navbar-brand">
<i class="fas fa-comment"></i>
<strong>Auto-complete</strong>
</a>
</nav>
</header>
</div>
<div class="container mt-3">
<div class="row">
<div class="col">
<label class="form-label">Input Text</label>
<input class="form-control" type="input" id="pred_features" pattern="^[a-z]*$">
</div>
<div class="col">
<label class="form-label">Predicted Text</label>
<input class="form-control" type="input" id="pred_labels" disabled>
</div>
</div>
<div>
<div class="card mt-3 bg-light">
<h5 class="card-header">Training Parameters</h5>
<div class="card-body">
<div class="row">
<div class="col-2">
<label>Max word length</label>
</div>
<div class="col-10">
<input type="number" id="max_len" onchange="max_len=parseInt(this.value);document.getElementById('pred_features').maxLength=parseInt(this.value)" min=1>
</div>
</div>
<div class="row mt-3">
<div class="col-2">
<label>Epochs</label>
</div>
<div class="col-10">
<input type="number" id="epochs" onchange="epochs=parseInt(this.value)" min=1>
</div>
</div>
<div class="row mt-3">
<div class="col-2">
<label>Batch Size</label>
</div>
<div class="col-10">
<input type="number" id="batch_size" onchange="batch_size=parseInt(this.value)" min=1>
</div>
</div>
<div class="row mt-3">
<div class="col-2">
<button class="btn btn-secondary btn-sm" style="cursor:pointer;" type="button" onclick="document.getElementById('file').click()">
<i class="fa fa-upload" aria-hidden="true"></i>
Dataset
</button>
<input style="display:none;"type="file" id="file" >
</div>
<div class="col-10">
<label id="file_name">Supported format: csv, tsv, txt.</label>
</div>
</div>
</div>
</div>
<input class="btn btn-secondary btn-block mt-3" type="button" id="train" value="Train">
<div class="d-flex justify-content-center mt-3" onclick="showVizer()">
<span class="spinner-grow spinner-grow" style="display:none" id="status" aria-hidden="true"></span>
</div>
<div class="float-right">
<a href="https://www.youtube.com/user/chirpieful">
<i class="fab fa-youtube text-danger" aria-hidden="true"></i>
</a>
<a href="https://ohyicong.medium.com/">
<i class="fab fa-medium text-dark" aria-hidden="true"></i>
</a>
</div>
</div>
</body>
<!-- main script file -->
<script>
const ALPHA_LEN = 26;
var sample_len = 1;
var batch_size = 32;
var epochs = 250;
var max_len = 10;
var words = [];
var model = create_model(max_len,ALPHA_LEN);
var status = "";
function setup(){
document.getElementById('file')
.addEventListener('change', function() {
var fr=new FileReader();
fr.onload=function(){
result = fr.result
filesize = result.length
delimiters = ['\r\n',',','\t',' '];
document.getElementById('file_name').innerText = "Supported format: csv, tsv, txt.";
for (let i in delimiters){
length = result.split(delimiters[i]).length
if(length!=filesize && length>1){
words=result.split(delimiters[i]);
document.getElementById('file_name').innerText = document.getElementById('file').files[0].name
}
}
}
fr.readAsText(this.files[0]);
})
document.getElementById('train')
.addEventListener('click',async ()=>{
if(words.length<=0){
alert("No dataset");
return
}
document.getElementById("status").style.display = "block";
document.getElementById("train").style.display = "none";
try{
filtered_words = preprocessing_stage_1(words,max_len);
int_words = preprocessing_stage_2(filtered_words,max_len);
train_features = preprocessing_stage_3(int_words,max_len,sample_len);
train_labels = preprocessing_stage_4(int_words,max_len,sample_len);
train_features = preprocessing_stage_5(train_features,max_len,ALPHA_LEN);
train_labels = preprocessing_stage_5(train_labels,max_len,ALPHA_LEN);
model = await create_model(max_len,ALPHA_LEN)
await trainModel(model, train_features, train_labels);
await model.save('downloads://autocorrect_model');
//memory management
train_features.dispose();
train_labels.dispose();
}catch (err){
alert("No enough GPU space. Please reduce your dataset size.");
}
document.getElementById("status").style.display = "none";
document.getElementById("train").style.display = "block";
})
document.getElementById('pred_features')
.addEventListener('keyup',()=>{
console.log( document.getElementById('pred_features').value);
let pattern = new RegExp("^[a-z]{1,"+max_len+"}$");
let pred_features = []
pred_features.push(document.getElementById('pred_features').value);
if(pred_features[0].length<sample_len+1 || !pattern.test(pred_features[0])){
document.getElementById('pred_labels').value="";
return;
}
pred_features = preprocessing_stage_2(pred_features,max_len);
pred_features = preprocessing_stage_5(pred_features,max_len,ALPHA_LEN);
let pred_labels = model.predict(pred_features);
pred_labels = postprocessing_stage_1(pred_labels)
pred_labels = postprocessing_stage_2(pred_labels,max_len)[0]
document.getElementById('pred_labels').value=pred_labels.join("");
})
document.getElementById("max_len").value=max_len
document.getElementById("epochs").value=epochs
document.getElementById("batch_size").value=batch_size
document.getElementById("pred_features").maxLength = document.getElementById("max_len").value;
}
function showVizer(){
const visorInstance = tfvis.visor();
if (!visorInstance.isOpen()) {
visorInstance.toggle();
}
}
function preprocessing_stage_1(words,max_len){
// function to filter the wordlist
// string [] = words
// int = max_len
status = "Preprocessing Data 1";
console.log(status);
let filtered_words = [];
var pattern = new RegExp("^[a-z]{1,"+max_len+"}$");
for (let i in words){
var is_valid = pattern.test(words[i]);
if (is_valid) filtered_words.push(words[i]);
}
return filtered_words;
}
function preprocessing_stage_2(words,max_len){
// function to convert the wordlist to int
// string [] = words
// int = max_len
status = "Preprocessing Data 2";
console.log(status);
let int_words = [];
for (let i in words){
int_words.push(word_to_int(words[i],max_len))
}
return int_words;
}
function preprocessing_stage_3(words,max_len,sample_len){
// function to perform sliding window on wordlist
// int [] = words
// int = max_len, sample_len
status = "Preprocessing Data 3";
console.log(status);
let input_data = [];
for (let x in words){
let letters = [];
for (let y=sample_len+1;y<max_len+1;y++){
input_data.push(words[x].slice(0,y).concat(Array(max_len-y).fill(0)));
}
}
return input_data;
}
function preprocessing_stage_4(words,max_len,sample_len){
// function to ensure that training data size y == x
// int [] = words
// int = max_len, sample_len
status = "Preprocessing Data 4";
console.log(status);
let output_data = [];
for (let x in words){
for (let y=sample_len+1;y<max_len+1;y++){
output_data.push(words[x]);
}
}
return output_data;
}
function preprocessing_stage_5(words,max_len,alpha_len){
// function to convert int to onehot encoding
// int [] = words
// int = max_len, alpha_len
status = "Preprocessing Data 5";
console.log(status);
return tf.oneHot(tf.tensor2d(words,[words.length,max_len],dtype='int32'), alpha_len);
}
function postprocessing_stage_1(words){
//function to decode onehot encoding
return words.argMax(-1).arraySync();
}
function postprocessing_stage_2(words,max_len){
//function to convert int to words
let results = [];
for (let i in words){
results.push(int_to_word(words[i],max_len));
}
return results;
}
function word_to_int (word,max_len){
// char [] = word
// int = max_len
let encode = [];
for (let i = 0; i < max_len; i++) {
if(i<word.length){
let letter = word.slice(i, i+1);
encode.push(letter.charCodeAt(0)-96);
}else{
encode.push(0)
}
}
return encode;
}
function int_to_word (word,max_len){
// int [] = word
// int = max_len
let decode = []
for (let i = 0; i < max_len; i++) {
if(word[i]==0){
decode.push("");
}else{
decode.push(String.fromCharCode(word[i]+96))
}
}
return decode;
}
async function create_model(max_len,alpha_len){
var model = tf.sequential();
await model.add(tf.layers.lstm({
units:alpha_len*2,
inputShape:[max_len,alpha_len],
dropout:0.2,
recurrentDropout:0.2,
useBias: true,
returnSequences:true,
activation:"relu"
}))
await model.add(tf.layers.timeDistributed({
layer: tf.layers.dense({
units: alpha_len,
dropout:0.2,
activation:"softmax"
})
}));
model.summary();
return model
}
async function trainModel(model, train_features, train_labels) {
status = "Training Model";
console.log(status)
// Prepare the model for training.
model.compile({
optimizer: tf.train.adam(),
loss: 'categoricalCrossentropy',
metrics: ['mse']
})
await model.fit(train_features, train_labels, {
epochs,
batch_size,
shuffle: true,
callbacks: tfvis.show.fitCallbacks(
{ name: 'Training' },
['loss', 'mse'],
{ height: 200, callbacks: ['onEpochEnd'] }
)
});
return;
}
setup();
</script>
</html>
@momen-hafez
Copy link

Hi, this job is great, I want to ask if it's possible to develop it to predict sentences not only words also to rank the predictions according to the dataset.
Regards.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment