Created
April 25, 2024 14:22
-
-
Save bennyschmidt/ba79ba64faa5ba18334b4ae06c857641 to your computer and use it in GitHub Desktop.
A simple 8-dimensional word embedding (e.g. word2vec) from scratch
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
const bigrams = { | |
the: { | |
quick: 1, | |
sun: 9 | |
}, | |
quick: { | |
brown: 1, | |
and: 9 | |
}, | |
fast: { | |
as: 4, | |
for: 2 | |
}, | |
brown: { | |
fox: 3, | |
recluse: 8 | |
}, | |
red: { | |
light: 11, | |
robin: 5 | |
}, | |
fox: { | |
jumped: 3, | |
hole: 3 | |
}, | |
jumped: { | |
on: 3, | |
over: 6 | |
}, | |
on: { | |
the: 12, | |
their: 5 | |
}, | |
over: { | |
by: 3, | |
the: 6, | |
there: 4 | |
}, | |
the: { | |
lazy: 1, | |
second: 4, | |
entire: 3, | |
whole: 6, | |
reason: 8 | |
}, | |
lazy: { | |
dog: 6, | |
millennial: 3 | |
}, | |
dog: { | |
park: 5, | |
treat: 4, | |
walker: 8 | |
}, | |
randomly: { | |
i: 7, | |
he: 3 | |
}, | |
light: { | |
rain: 2 | |
}, | |
robin: { | |
hood: 4 | |
} | |
}; | |
// utils | |
const posSpecificityList = [ | |
"TO", | |
"CD", | |
"UH", | |
"FW", | |
"CC", | |
"EX", | |
"LS", | |
"RP", | |
"SYM", | |
"DT", | |
"WDT", | |
"MD", | |
"IN", | |
"POS", | |
"PRP", | |
"PRP$", | |
"RB", | |
"RBR", | |
"RBS", | |
"WRB", | |
"PDT", | |
"JJ", | |
"JJR", | |
"JJS", | |
"VB", | |
"VBD", | |
"VBG", | |
"VBN", | |
"VBP", | |
"VBZ", | |
"WP", | |
"WP$", | |
"NN", | |
"NNS", | |
"NNP", | |
"NNPS" | |
]; | |
const alphabet = "abcdefghijklmnopqrstuvwxyz".split(''); | |
const vowels = "aeiou".split(''); | |
const getTokenPrevalence = token => Object.keys(bigrams[token]).length; | |
const getMaxTokenPrevalence = () => Object.keys(bigrams) | |
.map(key => getTokenPrevalence(key)) | |
.sort((a, b) => a > b ? 1 : -1) | |
.pop(); | |
const getMostFrequentNextToken = token => Object | |
.keys(bigrams[token]) | |
.sort((a, b) => bigrams[token][a] > bigrams[token][b] ? 1 : -1) | |
.pop(); | |
const getMostFrequentNextTokenValue = token => ( | |
bigrams[token][getMostFrequentNextToken(token)] | |
); | |
const getMaxTokenFrequency = token => Object | |
.values(bigrams[token]) | |
.sort((a, b) => a > b ? 1 : -1) | |
.pop(); | |
const getMaxFrequency = () => Object | |
.keys(bigrams) | |
.map(getMaxTokenFrequency) | |
.sort((a, b) => a > b ? 1 : -1) | |
.pop(); | |
const getFirstVowel = string => { | |
const isVowel = character => ( | |
vowels.indexOf(character.toLowerCase()) !== -1 | |
); | |
for (const character of string) { | |
if (isVowel(character)) { | |
return character; | |
} | |
} | |
}; | |
const getPOSSpecificity = tag => posSpecificityList.indexOf(tag); | |
const getSequenceEmbeddings = (sequence = "The quick brown fox jumped over the lazy dog.") => { | |
console.log('QUERY', sequence); | |
const embeddings = {}; | |
// tokenize | |
const tokens = sequence | |
.toLowerCase() | |
.trim() | |
.replace(/[\p{P}$+<=>^`|~]/gu, '') | |
.split(' '); | |
for (const token of tokens) { | |
console.log('TOKEN', token); | |
// pos tag | |
const posTag = "RB"; | |
console.log('POS', posTag); | |
// frequency | |
const frequency = getMostFrequentNextTokenValue(token); | |
console.log('PART #0: FREQUENCY', frequency); | |
// prevalence | |
const prevalence = getTokenPrevalence(token); | |
console.log('PART #1: PREVALENCE', prevalence); | |
// specificity | |
const specificity = getPOSSpecificity(posTag); | |
console.log('PART #2: SPECIFICITY', specificity); | |
// length | |
const { length } = token; | |
console.log('PART #3: LENGTH', length); | |
// first letter | |
const firstLetter = alphabet.indexOf(token.charAt(0)); | |
console.log('PART #4: FIRST LETTER', firstLetter); | |
// last letter | |
const lastLetter = alphabet.indexOf(token.charAt(token.length - 1)); | |
console.log('PART #5: LAST LETTER', lastLetter); | |
// first vowel | |
const firstVowel = alphabet.indexOf(getFirstVowel(token)); | |
console.log('PART #6: FIRST VOWEL', firstVowel); | |
// last vowel | |
const lastVowel = alphabet.indexOf(getFirstVowel(token.split('').reverse().join(''))); | |
console.log('PART #7: LAST VOWEL', lastVowel); | |
// Embeddings | |
const prenormalized = [ | |
frequency, | |
prevalence, | |
specificity, | |
length, | |
firstLetter, | |
lastLetter, | |
firstVowel, | |
lastVowel | |
]; | |
console.log('PRENORMALIZED EMBEDDING:', prenormalized); | |
const maxFrequency = getMaxFrequency(); | |
const maxPrevalence = getMaxTokenPrevalence(); | |
const maxSpecificity = posSpecificityList.length - 1; | |
const maxLength = 20; | |
const maxFirstLetter = 25; | |
const maxLastLetter = 25; | |
const maxFirstVowel = 25; | |
const maxLastVowel = 25; | |
const maximums = [ | |
maxFrequency, | |
maxPrevalence, | |
maxSpecificity, | |
maxLength, | |
maxFirstLetter, | |
maxLastLetter, | |
maxFirstVowel, | |
maxLastVowel | |
]; | |
const embedding = ( | |
prenormalized.map((value, index) => ( | |
Math.max(0, Math.min(1, value / maximums[index]))) | |
) | |
); | |
console.log(`EMBEDDING "${token}":`, embedding); | |
embeddings[token] = embedding; | |
} | |
return embeddings; | |
}; | |
const embeddings = getSequenceEmbeddings(); | |
console.log('EMBEDDINGS:', Object.values(embeddings)); | |
const getSum = vector => ( | |
vector.reduce((a, b) => a + b) | |
); | |
const getDotProduct = (a, b) => ( | |
a.map((_, index) => a[index] * b[index]).reduce((m, n) => m + n) | |
); | |
const getSimilarityIndexByToken = token => { | |
const products = {}; | |
const tokenEmbedding = embeddings[token]; | |
for (const tokenComparison of Object.keys(embeddings)) { | |
const embedding = embeddings[tokenComparison]; | |
const dotProduct = getDotProduct(tokenEmbedding, embedding); | |
products[tokenComparison] = dotProduct; | |
} | |
const similarityIndex = Object.keys(products).sort((a, b) => ( | |
products[a] > products[b] ? 1 : -1 | |
)); | |
return similarityIndex; | |
}; | |
const testWord = 'quick'; | |
const similarityIndex = getSimilarityIndexByToken(testWord); | |
console.log(`The most similar word to "${testWord}" is "${similarityIndex[similarityIndex.length - 1]}".`, similarityIndex, '(ascending)'); |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment