Created
April 5, 2022 21:06
-
-
Save yeus/a4d7cc6c97485597eb1e0d7fd720b4e3 to your computer and use it in GitHub Desktop.
spacy > 3.0 transformers contextual vectors pipeline component
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from spacy.language import Language | |
from spacy.tokens import Doc | |
@Language.factory('trf_vectors') | |
class TrfContextualVectors: | |
""" | |
Spacy pipeline which add transformer vectors to each token based on user hooks. | |
https://spacy.io/usage/processing-pipelines#custom-components-user-hooks | |
https://github.com/explosion/spaCy/discussions/6511 | |
""" | |
def __init__(self, nlp: Language, name: str): | |
self.name = name | |
Doc.set_extension("trf_token_vecs", default=None) | |
def __call__(self, sdoc): | |
# inject hooks from this class into the pipeline | |
if type(sdoc) == str: | |
sdoc = self._nlp(sdoc) | |
# pre-calculate all vectors for every token: | |
# calculate groups for spacy token boundaries in the trf vectors | |
vec_idx_splits = np.cumsum(sdoc._.trf_data.align.lengths) | |
# get transformer vectors and reshape them into one large continous tensor | |
trf_vecs = sdoc._.trf_data.tensors[0].reshape(-1, 768) | |
# calculate mapping groups from spacy tokens to transformer vector indices | |
vec_idxs = np.split(sdoc._.trf_data.align.dataXd, vec_idx_splits) | |
# take sum of mapped transformer vector indices for spacy vectors | |
vecs = np.stack([trf_vecs[idx].sum(0) for idx in vec_idxs[:-1]]) | |
sdoc._.trf_token_vecs = vecs | |
sdoc.user_token_hooks["vector"] = self.vector | |
# sdoc.user_span_hooks["vector"] = self.vector | |
# sdoc.user_hooks["vector"] = self.vector | |
sdoc.user_token_hooks["has_vector"] = self.has_vector | |
# sdoc.user_token_hooks["similarity"] = self.similarity | |
# sdoc.user_span_hooks["similarity"] = self.similarity | |
# sdoc.user_hooks["similarity"] = self.similarity | |
return sdoc | |
def vector(self, token): | |
return token.doc._.trf_token_vecs[token.i] | |
def has_vector(self, token): | |
return True |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hi @yeus, this is amazing implementation, any hint how I could get vectors from my model based on your class:
I'm getting error