Last active
June 27, 2024 05:18
-
-
Save thomasdavis/fadbca01605fb3cb6491107794053a8e to your computer and use it in GitHub Desktop.
output of gpt-4o (some parts added by me, was too lazy to remove)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
export default function Talk() { | |
const [modelReady, setModelReady] = useState(false); | |
const [fullModel, setFullModel] = useState(null); | |
const [training, setTraining] = useState(false); | |
const [chat, setChat] = useState(''); | |
// slice book every 2000 characters | |
const bookParts = []; | |
for (let i = 0; i < book.length; i += 2000) { | |
bookParts.push(book.slice(i, i + 2000)); | |
} | |
// Training data: array of sentences | |
let sentences = [ | |
// 'The quick brown fox jumps over the lazy dog', | |
// 'I love programming in JavaScript', | |
// 'TensorFlow.js makes machine learning in JavaScript easy', | |
// 'Artificial Intelligence is the future', | |
// book.slice(0, 10), | |
// bookParts, | |
]; | |
const epochs = 2; | |
sentences = bookParts.slice(0, 10); | |
console.log({ sentences }); | |
// Preprocessing function to convert sentences to sequences of numbers | |
const preprocess = (sentences) => { | |
// Create a vocabulary from the sentences | |
const words = sentences.join(' ').split(' '); | |
const vocabulary = Array.from(new Set(words)); | |
// Create a word to index mapping | |
const wordIndex = {}; | |
vocabulary.forEach((word, index) => { | |
wordIndex[word] = index + 1; // Indexing starts from 1 | |
}); | |
// Convert sentences to sequences of numbers | |
const sequences = sentences.map((sentence) => { | |
return sentence.split(' ').map((word) => wordIndex[word]); | |
}); | |
return { sequences, wordIndex, vocabulary }; | |
}; | |
// Preprocess the sentences | |
const { sequences, wordIndex, vocabulary } = preprocess(sentences); | |
// Parameters | |
const vocabSize = vocabulary.length + 1; // +1 for padding | |
const embeddingDim = 8; // Dimension of the embedding space | |
const maxSeqLength = Math.max(...sequences.map((seq) => seq.length)); | |
// Pad sequences to the same length | |
const padSequences = (sequences, maxLen) => { | |
return sequences.map((seq) => { | |
const padded = Array(maxLen).fill(0); | |
for (let i = 0; i < seq.length; i++) { | |
padded[i] = seq[i]; | |
} | |
return padded; | |
}); | |
}; | |
// Pad sequences | |
const paddedSequences = padSequences(sequences, maxSeqLength); | |
// Prepare the input data | |
const inputs = tf.tensor2d(paddedSequences, [ | |
paddedSequences.length, | |
maxSeqLength, | |
]); | |
// Create labels by shifting sequences to the left and padding the end with zeros | |
const shiftedSequences = paddedSequences.map((seq) => | |
seq.slice(1).concat([0]) | |
); | |
// One-hot encode the labels | |
const oneHotEncode = (sequences, vocabSize) => { | |
return sequences.map((seq) => { | |
return seq.map((index) => { | |
const oneHot = Array(vocabSize).fill(0); | |
if (index > 0) { | |
oneHot[index] = 1; | |
} | |
return oneHot; | |
}); | |
}); | |
}; | |
const oneHotLabels = oneHotEncode(shiftedSequences, vocabSize); | |
// Convert the one-hot encoded labels to a 3D tensor | |
const labels = tf.tensor3d(oneHotLabels, [ | |
paddedSequences.length, | |
maxSeqLength, | |
vocabSize, | |
]); | |
// Define the model | |
const model = tf.sequential(); | |
model.add( | |
tf.layers.embedding({ | |
inputDim: vocabSize, | |
outputDim: embeddingDim, | |
inputLength: maxSeqLength, | |
}) | |
); | |
model.add(tf.layers.lstm({ units: 32, returnSequences: true })); | |
model.add(tf.layers.dense({ units: vocabSize, activation: 'softmax' })); | |
// Compile the model | |
model.compile({ | |
optimizer: 'adam', | |
loss: 'categoricalCrossentropy', | |
metrics: ['accuracy'], | |
}); | |
// Train the model | |
function train() { | |
setTraining(true); | |
setFullModel(null); | |
model.fit(inputs, labels, { | |
epochs, | |
batchSize: 1, | |
callbacks: { | |
onEpochEnd: (epoch, logs) => { | |
console.log(`Epoch ${epoch + 1}: loss = ${logs.loss}`); | |
}, | |
}, | |
}); | |
setFullModel(model); | |
setTraining(false); | |
} | |
function predict(testSentence) { | |
// Model training complete | |
// Test the model with a new sentence | |
const testSequence = testSentence | |
.split(' ') | |
.map((word) => wordIndex[word]) | |
.filter(Boolean); | |
const paddedTestSequence = padSequences([testSequence], maxSeqLength)[0]; | |
const testInput = tf.tensor2d([paddedTestSequence], [1, maxSeqLength]); | |
console.log({ testSequence, paddedTestSequence, testInput }); | |
// Predict the next word | |
const prediction = fullModel.predict(testInput); | |
const predictedIndex = prediction.argMax(-1).dataSync()[0]; | |
// Find the word corresponding to the predicted index | |
const predictedWord = Object.keys(wordIndex).find( | |
(key) => wordIndex[key] === predictedIndex | |
); | |
console.log(`Predicted next word: ${predictedWord}`); | |
} | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment