Skip to content

Instantly share code, notes, and snippets.

@bennyschmidt
Created April 25, 2024 14:22
Show Gist options
  • Save bennyschmidt/ba79ba64faa5ba18334b4ae06c857641 to your computer and use it in GitHub Desktop.
Save bennyschmidt/ba79ba64faa5ba18334b4ae06c857641 to your computer and use it in GitHub Desktop.
A simple 8-dimensional word embedding (e.g. word2vec) from scratch
const bigrams = {
the: {
quick: 1,
sun: 9
},
quick: {
brown: 1,
and: 9
},
fast: {
as: 4,
for: 2
},
brown: {
fox: 3,
recluse: 8
},
red: {
light: 11,
robin: 5
},
fox: {
jumped: 3,
hole: 3
},
jumped: {
on: 3,
over: 6
},
on: {
the: 12,
their: 5
},
over: {
by: 3,
the: 6,
there: 4
},
the: {
lazy: 1,
second: 4,
entire: 3,
whole: 6,
reason: 8
},
lazy: {
dog: 6,
millennial: 3
},
dog: {
park: 5,
treat: 4,
walker: 8
},
randomly: {
i: 7,
he: 3
},
light: {
rain: 2
},
robin: {
hood: 4
}
};
// utils
const posSpecificityList = [
"TO",
"CD",
"UH",
"FW",
"CC",
"EX",
"LS",
"RP",
"SYM",
"DT",
"WDT",
"MD",
"IN",
"POS",
"PRP",
"PRP$",
"RB",
"RBR",
"RBS",
"WRB",
"PDT",
"JJ",
"JJR",
"JJS",
"VB",
"VBD",
"VBG",
"VBN",
"VBP",
"VBZ",
"WP",
"WP$",
"NN",
"NNS",
"NNP",
"NNPS"
];
const alphabet = "abcdefghijklmnopqrstuvwxyz".split('');
const vowels = "aeiou".split('');
const getTokenPrevalence = token => Object.keys(bigrams[token]).length;
const getMaxTokenPrevalence = () => Object.keys(bigrams)
.map(key => getTokenPrevalence(key))
.sort((a, b) => a > b ? 1 : -1)
.pop();
const getMostFrequentNextToken = token => Object
.keys(bigrams[token])
.sort((a, b) => bigrams[token][a] > bigrams[token][b] ? 1 : -1)
.pop();
const getMostFrequentNextTokenValue = token => (
bigrams[token][getMostFrequentNextToken(token)]
);
const getMaxTokenFrequency = token => Object
.values(bigrams[token])
.sort((a, b) => a > b ? 1 : -1)
.pop();
const getMaxFrequency = () => Object
.keys(bigrams)
.map(getMaxTokenFrequency)
.sort((a, b) => a > b ? 1 : -1)
.pop();
const getFirstVowel = string => {
const isVowel = character => (
vowels.indexOf(character.toLowerCase()) !== -1
);
for (const character of string) {
if (isVowel(character)) {
return character;
}
}
};
const getPOSSpecificity = tag => posSpecificityList.indexOf(tag);
const getSequenceEmbeddings = (sequence = "The quick brown fox jumped over the lazy dog.") => {
console.log('QUERY', sequence);
const embeddings = {};
// tokenize
const tokens = sequence
.toLowerCase()
.trim()
.replace(/[\p{P}$+<=>^`|~]/gu, '')
.split(' ');
for (const token of tokens) {
console.log('TOKEN', token);
// pos tag
const posTag = "RB";
console.log('POS', posTag);
// frequency
const frequency = getMostFrequentNextTokenValue(token);
console.log('PART #0: FREQUENCY', frequency);
// prevalence
const prevalence = getTokenPrevalence(token);
console.log('PART #1: PREVALENCE', prevalence);
// specificity
const specificity = getPOSSpecificity(posTag);
console.log('PART #2: SPECIFICITY', specificity);
// length
const { length } = token;
console.log('PART #3: LENGTH', length);
// first letter
const firstLetter = alphabet.indexOf(token.charAt(0));
console.log('PART #4: FIRST LETTER', firstLetter);
// last letter
const lastLetter = alphabet.indexOf(token.charAt(token.length - 1));
console.log('PART #5: LAST LETTER', lastLetter);
// first vowel
const firstVowel = alphabet.indexOf(getFirstVowel(token));
console.log('PART #6: FIRST VOWEL', firstVowel);
// last vowel
const lastVowel = alphabet.indexOf(getFirstVowel(token.split('').reverse().join('')));
console.log('PART #7: LAST VOWEL', lastVowel);
// Embeddings
const prenormalized = [
frequency,
prevalence,
specificity,
length,
firstLetter,
lastLetter,
firstVowel,
lastVowel
];
console.log('PRENORMALIZED EMBEDDING:', prenormalized);
const maxFrequency = getMaxFrequency();
const maxPrevalence = getMaxTokenPrevalence();
const maxSpecificity = posSpecificityList.length - 1;
const maxLength = 20;
const maxFirstLetter = 25;
const maxLastLetter = 25;
const maxFirstVowel = 25;
const maxLastVowel = 25;
const maximums = [
maxFrequency,
maxPrevalence,
maxSpecificity,
maxLength,
maxFirstLetter,
maxLastLetter,
maxFirstVowel,
maxLastVowel
];
const embedding = (
prenormalized.map((value, index) => (
Math.max(0, Math.min(1, value / maximums[index])))
)
);
console.log(`EMBEDDING "${token}":`, embedding);
embeddings[token] = embedding;
}
return embeddings;
};
const embeddings = getSequenceEmbeddings();
console.log('EMBEDDINGS:', Object.values(embeddings));
const getSum = vector => (
vector.reduce((a, b) => a + b)
);
const getDotProduct = (a, b) => (
a.map((_, index) => a[index] * b[index]).reduce((m, n) => m + n)
);
const getSimilarityIndexByToken = token => {
const products = {};
const tokenEmbedding = embeddings[token];
for (const tokenComparison of Object.keys(embeddings)) {
const embedding = embeddings[tokenComparison];
const dotProduct = getDotProduct(tokenEmbedding, embedding);
products[tokenComparison] = dotProduct;
}
const similarityIndex = Object.keys(products).sort((a, b) => (
products[a] > products[b] ? 1 : -1
));
return similarityIndex;
};
const testWord = 'quick';
const similarityIndex = getSimilarityIndexByToken(testWord);
console.log(`The most similar word to "${testWord}" is "${similarityIndex[similarityIndex.length - 1]}".`, similarityIndex, '(ascending)');
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment