Skip to content

Instantly share code, notes, and snippets.

@bricksdont
Last active June 25, 2019 15:35
Show Gist options
  • Save bricksdont/98127e1333467debb4d384612053f6a2 to your computer and use it in GitHub Desktop.
Save bricksdont/98127e1333467debb4d384612053f6a2 to your computer and use it in GitHub Desktop.
# import embeddings
!ls shared/MUSE
import musevecs
mdls = musevecs.MUSE('shared/MUSE/wiki.multi.{0}.vecfull.txt', {'en', 'de', 'fr'}, nmax=200000)
enmdl = mdls.vecmap['en']
# document classification
def open_split(path, label):
X, y = [], []
for line in open(path):
line = line.strip()
tokens = line.split(" ")
X.append(tokens)
y.append(label)
return X, y
TRAINSETS = {"europarl": "shared/textclass/europarl.en",
"subs": "shared/textclass/subs.en",
"wiki": "shared/textclass/wiki.en"}
X, y = [], []
for train_label, train_path in TRAINSETS.items():
X_part, y_part = open_split(train_path, train_label)
X.extend(X_part)
y.extend(y_part)
len(X), len(y)
import numpy as np
def replace_tokens_with_vectors(X):
X_vec = []
bad_counter = 0
for tokens in X:
tokens_vec = []
for token in tokens:
try:
tokens_vec.append(enmdl.getWordVec(token.lower()))
except KeyError:
pass
if len(tokens_vec) == 0:
tokens_vec = [np.zeros(300,)]
bad_counter += 1
X_vec.append(tokens_vec)
print("Num sentences that are all zeros: %d" % bad_counter)
return X_vec
def average_word_vectors(X_vec):
X = []
for vectors in X_vec:
X.append(np.mean(np.array(vectors), axis=0))
return X
# retain original text
X_text = X
X = replace_tokens_with_vectors(X)
X = average_word_vectors(X)
print(len(X), len(y))
# check if some examples have wrong shape
[x.shape for x in X if x.shape != (300,)]
from sklearn.neighbors import KNeighborsClassifier
clf = KNeighborsClassifier(n_neighbors=9)
clf.fit(X, y)
test_europarl_de = open_split("shared/textclass/test/europarl.de", "europarl")
test_europarl_en = open_split("shared/textclass/test/europarl.en", "europarl")
test_subs_de = open_split("shared/textclass/test/subs.de", "subs")
test_subs_en = open_split("shared/textclass/test/subs.en", "subs")
test_wiki_de = open_split("shared/textclass/test/wiki.de", "wiki")
test_wiki_en = open_split("shared/textclass/test/wiki.en", "wiki")
TESTS = {"test_europarl_de": test_europarl_de,
"test_europarl_en": test_europarl_en,
"test_subs_de": test_subs_de,
"test_subs_en": test_subs_en,
"test_wiki_de": test_wiki_de,
"test_wiki_en": test_wiki_en}
from sklearn.metrics import accuracy_score, classification_report
for test_label, test_data in TESTS.items():
X, y_true = test_data
# preprocess X
X = replace_tokens_with_vectors(X)
X = average_word_vectors(X)
y_pred = clf.predict(X)
print(test_label)
print(accuracy_score(y_true, y_pred))
print()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment