Skip to content

Instantly share code, notes, and snippets.

@kudkudak
Created March 2, 2017 12:28
Show Gist options
  • Save kudkudak/b7f8e91b7e1cf8938af83b6f43e9d100 to your computer and use it in GitHub Desktop.
Save kudkudak/b7f8e91b7e1cf8938af83b6f43e9d100 to your computer and use it in GitHub Desktop.
import sys
import numpy as np
from web import embedding
import web
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.decomposition import RandomizedPCA
import logging
logger = logging.getLogger(__name__)
class SentenceEmbedder(BaseEstimator, TransformerMixin):
"""
Transforms sentences into vector representation using word embeddings
Notes
-----
remember about cleaning and lowering words to match embedding vocabulary!
it is best if you don't change on_missing because it is a good test
that your intentions are met (i.e. no words are skipped because you forgot
to lower either Embedding or text dataset)
web.embedding.Embedding is assumed because pandas .loc/iloc indexing can
be very slow
to be fair it is worth restrcing all embeddings to common vocabulary
"""
def __init__(self, embedding, method="avg", on_missing="raise"):
self.on_missing = on_missing
if self.on_missing not in ['raise', 'skip']:
raise NotImplementedError("Not implemented {} on_missing arg".format(self.on_missing))
self.embedding = embedding
if not isinstance(self.embedding, web.embedding.Embedding):
raise NotImplementedError()
self.method = method
if self.method not in ['avg', 'concat']:
raise NotImplementedError()
def fit(self, X, y=None):
return self
def transform(self, X, y=None):
# words for 'raise' and 'avg' parameters
if self.method == 'avg':
X_tr = np.zeros(shape=(X.shape[0], len(self.embedding.vectors[0])), dtype=self.embedding.vectors.dtype)
elif self.method == "concat":
X_tr = np.zeros(shape=(X.shape[0], len(X[0]) * len(self.embedding.vectors[0])), dtype=self.embedding.vectors.dtype)
logger.info("Creating embedding via {} of shape {}".format(self.method, X_tr.shape))
if not all(len(query) == len(X[0]) for query in X):
raise RuntimeError("Variable sized sentences don't work with 'concat")
else:
raise NotImplementedError()
skipped = 0
for id, query in enumerate(X):
query_tr = []
for word in query:
if word in self.embedding:
query_tr.append(self.embedding[word])
else:
if self.on_missing == "raise":
raise RuntimeError("Not found word")
elif self.on_missing == "skip":
skipped += 1
else:
raise NotImplementedError()
if self.method == "avg":
X_tr[id] = np.mean(np.array(query_tr).T, axis=1).reshape(-1,)
elif self.method == "concat":
A = np.hstack([q.reshape(-1, ) for q in query_tr])
assert len(A) == X_tr.shape[1], "Constructed correctly sized vector"
X_tr[id] = A
else:
raise NotImplementedError()
if skipped > 0:
logger.warning("Skipped {} out of vocabulary words".format(skipped))
return X_tr
class DotProductEmbedder(BaseEstimator, TransformerMixin):
"""
Transforms sentence into dot products of required pairs
pairs = [(0,1)] means join 0 and 1 vector via method
"""
def __init__(self, embedding, method="diag", on_missing="raise", pairs=[[0,1]]):
self.on_missing = on_missing
self.pairs = pairs
if self.on_missing not in ['raise', 'skip']:
raise NotImplementedError("Not implemented {} on_missing arg".format(self.on_missing))
self.embedding = embedding
if not isinstance(self.embedding, web.embedding.Embedding):
raise NotImplementedError()
self.method = method
if self.method not in ['diag']:
raise NotImplementedError()
def fit(self, X, y=None):
return self
def transform(self, X, y=None):
# words for 'raise' and 'avg' parameters
L = len(self.embedding.vectors[0])
if self.method == 'diag':
X_tr = np.zeros(shape=(X.shape[0], L * len(self.pairs)), dtype=self.embedding.vectors.dtype)
if not all(len(query) == len(X[0]) for query in X):
raise RuntimeError("Variable sized sentences don't work with 'diag")
else:
raise NotImplementedError()
skipped = 0
for id, query in enumerate(X):
query_tr = []
for word in query:
if word in self.embedding:
query_tr.append(self.embedding[word])
else:
if self.on_missing == "raise":
raise RuntimeError("Not found word")
elif self.on_missing == "skip":
skipped += 1
else:
raise NotImplementedError()
if self.method == "diag":
start_id = 0
for p in self.pairs:
X_tr[id, start_id: (start_id + L)] = np.multiply(query_tr[p[0]], query_tr[p[1]])
start_id += L
else:
raise NotImplementedError()
if skipped > 0:
logger.warning("Skipped {} out of vocabulary words".format(skipped))
return X_tr
class DoubleListEmbedder(BaseEstimator, TransformerMixin):
"""
Transforms list of lists of words into list of lists of vectors
"""
def __init__(self, embedding, on_missing="raise"):
self.on_missing = on_missing
if self.on_missing not in ['raise', 'skip']:
raise NotImplementedError("Not implemented {} on_missing arg".format(self.on_missing))
self.embedding = embedding
if not isinstance(self.embedding, web.embedding.Embedding):
raise NotImplementedError()
def fit(self, X, y=None):
return self
def transform(self, X, y=None):
result = []
skipped = 0
for list_of_words in X:
list_of_vectors = []
for word in list_of_words:
if word in self.embedding:
list_of_vectors.append(self.embedding[word])
else:
if self.on_missing == "raise":
raise RuntimeError("Not found word")
elif self.on_missing == "skip":
skipped += 1
else:
raise NotImplementedError()
result.append(list_of_vectors)
if skipped > 0:
logger.warning("Skipped {} out of vocabulary words".format(skipped))
return result
class DoubleListDotProduct(BaseEstimator, TransformerMixin):
"""
Transforms list of lists of vectors into numpy array
Currently only n x 2 input supported
"""
def __init__(self, method="single", pairs=[[0,1]]):
assert(method in ("single", "diagonal", "quadruple_diagonal", "double_diagonal", "all", "all_sym"))
assert(pairs == [[0,1]])
self.method = method
self.pairs = pairs
def fit(self, X, y=None):
return self
def transform(self, X, y=None):
V = X[0][0]
if self.method == "single":
n_cols = 1
joiner = lambda v, w: np.dot(v.reshape((1,-1)),w.reshape((-1,1))).ravel()
elif self.method == "diagonal":
n_cols = V.shape[0]
joiner = lambda v, w: np.multiply(v,w).ravel()
elif self.method == "double_diagonal":
n_cols = V.shape[0] * 2
joiner = lambda v, w: np.hstack([np.multiply(v,w).ravel(), np.log((1 + np.multiply(v,w).ravel())/2.)])
elif self.method == "triple_diagonal":
n_cols = V.shape[0] * 3
joiner = lambda v, w: np.hstack([v, w, np.multiply(v,w).ravel()])
elif self.method == "quadruple_diagonal":
n_cols = V.shape[0] * 4
joiner = lambda v, w: np.hstack([v, w, np.multiply(v,w).ravel(), np.log((1 + np.multiply(v,w).ravel())/2.)])
elif self.method == "concat":
n_cols = V.shape[0] * 2
joiner = lambda v, w: np.hstack([v, w])
elif self.method == "all":
n_cols = V.shape[0] ** 2
joiner = lambda v, w: np.dot(v.reshape((-1,1)),w.reshape((1,-1))).ravel()
elif self.method == "all_sym":
n_cols = V.shape[0] ** 2
joiner = lambda v, w: ( ( np.dot(v.reshape((-1,1)),w.reshape((1,-1))) + np.dot(w.reshape((-1,1)),v.reshape((1,-1))) )/2 ).ravel()
else:
raise NotImplementedError()
X_tr = np.zeros(shape=(len(X), n_cols), dtype=V.dtype)
for i, pair in enumerate(X):
X_tr[i,:] = joiner(pair[0], pair[1])
return X_tr
class DoubleListConcat(BaseEstimator, TransformerMixin):
"""
Transforms list of lists of vectors into numpy array
Currently only n x k input supported
"""
def __init__(self):
pass
def fit(self, X, y=None):
return self
def transform(self, X, y=None):
how_many = len(X[0])
V = X[0][0]
X_tr = np.zeros((len(X), how_many*V.shape[0]), dtype=V.dtype)
for i, vectors in enumerate(X):
X_tr[i,:] = np.hstack(vectors)
return X_tr
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment