Skip to content

Instantly share code, notes, and snippets.

@OverPoweredDev
Created September 29, 2022 14:21
Show Gist options
  • Save OverPoweredDev/034258cb827c6b0932c778277312cd1a to your computer and use it in GitHub Desktop.
Save OverPoweredDev/034258cb827c6b0932c778277312cd1a to your computer and use it in GitHub Desktop.
Skip Gram
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/OverPoweredDev/034258cb827c6b0932c778277312cd1a/skip-gram.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "_5KKfBfnenmb"
},
"outputs": [],
"source": [
"from nltk.corpus import gutenberg\n",
"from string import punctuation\n",
"import nltk \n",
"import numpy as np\n",
"from keras.preprocessing import text\n",
"from keras.preprocessing.sequence import skipgrams \n",
"from keras.layers import *\n",
"from keras.layers.core import Dense, Reshape\n",
"from keras.layers.embeddings import Embedding\n",
"from keras.models import Model,Sequential \n",
"import re"
]
},
{
"cell_type": "code",
"source": [
"nltk.download('gutenberg')\n",
"nltk.download('punkt')\n",
"nltk.download('stopwords')\n",
"stop_words = nltk.corpus.stopwords.words('english')"
],
"metadata": {
"id": "BspAGfvAeuKe"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"with open(\"text\", \"r\") as f:\n",
" bible = f.read()\n",
"remove_terms = punctuation + '0123456789'\n",
"wpt = nltk.WordPunctTokenizer()\n",
"def normalize_document(doc):\n",
" # lower case and remove special characters\\whitespaces\n",
" # doc = re.sub(r'[^a-zA-Z\\s]', '', doc,re.I|re.A)\n",
" # print(doc)\n",
" doc = doc.lower()\n",
" doc = doc.strip()\n",
" # tokenize document\n",
" tokens = wpt.tokenize(doc)\n",
" # filter stopwords out of document\n",
" filtered_tokens = [token for token in tokens if token not in stop_words]\n",
" # re-create document from filtered tokens\n",
" doc = ' '.join(filtered_tokens)\n",
" return doc\n",
"normalize_corpus = np.vectorize(normalize_document)"
],
"metadata": {
"id": "uFXoxvade4JF"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"norm_bible = [[word.lower() for word in sent.split() if word not in remove_terms] for sent in bible.split('\\n')]\n",
"print(norm_bible)\n",
"norm_bible = [' '.join(tok_sent) for tok_sent in norm_bible]\n",
"norm_bible = filter(None, normalize_corpus(norm_bible))\n",
"norm_bible = [tok_sent for tok_sent in norm_bible if len(tok_sent.split()) > 2]\n",
"tokenizer = text.Tokenizer()\n",
"tokenizer.fit_on_texts(norm_bible)\n",
"word2id = tokenizer.word_index\n",
"id2word = {v:k for k, v in word2id.items()}\n",
"vocab_size = len(word2id) + 1\n",
"wids = [[word2id[w] for w in text.text_to_word_sequence(doc)] for doc in norm_bible]\n",
"print('Vocabulary Size:', vocab_size)\n",
"print('Vocabulary Sample:', list(word2id.items())[:5])"
],
"metadata": {
"id": "R1CMFnysfMx5"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# generate skip-grams\n",
"skip_grams = [skipgrams(wid, vocabulary_size=vocab_size, window_size=10) for wid in wids]\n",
"# view sample skip-grams\n",
"pairs, labels = skip_grams[0][0], skip_grams[0][1]\n",
"for i in range(10):\n",
" print(\"({:s} ({:d}), {:s} ({:d})) -> {:d}\".format(\n",
" id2word[pairs[i][0]], pairs[i][0], \n",
" id2word[pairs[i][1]], pairs[i][1], \n",
" labels[i])) "
],
"metadata": {
"id": "-8nae4xWfT0Q"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# build skip-gram architecture\n",
"embed_size = 100\n",
"word_model = Sequential()\n",
"word_model.add(Embedding(vocab_size, embed_size,\n",
" embeddings_initializer=\"glorot_uniform\",\n",
" input_length=1))\n",
"word_model.add(Reshape((embed_size, )))\n",
"context_model = Sequential()\n",
"context_model.add(Embedding(vocab_size, embed_size,\n",
" embeddings_initializer=\"glorot_uniform\",\n",
" input_length=1))\n",
"context_model.add(Reshape((embed_size,)))\n",
"merged_output = add([word_model.output, context_model.output]) \n",
"model_combined = Sequential()\n",
"model_combined.add(Dense(1, kernel_initializer=\"glorot_uniform\", activation=\"sigmoid\"))\n",
"final_model = Model([word_model.input, context_model.input], model_combined(merged_output))\n",
"final_model.compile(loss=\"mean_squared_error\", optimizer=\"rmsprop\")\n",
"final_model.summary()"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "b73IPQyHfeqv",
"outputId": "7adcc433-bbea-4512-818d-5248d8799462"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Model: \"model_2\"\n",
"__________________________________________________________________________________________________\n",
" Layer (type) Output Shape Param # Connected to \n",
"==================================================================================================\n",
" embedding_4_input (InputLayer) [(None, 1)] 0 [] \n",
" \n",
" embedding_5_input (InputLayer) [(None, 1)] 0 [] \n",
" \n",
" embedding_4 (Embedding) (None, 1, 100) 25800 ['embedding_4_input[0][0]'] \n",
" \n",
" embedding_5 (Embedding) (None, 1, 100) 25800 ['embedding_5_input[0][0]'] \n",
" \n",
" reshape_4 (Reshape) (None, 100) 0 ['embedding_4[0][0]'] \n",
" \n",
" reshape_5 (Reshape) (None, 100) 0 ['embedding_5[0][0]'] \n",
" \n",
" add_2 (Add) (None, 100) 0 ['reshape_4[0][0]', \n",
" 'reshape_5[0][0]'] \n",
" \n",
" sequential_8 (Sequential) (None, 1) 101 ['add_2[0][0]'] \n",
" \n",
"==================================================================================================\n",
"Total params: 51,701\n",
"Trainable params: 51,701\n",
"Non-trainable params: 0\n",
"__________________________________________________________________________________________________\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"for epoch in range(1, 10):\n",
" loss = 0\n",
" for i, elem in enumerate(skip_grams):\n",
" pair_first_elem = np.array(list(zip(*elem[0]))[0], dtype='int32')\n",
" pair_second_elem = np.array(list(zip(*elem[0]))[1], dtype='int32')\n",
" labels = np.array(elem[1], dtype='int32')\n",
" X = [pair_first_elem, pair_second_elem]\n",
" Y = labels\n",
" if i % 10000 == 0:\n",
" print('Processed {} (skip_first, skip_second, relevance) pairs'.format(i))\n",
" loss += final_model.train_on_batch(X,Y) \n",
" print('Epoch:', epoch, 'Loss:', loss) "
],
"metadata": {
"id": "U9BF1Wb1foLa"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"from sklearn.metrics.pairwise import euclidean_distances\n",
" \n",
"word_embed_layer = word_model.layers[0]\n",
"weights = word_embed_layer.get_weights()[0][1:]\n",
"distance_matrix = euclidean_distances(weights)\n",
"print(distance_matrix.shape)\n",
"similar_words = {search_term: [id2word[idx] for idx in distance_matrix[word2id[search_term]-1].argsort()[1:6]+1] \n",
" for search_term in ['king']}\n",
"similar_words "
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "4vxCeVQhf2QJ",
"outputId": "a513b60f-a1b3-4656-9dbe-6a45f50815c7"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"(257, 257)\n"
]
},
{
"output_type": "execute_result",
"data": {
"text/plain": [
"{'king': ['archon', 'consort', 'wheeled', 'back', 'child']}"
]
},
"metadata": {},
"execution_count": 57
}
]
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment