Skip to content

Instantly share code, notes, and snippets.

@dchaplinsky
Last active April 7, 2023 11:41
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save dchaplinsky/9b87817ff3e351d3126544691a0ef7c5 to your computer and use it in GitHub Desktop.
Save dchaplinsky/9b87817ff3e351d3126544691a0ef7c5 to your computer and use it in GitHub Desktop.
A script to embed sentences using different pooling strategy and rnn-like Flair embeddings
import argparse
from flair.data import Sentence
from flair.embeddings import (
DocumentEmbeddings,
FlairEmbeddings,
DocumentLMEmbeddings,
DocumentPoolEmbeddings,
)
from torch import Tensor
def embed(token: str, embeddings: DocumentEmbeddings) -> Tensor:
"""
Embed a token using a flair embedding.
"""
sentence = Sentence(token)
embeddings.embed(sentence)
return sentence.get_embedding()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--pooling", type=str, choices=["mean", "max", "min", "rnn"], required=True
)
parser.add_argument(
"--embeddings", type=str, choices=["medium", "large"], required=True
)
args = parser.parse_args()
if args.embeddings == "medium":
embeddings = [FlairEmbeddings("uk-forward"), FlairEmbeddings("uk-backward")]
else:
embeddings = [
FlairEmbeddings("/data/flair/uk-large/forward/best-lm.pt"),
FlairEmbeddings("/data/flair/uk-large/backward/best-lm.pt"),
]
if args.pooling in ["mean", "max", "min"]:
document_embeddings = DocumentPoolEmbeddings(embeddings, pooling=args.pooling)
else:
document_embeddings = DocumentLMEmbeddings(embeddings)
embedded = embed("капуста білоголова", document_embeddings)
print(embedded)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment