Last active
August 6, 2018 17:58
-
-
Save nokados/d5cfec00bc194822f89dff556ff62b29 to your computer and use it in GitHub Desktop.
Functions for dealing with embedding
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from gensim.models import KeyedVectors, Word2Vec | |
from gensim.models.fasttext import FastText as FT_gensim | |
from nltk.tokenize import sent_tokenize, word_tokenize | |
import json | |
import pandas as pd | |
from tqdm import trange | |
W2V_PATH = 'data/GoogleNews-vectors-negative300.bin' | |
def load_word2vec(path): | |
return KeyedVectors.load_word2vec_format(path, binary=True) | |
def save_word2vec(path): | |
model.wv.save_word2vec_format(save_name, binary=True) | |
def load_google_word2vec(): | |
return load_word2vec(W2V_PATH) | |
def read_corpus(preprocessed_text): | |
""" | |
Transform text into tokens for learning embeddings | |
:param preprocessed_text: str or pandas.Series or iterable | |
:return: pandas.Series if preprocessed_text is instance of pandas.Series else list | |
""" | |
if isinstance(preprocessed_text, str): | |
preprocessed_text = sent_tokenize(preprocessed_text) | |
if preprocessed_text is pd.Series: | |
corpus = preprocessed_text.progress_apply(word_tokenize) | |
else: | |
corpus = [word_tokenize(text) for text in preprocessed_text] | |
return corpus | |
def get_step_epochs(epochs): | |
epochs.sort() | |
step_epochs = [epochs[0]] | |
for i in range(1, len(epochs)): | |
step_epochs.append(epochs[i] - epochs[i-1]) | |
return step_epochs | |
def train_embedding_model(model, corpus, epochs, save_name): | |
if isinstance(epochs, list): | |
if not save_name: | |
raise Exception('Нужен save_name в случае нескольких чекпоинтов по эпохам') | |
step_epochs = get_step_epochs(epochs) | |
for ind, num_epochs in enumerate(step_epochs): | |
for i in trange(num_epochs): | |
model.train(corpus, total_examples=model.corpus_count, epochs=model.iter) | |
model.save('{}_e{}.w2v'.format(save_name, epochs[ind])) | |
else: | |
for i in trange(epochs): | |
model.train(corpus, total_examples=model.corpus_count, epochs=model.iter) | |
if save_name: | |
model.save(save_name + '.w2v') | |
return model | |
def create_embeddings(preprocessed_text, | |
intersect_with=None, | |
size=100, | |
window=5, | |
min_count=10, | |
workers=8, | |
iter_by_epoch=10, | |
epochs=300, | |
skip_gram=1, | |
save_name=None): | |
new_model = Word2Vec(size=size, | |
sg=skip_gram, | |
window=window, | |
min_count=min_count, | |
workers=workers, | |
iter=iter_by_epoch) | |
corpus = read_corpus(preprocessed_text) | |
new_model.build_vocab(corpus) | |
if intersect_with is not None: | |
new_model.intersect_word2vec_format(intersect_with, binary=True) | |
return train_embedding_model(new_model, corpus, epochs, save_name) | |
def create_fasttext_embeddings(preprocessed_text, | |
size=100, | |
window=5, | |
min_count=10, | |
workers=8, | |
epochs=300, | |
save_name=None): | |
corpus = read_corpus(preprocessed_text) | |
model_gensim = FT_gensim(size=size, | |
window=window, | |
min_count=min_count, | |
workers=8, | |
iter=epochs) | |
model_gensim.build_vocab(corpus) | |
model_gensim.train(corpus, total_examples=model_gensim.corpus_count, epochs=model_gensim.iter) | |
if save_name: | |
model_gensim.save(save_name+'.ft') | |
return model_gensim | |
def load_embeddings(path): | |
if path.endswith('.w2v'): | |
return Word2Vec.load(path) | |
elif path.endswith('.ft'): | |
return FT_gensim.load(path) | |
else: | |
raise NotImplemented('Only w2v and ft are supported') | |
def save_model(w2v_model, top_words=5000): | |
""" | |
Generate embeddings from a batch of text | |
:param embeddings_path: where to save the embeddings | |
:param vocab_path: where to save the word-index map | |
""" | |
vocab = dict([(k, v.index) for k, v in list(w2v_model.vocab.items())[:top_words]]) | |
indices = list(vocab.values()) | |
weights = np.take(w2v_model.vectors, indices, axis=0, mode='raise') | |
idx2word = dict([(v, k) for k, v in vocab.items()]) | |
vocab = {idx2word[old_index]: real_index for real_index, old_index in enumerate(indices)} | |
with open('data/vocab.json', 'w') as f: | |
f.write(json.dumps(vocab)) | |
np.save(open('data/weights.npz', 'wb'), weights) | |
def load_vocab(vocab_path='data/vocab.json'): | |
""" | |
Load word -> index and index -> word mappings | |
:param vocab_path: where the word-index map is saved | |
:return: word2idx, idx2word | |
""" | |
with open(vocab_path, 'r') as f: | |
data = json.loads(f.read()) | |
word2idx = data | |
idx2word = dict([(v, k) for k, v in data.items()]) | |
return word2idx, idx2word | |
def load_weights(weights_path='data/weights.npz'): | |
weights = np.load(open(weights_path, 'rb')) | |
return weights |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment