Skip to content

Instantly share code, notes, and snippets.

Created June 13, 2022 10:44
Show Gist options
  • Save louismagowan/dc465417bdfd47b4ce0c123a08cc3a45 to your computer and use it in GitHub Desktop.
Save louismagowan/dc465417bdfd47b4ce0c123a08cc3a45 to your computer and use it in GitHub Desktop.
spaCy and Keras embedding layers
# Imports
from keras.initializers import Constant
import spacy
from keras.layers import Embedding
def spacy_embedding(tokenizer, maxlen = 500, show_progress = False):
# Load the spacy pipeline
nlp = spacy.load("en_core_web_sm")
# Get vocab size of tokenizer
vocab_size = len(tokenizer.word_index) + 1
# Get the number of embedding dimensions SpaCy uses
embedding_dim = nlp("any_word").vector.shape[0]
# Create a matrix to use in embedding layer
embedding_matrix = np.zeros((vocab_size, embedding_dim))
# Iterate through our vocabulary, mapping words to spacy embedding
# this will take a while to run
for i, word in enumerate(tokenizer.word_index):
embedding_matrix[i] = nlp(word).vector
# Show progress if desired
if show_progress:
if i % 10000 == 0 and i > 0:
print(round(i*100/vocab_size, 3), "% complete")
# Load the embedding matrix as the weights matrix for the embedding layer
# Set trainable to False as the layer is already "learned"
Embedding_layer = Embedding(
input_length = maxlen,
name = "spacy_embedding")
return Embedding_layer
def keras_embedding(tokenizer, embedding_dim = 256, maxlen = 500):
# Get vocab size of tokenizer
vocab_size = len(tokenizer.word_index) + 1
# Load the embedding matrix as the weights matrix for the embedding layer
# Set trainable to False as the layer is already "learned"
Embedding_layer = Embedding(
input_length = maxlen,
name = "keras_embedding")
return Embedding_layer
# Generate the embeddings
embed_dict = dict()
embed_dict["spacy"] = spacy_embedding(tokenizer, show_progress = True, maxlen = 500)
embed_dict["keras"] = keras_embedding(tokenizer, maxlen = 500)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment