Train your own auto-complete model on the web with Tensorflow.js
- Chrome Web browser
- Download predicts.html
- Upload dataset.txt
- Train
- Test
| awesome,test,grove,one,two,apple |
| <!DOCTYPE html> | |
| <html> | |
| <head> | |
| <title>Auto-complete</title> | |
| <!-- Import TensorFlow.js --> | |
| <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@2.0.0/dist/tf.min.js"></script> | |
| <!-- Import tfjs-vis --> | |
| <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs-vis@1.0.2/dist/tfjs-vis.umd.min.js"></script> | |
| <!-- Import visual frameworks --> | |
| <link rel="stylesheet" href="https://stackpath.bootstrapcdn.com/bootstrap/4.5.2/css/bootstrap.min.css" integrity="sha384-JcKb8q3iqJ61gNV9KGb8thSsNjpSL0n8PARn9HuZOnIxN0hoP+VmmDGMN5t9UJ0Z" crossorigin="anonymous"> | |
| <link rel="stylesheet" href="https://use.fontawesome.com/releases/v5.7.0/css/all.css" integrity="sha384-lZN37f5QGtY3VHgisS14W3ExzMWZxybE1SJSEsQp9S+oqd12jhcu+A56Ebc1zFSJ" crossorigin="anonymous"> | |
| </head> | |
| <body class=""> | |
| <div class="container p-0"> | |
| <div class="container p-0" > | |
| <header > | |
| <nav class="navbar navbar-dark bg-dark rounded"> | |
| <a class="navbar-brand"> | |
| <i class="fas fa-comment"></i> | |
| <strong>Auto-complete</strong> | |
| </a> | |
| </nav> | |
| </header> | |
| </div> | |
| <div class="container mt-3"> | |
| <div class="row"> | |
| <div class="col"> | |
| <label class="form-label">Input Text</label> | |
| <input class="form-control" type="input" id="pred_features" pattern="^[a-z]*$"> | |
| </div> | |
| <div class="col"> | |
| <label class="form-label">Predicted Text</label> | |
| <input class="form-control" type="input" id="pred_labels" disabled> | |
| </div> | |
| </div> | |
| <div> | |
| <div class="card mt-3 bg-light"> | |
| <h5 class="card-header">Training Parameters</h5> | |
| <div class="card-body"> | |
| <div class="row"> | |
| <div class="col-2"> | |
| <label>Max word length</label> | |
| </div> | |
| <div class="col-10"> | |
| <input type="number" id="max_len" onchange="max_len=parseInt(this.value);document.getElementById('pred_features').maxLength=parseInt(this.value)" min=1> | |
| </div> | |
| </div> | |
| <div class="row mt-3"> | |
| <div class="col-2"> | |
| <label>Epochs</label> | |
| </div> | |
| <div class="col-10"> | |
| <input type="number" id="epochs" onchange="epochs=parseInt(this.value)" min=1> | |
| </div> | |
| </div> | |
| <div class="row mt-3"> | |
| <div class="col-2"> | |
| <label>Batch Size</label> | |
| </div> | |
| <div class="col-10"> | |
| <input type="number" id="batch_size" onchange="batch_size=parseInt(this.value)" min=1> | |
| </div> | |
| </div> | |
| <div class="row mt-3"> | |
| <div class="col-2"> | |
| <button class="btn btn-secondary btn-sm" style="cursor:pointer;" type="button" onclick="document.getElementById('file').click()"> | |
| <i class="fa fa-upload" aria-hidden="true"></i> | |
| Dataset | |
| </button> | |
| <input style="display:none;"type="file" id="file" > | |
| </div> | |
| <div class="col-10"> | |
| <label id="file_name">Supported format: csv, tsv, txt.</label> | |
| </div> | |
| </div> | |
| </div> | |
| </div> | |
| <input class="btn btn-secondary btn-block mt-3" type="button" id="train" value="Train"> | |
| <div class="d-flex justify-content-center mt-3" onclick="showVizer()"> | |
| <span class="spinner-grow spinner-grow" style="display:none" id="status" aria-hidden="true"></span> | |
| </div> | |
| <div class="float-right"> | |
| <a href="https://www.youtube.com/user/chirpieful"> | |
| <i class="fab fa-youtube text-danger" aria-hidden="true"></i> | |
| </a> | |
| <a href="https://ohyicong.medium.com/"> | |
| <i class="fab fa-medium text-dark" aria-hidden="true"></i> | |
| </a> | |
| </div> | |
| </div> | |
| </body> | |
| <!-- main script file --> | |
| <script> | |
| const ALPHA_LEN = 26; | |
| var sample_len = 1; | |
| var batch_size = 32; | |
| var epochs = 250; | |
| var max_len = 10; | |
| var words = []; | |
| var model = create_model(max_len,ALPHA_LEN); | |
| var status = ""; | |
| function setup(){ | |
| document.getElementById('file') | |
| .addEventListener('change', function() { | |
| var fr=new FileReader(); | |
| fr.onload=function(){ | |
| result = fr.result | |
| filesize = result.length | |
| delimiters = ['\r\n',',','\t',' ']; | |
| document.getElementById('file_name').innerText = "Supported format: csv, tsv, txt."; | |
| for (let i in delimiters){ | |
| length = result.split(delimiters[i]).length | |
| if(length!=filesize && length>1){ | |
| words=result.split(delimiters[i]); | |
| document.getElementById('file_name').innerText = document.getElementById('file').files[0].name | |
| } | |
| } | |
| } | |
| fr.readAsText(this.files[0]); | |
| }) | |
| document.getElementById('train') | |
| .addEventListener('click',async ()=>{ | |
| if(words.length<=0){ | |
| alert("No dataset"); | |
| return | |
| } | |
| document.getElementById("status").style.display = "block"; | |
| document.getElementById("train").style.display = "none"; | |
| try{ | |
| filtered_words = preprocessing_stage_1(words,max_len); | |
| int_words = preprocessing_stage_2(filtered_words,max_len); | |
| train_features = preprocessing_stage_3(int_words,max_len,sample_len); | |
| train_labels = preprocessing_stage_4(int_words,max_len,sample_len); | |
| train_features = preprocessing_stage_5(train_features,max_len,ALPHA_LEN); | |
| train_labels = preprocessing_stage_5(train_labels,max_len,ALPHA_LEN); | |
| model = await create_model(max_len,ALPHA_LEN) | |
| await trainModel(model, train_features, train_labels); | |
| await model.save('downloads://autocorrect_model'); | |
| //memory management | |
| train_features.dispose(); | |
| train_labels.dispose(); | |
| }catch (err){ | |
| alert("No enough GPU space. Please reduce your dataset size."); | |
| } | |
| document.getElementById("status").style.display = "none"; | |
| document.getElementById("train").style.display = "block"; | |
| }) | |
| document.getElementById('pred_features') | |
| .addEventListener('keyup',()=>{ | |
| console.log( document.getElementById('pred_features').value); | |
| let pattern = new RegExp("^[a-z]{1,"+max_len+"}$"); | |
| let pred_features = [] | |
| pred_features.push(document.getElementById('pred_features').value); | |
| if(pred_features[0].length<sample_len+1 || !pattern.test(pred_features[0])){ | |
| document.getElementById('pred_labels').value=""; | |
| return; | |
| } | |
| pred_features = preprocessing_stage_2(pred_features,max_len); | |
| pred_features = preprocessing_stage_5(pred_features,max_len,ALPHA_LEN); | |
| let pred_labels = model.predict(pred_features); | |
| pred_labels = postprocessing_stage_1(pred_labels) | |
| pred_labels = postprocessing_stage_2(pred_labels,max_len)[0] | |
| document.getElementById('pred_labels').value=pred_labels.join(""); | |
| }) | |
| document.getElementById("max_len").value=max_len | |
| document.getElementById("epochs").value=epochs | |
| document.getElementById("batch_size").value=batch_size | |
| document.getElementById("pred_features").maxLength = document.getElementById("max_len").value; | |
| } | |
| function showVizer(){ | |
| const visorInstance = tfvis.visor(); | |
| if (!visorInstance.isOpen()) { | |
| visorInstance.toggle(); | |
| } | |
| } | |
| function preprocessing_stage_1(words,max_len){ | |
| // function to filter the wordlist | |
| // string [] = words | |
| // int = max_len | |
| status = "Preprocessing Data 1"; | |
| console.log(status); | |
| let filtered_words = []; | |
| var pattern = new RegExp("^[a-z]{1,"+max_len+"}$"); | |
| for (let i in words){ | |
| var is_valid = pattern.test(words[i]); | |
| if (is_valid) filtered_words.push(words[i]); | |
| } | |
| return filtered_words; | |
| } | |
| function preprocessing_stage_2(words,max_len){ | |
| // function to convert the wordlist to int | |
| // string [] = words | |
| // int = max_len | |
| status = "Preprocessing Data 2"; | |
| console.log(status); | |
| let int_words = []; | |
| for (let i in words){ | |
| int_words.push(word_to_int(words[i],max_len)) | |
| } | |
| return int_words; | |
| } | |
| function preprocessing_stage_3(words,max_len,sample_len){ | |
| // function to perform sliding window on wordlist | |
| // int [] = words | |
| // int = max_len, sample_len | |
| status = "Preprocessing Data 3"; | |
| console.log(status); | |
| let input_data = []; | |
| for (let x in words){ | |
| let letters = []; | |
| for (let y=sample_len+1;y<max_len+1;y++){ | |
| input_data.push(words[x].slice(0,y).concat(Array(max_len-y).fill(0))); | |
| } | |
| } | |
| return input_data; | |
| } | |
| function preprocessing_stage_4(words,max_len,sample_len){ | |
| // function to ensure that training data size y == x | |
| // int [] = words | |
| // int = max_len, sample_len | |
| status = "Preprocessing Data 4"; | |
| console.log(status); | |
| let output_data = []; | |
| for (let x in words){ | |
| for (let y=sample_len+1;y<max_len+1;y++){ | |
| output_data.push(words[x]); | |
| } | |
| } | |
| return output_data; | |
| } | |
| function preprocessing_stage_5(words,max_len,alpha_len){ | |
| // function to convert int to onehot encoding | |
| // int [] = words | |
| // int = max_len, alpha_len | |
| status = "Preprocessing Data 5"; | |
| console.log(status); | |
| return tf.oneHot(tf.tensor2d(words,[words.length,max_len],dtype='int32'), alpha_len); | |
| } | |
| function postprocessing_stage_1(words){ | |
| //function to decode onehot encoding | |
| return words.argMax(-1).arraySync(); | |
| } | |
| function postprocessing_stage_2(words,max_len){ | |
| //function to convert int to words | |
| let results = []; | |
| for (let i in words){ | |
| results.push(int_to_word(words[i],max_len)); | |
| } | |
| return results; | |
| } | |
| function word_to_int (word,max_len){ | |
| // char [] = word | |
| // int = max_len | |
| let encode = []; | |
| for (let i = 0; i < max_len; i++) { | |
| if(i<word.length){ | |
| let letter = word.slice(i, i+1); | |
| encode.push(letter.charCodeAt(0)-96); | |
| }else{ | |
| encode.push(0) | |
| } | |
| } | |
| return encode; | |
| } | |
| function int_to_word (word,max_len){ | |
| // int [] = word | |
| // int = max_len | |
| let decode = [] | |
| for (let i = 0; i < max_len; i++) { | |
| if(word[i]==0){ | |
| decode.push(""); | |
| }else{ | |
| decode.push(String.fromCharCode(word[i]+96)) | |
| } | |
| } | |
| return decode; | |
| } | |
| async function create_model(max_len,alpha_len){ | |
| var model = tf.sequential(); | |
| await model.add(tf.layers.lstm({ | |
| units:alpha_len*2, | |
| inputShape:[max_len,alpha_len], | |
| dropout:0.2, | |
| recurrentDropout:0.2, | |
| useBias: true, | |
| returnSequences:true, | |
| activation:"relu" | |
| })) | |
| await model.add(tf.layers.timeDistributed({ | |
| layer: tf.layers.dense({ | |
| units: alpha_len, | |
| dropout:0.2, | |
| activation:"softmax" | |
| }) | |
| })); | |
| model.summary(); | |
| return model | |
| } | |
| async function trainModel(model, train_features, train_labels) { | |
| status = "Training Model"; | |
| console.log(status) | |
| // Prepare the model for training. | |
| model.compile({ | |
| optimizer: tf.train.adam(), | |
| loss: 'categoricalCrossentropy', | |
| metrics: ['mse'] | |
| }) | |
| await model.fit(train_features, train_labels, { | |
| epochs, | |
| batch_size, | |
| shuffle: true, | |
| callbacks: tfvis.show.fitCallbacks( | |
| { name: 'Training' }, | |
| ['loss', 'mse'], | |
| { height: 200, callbacks: ['onEpochEnd'] } | |
| ) | |
| }); | |
| return; | |
| } | |
| setup(); | |
| </script> | |
| </html> |
Perfect use of LSTM on the Web, I loved it!!
Few suggestions for this:
Hi, this job is great, I want to ask if it's possible to develop it to predict sentences not only words also to rank the predictions according to the dataset.
Regards.