Skip to content

Instantly share code, notes, and snippets.

@nikogamulin
Created June 20, 2019 11:01
Show Gist options
  • Save nikogamulin/173bc65dca79ff0a6cefe5e856c4e5d4 to your computer and use it in GitHub Desktop.
Save nikogamulin/173bc65dca79ff0a6cefe5e856c4e5d4 to your computer and use it in GitHub Desktop.
flair-test.ipynb
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"metadata": {
"ExecuteTime": {
"start_time": "2019-06-20T10:46:01.166626Z",
"end_time": "2019-06-20T10:46:04.936750Z"
},
"trusted": true
},
"cell_type": "code",
"source": "from flair.data import Sentence\nfrom flair.models import SequenceTagger\nfrom numpy import dot\nfrom numpy.linalg import norm\n\n# make a sentence\nsentence = Sentence('I love Berlin .')\n\n# load the NER tagger\ntagger = SequenceTagger.load('ner')\n\n# run NER over sentence\ntagger.predict(sentence)",
"execution_count": 30,
"outputs": [
{
"output_type": "stream",
"text": "2019-06-20 12:46:01,169 loading file /Users/nikogamulin/.flair/models/en-ner-conll03-v0.4.pt\n",
"name": "stdout"
},
{
"output_type": "execute_result",
"execution_count": 30,
"data": {
"text/plain": "[Sentence: \"I love Berlin .\" - 4 Tokens]"
},
"metadata": {}
}
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": "The Sentence now has entity annotations. Print the sentence to see what the tagger found."
},
{
"metadata": {
"ExecuteTime": {
"start_time": "2019-06-20T09:51:46.524368Z",
"end_time": "2019-06-20T09:51:46.529038Z"
},
"trusted": true
},
"cell_type": "code",
"source": "print(sentence)\nprint('The following NER tags are found:')\n\n# iterate over entities and print\nfor entity in sentence.get_spans('ner'):\n print(entity)",
"execution_count": 4,
"outputs": [
{
"output_type": "stream",
"text": "Sentence: \"I love Berlin .\" - 4 Tokens\nThe following NER tags are found:\nLOC-span [3]: \"Berlin\"\n",
"name": "stdout"
}
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": "# Tutorial 1: NLP Base Types"
},
{
"metadata": {
"ExecuteTime": {
"start_time": "2019-06-20T09:52:46.936673Z",
"end_time": "2019-06-20T09:52:46.941986Z"
},
"trusted": true
},
"cell_type": "code",
"source": "# The sentence objects holds a sentence that we may want to embed or tag\nfrom flair.data import Sentence\n\n# Make a sentence object by passing a whitespace tokenized string\nsentence = Sentence('The grass is green .')\n\n# Print the object to see what's in there\nprint(sentence)",
"execution_count": 5,
"outputs": [
{
"output_type": "stream",
"text": "Sentence: \"The grass is green .\" - 5 Tokens\n",
"name": "stdout"
}
]
},
{
"metadata": {
"ExecuteTime": {
"start_time": "2019-06-20T09:53:11.320247Z",
"end_time": "2019-06-20T09:53:11.325115Z"
},
"trusted": true
},
"cell_type": "code",
"source": "# using the token id\nprint(sentence.get_token(4))\n# using the index itself\nprint(sentence[3])",
"execution_count": 6,
"outputs": [
{
"output_type": "stream",
"text": "Token: 4 green\nToken: 4 green\n",
"name": "stdout"
}
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": "## Tokenization"
},
{
"metadata": {
"ExecuteTime": {
"start_time": "2019-06-20T09:54:10.798334Z",
"end_time": "2019-06-20T09:54:10.803539Z"
},
"trusted": true
},
"cell_type": "code",
"source": "# Make a sentence object by passing an untokenized string and the 'use_tokenizer' flag\nsentence = Sentence('The grass is green.', use_tokenizer=True)\n\n# Print the object to see what's in there\nprint(sentence)",
"execution_count": 7,
"outputs": [
{
"output_type": "stream",
"text": "Sentence: \"The grass is green .\" - 5 Tokens\n",
"name": "stdout"
}
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": "## Adding Tags to Tokens"
},
{
"metadata": {
"ExecuteTime": {
"start_time": "2019-06-20T09:54:56.470612Z",
"end_time": "2019-06-20T09:54:56.474827Z"
},
"trusted": true
},
"cell_type": "code",
"source": "# add a tag to a word in the sentence\nsentence[3].add_tag('ner', 'color')\n\n# print the sentence with all tags of this type\nprint(sentence.to_tagged_string())",
"execution_count": 8,
"outputs": [
{
"output_type": "stream",
"text": "The grass is green <color> .\n",
"name": "stdout"
}
]
},
{
"metadata": {
"ExecuteTime": {
"start_time": "2019-06-20T09:55:33.444253Z",
"end_time": "2019-06-20T09:55:33.448702Z"
},
"trusted": true
},
"cell_type": "code",
"source": "from flair.data import Label\n\ntag: Label = sentence[3].get_tag('ner')\n\nprint(f'\"{sentence[3]}\" is tagged as \"{tag.value}\" with confidence score \"{tag.score}\"')",
"execution_count": 9,
"outputs": [
{
"output_type": "stream",
"text": "\"Token: 4 green\" is tagged as \"color\" with confidence score \"1.0\"\n",
"name": "stdout"
}
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": "# Tutorial 2: Tagging your Text"
},
{
"metadata": {
"ExecuteTime": {
"start_time": "2019-06-20T09:57:15.561146Z",
"end_time": "2019-06-20T09:57:15.564011Z"
}
},
"cell_type": "markdown",
"source": "## Tagging with Pre-Trained Sequence Tagging Models"
},
{
"metadata": {
"ExecuteTime": {
"start_time": "2019-06-20T09:57:29.257212Z",
"end_time": "2019-06-20T09:57:31.188681Z"
},
"trusted": true
},
"cell_type": "code",
"source": "from flair.models import SequenceTagger\n\ntagger = SequenceTagger.load('ner')",
"execution_count": 11,
"outputs": [
{
"output_type": "stream",
"text": "2019-06-20 11:57:29,259 loading file /Users/nikogamulin/.flair/models/en-ner-conll03-v0.4.pt\n",
"name": "stdout"
}
]
},
{
"metadata": {
"ExecuteTime": {
"start_time": "2019-06-20T09:57:49.695126Z",
"end_time": "2019-06-20T09:57:49.919593Z"
},
"trusted": true
},
"cell_type": "code",
"source": "sentence = Sentence('George Washington went to Washington .')\n\n# predict NER tags\ntagger.predict(sentence)\n\n# print sentence with predicted tags\nprint(sentence.to_tagged_string())",
"execution_count": 12,
"outputs": [
{
"output_type": "stream",
"text": "George <B-PER> Washington <E-PER> went to Washington <S-LOC> .\n",
"name": "stdout"
}
]
},
{
"metadata": {
"ExecuteTime": {
"start_time": "2019-06-20T09:58:19.095615Z",
"end_time": "2019-06-20T09:58:19.099759Z"
},
"trusted": true
},
"cell_type": "code",
"source": "for entity in sentence.get_spans('ner'):\n print(entity)",
"execution_count": 13,
"outputs": [
{
"output_type": "stream",
"text": "PER-span [1,2]: \"George Washington\"\nLOC-span [5]: \"Washington\"\n",
"name": "stdout"
}
]
},
{
"metadata": {
"ExecuteTime": {
"start_time": "2019-06-20T09:59:24.533186Z",
"end_time": "2019-06-20T09:59:24.537410Z"
},
"trusted": true
},
"cell_type": "code",
"source": "sentence = Sentence(\"Burger King is not a best choice for a healthy food.\", use_tokenizer=True)",
"execution_count": 14,
"outputs": []
},
{
"metadata": {
"ExecuteTime": {
"start_time": "2019-06-20T09:59:48.188423Z",
"end_time": "2019-06-20T09:59:48.527928Z"
},
"trusted": true
},
"cell_type": "code",
"source": "# predict NER tags\ntagger.predict(sentence)\nfor entity in sentence.get_spans('ner'):\n print(entity)",
"execution_count": 15,
"outputs": [
{
"output_type": "stream",
"text": "ORG-span [1,2]: \"Burger King\"\n",
"name": "stdout"
}
]
},
{
"metadata": {
"ExecuteTime": {
"start_time": "2019-06-20T10:00:29.771210Z",
"end_time": "2019-06-20T10:00:29.775402Z"
},
"trusted": true
},
"cell_type": "code",
"source": "print(sentence.to_dict(tag_type='ner'))",
"execution_count": 16,
"outputs": [
{
"output_type": "stream",
"text": "{'text': 'Burger King is not a best choice for a healthy food.', 'labels': [], 'entities': [{'text': 'Burger King', 'start_pos': 0, 'end_pos': 11, 'type': 'ORG', 'confidence': 0.6567533910274506}]}\n",
"name": "stdout"
}
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": "## Embedding"
},
{
"metadata": {},
"cell_type": "markdown",
"source": "### Word Embedding"
},
{
"metadata": {
"ExecuteTime": {
"start_time": "2019-06-20T10:04:30.966438Z",
"end_time": "2019-06-20T10:06:44.197381Z"
},
"trusted": true
},
"cell_type": "code",
"source": "from flair.embeddings import WordEmbeddings\n\n# init embedding\nglove_embedding = WordEmbeddings('glove')",
"execution_count": 17,
"outputs": [
{
"output_type": "stream",
"text": "2019-06-20 12:04:31,118 https://s3.eu-central-1.amazonaws.com/alan-nlp/resources/embeddings/glove.gensim.vectors.npy not found in cache, downloading to /var/folders/nn/wsc7gmvs5cz7pf2mdj7d2pwh0000gn/T/tmpxi_2v3d1\n",
"name": "stdout"
},
{
"output_type": "stream",
"text": "100%|██████████| 160000128/160000128 [01:54<00:00, 1402777.43B/s]",
"name": "stderr"
},
{
"output_type": "stream",
"text": "2019-06-20 12:06:25,499 copying /var/folders/nn/wsc7gmvs5cz7pf2mdj7d2pwh0000gn/T/tmpxi_2v3d1 to cache at /Users/nikogamulin/.flair/embeddings/glove.gensim.vectors.npy\n",
"name": "stdout"
},
{
"output_type": "stream",
"text": "\n",
"name": "stderr"
},
{
"output_type": "stream",
"text": "2019-06-20 12:06:25,970 removing temp file /var/folders/nn/wsc7gmvs5cz7pf2mdj7d2pwh0000gn/T/tmpxi_2v3d1\n2019-06-20 12:06:26,517 https://s3.eu-central-1.amazonaws.com/alan-nlp/resources/embeddings/glove.gensim not found in cache, downloading to /var/folders/nn/wsc7gmvs5cz7pf2mdj7d2pwh0000gn/T/tmplxnhk39b\n",
"name": "stdout"
},
{
"output_type": "stream",
"text": "100%|██████████| 21494764/21494764 [00:16<00:00, 1318189.15B/s]",
"name": "stderr"
},
{
"output_type": "stream",
"text": "2019-06-20 12:06:43,180 copying /var/folders/nn/wsc7gmvs5cz7pf2mdj7d2pwh0000gn/T/tmplxnhk39b to cache at /Users/nikogamulin/.flair/embeddings/glove.gensim\n2019-06-20 12:06:43,231 removing temp file /var/folders/nn/wsc7gmvs5cz7pf2mdj7d2pwh0000gn/T/tmplxnhk39b\n",
"name": "stdout"
},
{
"output_type": "stream",
"text": "\n/Users/nikogamulin/workspace/pytorch-lab/venv/lib/python3.7/site-packages/smart_open/smart_open_lib.py:398: UserWarning: This function is deprecated, use smart_open.open instead. See the migration notes for details: https://github.com/RaRe-Technologies/smart_open/blob/master/README.rst#migrating-to-the-new-open-function\n 'See the migration notes for details: %s' % _MIGRATION_NOTES_URL\n",
"name": "stderr"
}
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": "### BERT, ELMo, and Flair Embeddings"
},
{
"metadata": {
"ExecuteTime": {
"start_time": "2019-06-20T10:08:04.262075Z",
"end_time": "2019-06-20T10:08:57.968889Z"
},
"trusted": true
},
"cell_type": "code",
"source": "from flair.embeddings import FlairEmbeddings\n\n# init embedding\nflair_embedding_forward = FlairEmbeddings('news-forward')\n\n# create a sentence\nsentence = Sentence('The grass is green .')\n\n# embed words in sentence\nflair_embedding_forward.embed(sentence)",
"execution_count": 18,
"outputs": [
{
"output_type": "stream",
"text": "2019-06-20 12:08:04,476 https://s3.eu-central-1.amazonaws.com/alan-nlp/resources/embeddings-v0.4.1/big-news-forward--h2048-l1-d0.05-lr30-0.25-20/news-forward-0.4.1.pt not found in cache, downloading to /var/folders/nn/wsc7gmvs5cz7pf2mdj7d2pwh0000gn/T/tmpj1qzymrj\n",
"name": "stdout"
},
{
"output_type": "stream",
"text": "100%|██████████| 73034624/73034624 [00:52<00:00, 1387700.78B/s]",
"name": "stderr"
},
{
"output_type": "stream",
"text": "2019-06-20 12:08:57,501 copying /var/folders/nn/wsc7gmvs5cz7pf2mdj7d2pwh0000gn/T/tmpj1qzymrj to cache at /Users/nikogamulin/.flair/embeddings/news-forward-0.4.1.pt\n",
"name": "stdout"
},
{
"output_type": "stream",
"text": "\n",
"name": "stderr"
},
{
"output_type": "stream",
"text": "2019-06-20 12:08:57,674 removing temp file /var/folders/nn/wsc7gmvs5cz7pf2mdj7d2pwh0000gn/T/tmpj1qzymrj\n",
"name": "stdout"
},
{
"output_type": "execute_result",
"execution_count": 18,
"data": {
"text/plain": "[Sentence: \"The grass is green .\" - 5 Tokens]"
},
"metadata": {}
}
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": "#### Recommended Flair Usage"
},
{
"metadata": {
"ExecuteTime": {
"start_time": "2019-06-20T10:09:55.593948Z",
"end_time": "2019-06-20T10:10:50.242773Z"
},
"trusted": true
},
"cell_type": "code",
"source": "from flair.embeddings import WordEmbeddings, FlairEmbeddings, StackedEmbeddings\n\n# create a StackedEmbedding object that combines glove and forward/backward flair embeddings\nstacked_embeddings = StackedEmbeddings([\n WordEmbeddings('glove'), \n FlairEmbeddings('news-forward'), \n FlairEmbeddings('news-backward'),\n ])",
"execution_count": 19,
"outputs": [
{
"output_type": "stream",
"text": "/Users/nikogamulin/workspace/pytorch-lab/venv/lib/python3.7/site-packages/smart_open/smart_open_lib.py:398: UserWarning: This function is deprecated, use smart_open.open instead. See the migration notes for details: https://github.com/RaRe-Technologies/smart_open/blob/master/README.rst#migrating-to-the-new-open-function\n 'See the migration notes for details: %s' % _MIGRATION_NOTES_URL\n",
"name": "stderr"
},
{
"output_type": "stream",
"text": "2019-06-20 12:09:57,028 https://s3.eu-central-1.amazonaws.com/alan-nlp/resources/embeddings-v0.4.1/big-news-backward--h2048-l1-d0.05-lr30-0.25-20/news-backward-0.4.1.pt not found in cache, downloading to /var/folders/nn/wsc7gmvs5cz7pf2mdj7d2pwh0000gn/T/tmp5ao07d0m\n",
"name": "stdout"
},
{
"output_type": "stream",
"text": "100%|██████████| 73034575/73034575 [00:52<00:00, 1392467.21B/s]",
"name": "stderr"
},
{
"output_type": "stream",
"text": "2019-06-20 12:10:49,739 copying /var/folders/nn/wsc7gmvs5cz7pf2mdj7d2pwh0000gn/T/tmp5ao07d0m to cache at /Users/nikogamulin/.flair/embeddings/news-backward-0.4.1.pt\n",
"name": "stdout"
},
{
"output_type": "stream",
"text": "\n",
"name": "stderr"
},
{
"output_type": "stream",
"text": "2019-06-20 12:10:49,996 removing temp file /var/folders/nn/wsc7gmvs5cz7pf2mdj7d2pwh0000gn/T/tmp5ao07d0m\n",
"name": "stdout"
}
]
},
{
"metadata": {
"ExecuteTime": {
"start_time": "2019-06-20T10:11:02.200773Z",
"end_time": "2019-06-20T10:11:02.324187Z"
},
"trusted": true
},
"cell_type": "code",
"source": "sentence = Sentence('The grass is green .')\n\n# just embed a sentence using the StackedEmbedding as you would with any single embedding.\nstacked_embeddings.embed(sentence)\n\n# now check out the embedded tokens.\nfor token in sentence:\n print(token)\n print(token.embedding)",
"execution_count": 20,
"outputs": [
{
"output_type": "stream",
"text": "Token: 1 The\ntensor([-3.8194e-02, -2.4487e-01, 7.2812e-01, ..., -4.4014e-04,\n -3.9301e-02, 1.0601e-02])\nToken: 2 grass\ntensor([-8.1353e-01, 9.4042e-01, -2.4048e-01, ..., -3.7749e-04,\n -2.3563e-02, 1.1700e-02])\nToken: 3 is\ntensor([-0.5426, 0.4148, 1.0322, ..., -0.0061, 0.0112, 0.0100])\nToken: 4 green\ntensor([-0.6791, 0.3491, -0.2398, ..., -0.0026, -0.0118, 0.0455])\nToken: 5 .\ntensor([-3.3979e-01, 2.0941e-01, 4.6348e-01, ..., -2.3405e-04,\n 3.8688e-03, 5.7725e-03])\n",
"name": "stdout"
}
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": "## Document Embeddings"
},
{
"metadata": {
"ExecuteTime": {
"start_time": "2019-06-20T10:13:01.042761Z",
"end_time": "2019-06-20T10:13:01.045592Z"
},
"trusted": true
},
"cell_type": "code",
"source": "from flair.embeddings import WordEmbeddings, FlairEmbeddings, DocumentPoolEmbeddings, Sentence",
"execution_count": 21,
"outputs": []
},
{
"metadata": {
"ExecuteTime": {
"start_time": "2019-06-20T10:13:12.709443Z",
"end_time": "2019-06-20T10:13:14.846507Z"
},
"trusted": true
},
"cell_type": "code",
"source": "# initialize the word embeddings\nglove_embedding = WordEmbeddings('glove')\nflair_embedding_forward = FlairEmbeddings('news-forward')\nflair_embedding_backward = FlairEmbeddings('news-backward')\n\n# initialize the document embeddings, mode = mean\ndocument_embeddings = DocumentPoolEmbeddings([glove_embedding,\n flair_embedding_backward,\n flair_embedding_forward])",
"execution_count": 22,
"outputs": [
{
"output_type": "stream",
"text": "/Users/nikogamulin/workspace/pytorch-lab/venv/lib/python3.7/site-packages/smart_open/smart_open_lib.py:398: UserWarning: This function is deprecated, use smart_open.open instead. See the migration notes for details: https://github.com/RaRe-Technologies/smart_open/blob/master/README.rst#migrating-to-the-new-open-function\n 'See the migration notes for details: %s' % _MIGRATION_NOTES_URL\n",
"name": "stderr"
}
]
},
{
"metadata": {
"ExecuteTime": {
"start_time": "2019-06-20T10:13:42.452459Z",
"end_time": "2019-06-20T10:13:42.673425Z"
},
"trusted": true
},
"cell_type": "code",
"source": "# create an example sentence\nsentence = Sentence('The grass is green . And the sky is blue .')\n\n# embed the sentence with our document embedding\ndocument_embeddings.embed(sentence)\n\n# now check out the embedded sentence.\nprint(sentence.get_embedding())",
"execution_count": 23,
"outputs": [
{
"output_type": "stream",
"text": "tensor([-0.3197, 0.2621, 0.4037, ..., -0.0013, -0.0026, 0.0170],\n grad_fn=<CatBackward>)\n",
"name": "stdout"
}
]
},
{
"metadata": {
"ExecuteTime": {
"start_time": "2019-06-20T10:23:08.051327Z",
"end_time": "2019-06-20T10:23:08.054910Z"
},
"trusted": true
},
"cell_type": "code",
"source": "def cosine_similarity(a, b):\n cos_sim = dot(a, b)/(norm(a)*norm(b))\n return cos_sim",
"execution_count": 24,
"outputs": []
},
{
"metadata": {
"ExecuteTime": {
"start_time": "2019-06-20T10:25:00.584929Z",
"end_time": "2019-06-20T10:25:00.593291Z"
},
"trusted": true
},
"cell_type": "code",
"source": "sentence.get_embedding().data.numpy()",
"execution_count": 26,
"outputs": [
{
"output_type": "execute_result",
"execution_count": 26,
"data": {
"text/plain": "array([-0.31969544, 0.26205996, 0.4037069 , ..., -0.00134025,\n -0.00258876, 0.01702889], dtype=float32)"
},
"metadata": {}
}
]
},
{
"metadata": {
"ExecuteTime": {
"start_time": "2019-06-20T10:25:28.303332Z",
"end_time": "2019-06-20T10:25:28.306833Z"
},
"trusted": true
},
"cell_type": "code",
"source": "coupons = ['Buy one pizza, get other as a gift. Limited time! Available on all crusts',\n 'Get a free pair of socks for every friend you refer',\n '$1 mozarella sticks with purcase of an extra value meal',\n 'Swiss Watch Expo: $100 off + free shipping if you get this item now. This timepieve is yours for the next 24 hours',\n 'Dell: save 20% on your next purchase',\n 'Unlock Special Bonus for Premium Dog Food',\n 'shhh... here\\'s $8 off this purchase only',\n 'Sign up and receive 10% off your first undies purchase',\n 'Would you like 10% off your first purchase in any Wallmart store?',\n 'Check out now & receive 10% off your first food order',\n 'Get your device repaired today and save 15% at Radio Shack',\n 'Spend $150 and get free US standard delivery',\n 'Belk mobile extra $10, $20 or $30 off | ends Saturday']",
"execution_count": 27,
"outputs": []
},
{
"metadata": {
"ExecuteTime": {
"start_time": "2019-06-20T10:49:35.761554Z",
"end_time": "2019-06-20T10:49:39.271230Z"
},
"trusted": true
},
"cell_type": "code",
"source": "sentences = [Sentence(coupon, use_tokenizer=True) for coupon in coupons]\nfor sentence in sentences:\n document_embeddings.embed(sentence)",
"execution_count": 35,
"outputs": []
},
{
"metadata": {
"ExecuteTime": {
"start_time": "2019-06-20T10:49:51.513311Z",
"end_time": "2019-06-20T10:49:51.526272Z"
},
"trusted": true
},
"cell_type": "code",
"source": "max_similarities = {i: None for i in range(len(sentences))}\nfor i, sentence_i in enumerate(sentences):\n vec_i = sentence_i.get_embedding().data.numpy()\n for j, sentence_j in enumerate(sentences):\n vec_j = sentence_j.get_embedding().data.numpy()\n if i != j:\n similarity = cosine_similarity(vec_i, vec_j)\n if max_similarities[i] is None:\n max_similarities[i] = (j, similarity)\n if max_similarities[i][1] < similarity:\n max_similarities[i] = (j, similarity)",
"execution_count": 37,
"outputs": []
},
{
"metadata": {
"ExecuteTime": {
"start_time": "2019-06-20T10:49:55.049609Z",
"end_time": "2019-06-20T10:49:55.054361Z"
},
"trusted": true
},
"cell_type": "code",
"source": "max_similarities",
"execution_count": 38,
"outputs": [
{
"output_type": "execute_result",
"execution_count": 38,
"data": {
"text/plain": "{0: (3, 0.9133937),\n 1: (0, 0.8866179),\n 2: (0, 0.82934916),\n 3: (0, 0.9133937),\n 4: (7, 0.8908964),\n 5: (0, 0.7471674),\n 6: (3, 0.89610726),\n 7: (9, 0.9255427),\n 8: (3, 0.89906967),\n 9: (7, 0.9255427),\n 10: (3, 0.8647915),\n 11: (3, 0.8663579),\n 12: (11, 0.84938776)}"
},
"metadata": {}
}
]
},
{
"metadata": {
"ExecuteTime": {
"start_time": "2019-06-20T10:56:14.493637Z",
"end_time": "2019-06-20T10:56:15.769765Z"
},
"trusted": true
},
"cell_type": "code",
"source": "from flair.embeddings import WordEmbeddings, DocumentRNNEmbeddings\n\nglove_embedding = WordEmbeddings('glove')\n\ndocument_embeddings = DocumentRNNEmbeddings([glove_embedding])",
"execution_count": 39,
"outputs": [
{
"output_type": "stream",
"text": "/Users/nikogamulin/workspace/pytorch-lab/venv/lib/python3.7/site-packages/smart_open/smart_open_lib.py:398: UserWarning: This function is deprecated, use smart_open.open instead. See the migration notes for details: https://github.com/RaRe-Technologies/smart_open/blob/master/README.rst#migrating-to-the-new-open-function\n 'See the migration notes for details: %s' % _MIGRATION_NOTES_URL\n",
"name": "stderr"
}
]
},
{
"metadata": {
"ExecuteTime": {
"start_time": "2019-06-20T10:57:09.290210Z",
"end_time": "2019-06-20T10:57:09.341190Z"
},
"trusted": true
},
"cell_type": "code",
"source": "sentences = [Sentence(coupon, use_tokenizer=True) for coupon in coupons]\nfor sentence in sentences:\n document_embeddings.embed(sentence)",
"execution_count": 41,
"outputs": []
},
{
"metadata": {
"ExecuteTime": {
"start_time": "2019-06-20T10:57:24.553601Z",
"end_time": "2019-06-20T10:57:24.565270Z"
},
"trusted": true
},
"cell_type": "code",
"source": "max_similarities = {i: None for i in range(len(sentences))}\nfor i, sentence_i in enumerate(sentences):\n vec_i = sentence_i.get_embedding().data.numpy()\n for j, sentence_j in enumerate(sentences):\n vec_j = sentence_j.get_embedding().data.numpy()\n if i != j:\n similarity = cosine_similarity(vec_i, vec_j)\n if max_similarities[i] is None:\n max_similarities[i] = (j, similarity)\n if max_similarities[i][1] < similarity:\n max_similarities[i] = (j, similarity)",
"execution_count": 42,
"outputs": []
},
{
"metadata": {
"ExecuteTime": {
"start_time": "2019-06-20T10:57:33.242734Z",
"end_time": "2019-06-20T10:57:33.247518Z"
},
"trusted": true
},
"cell_type": "code",
"source": "max_similarities",
"execution_count": 43,
"outputs": [
{
"output_type": "execute_result",
"execution_count": 43,
"data": {
"text/plain": "{0: (1, 0.28898498),\n 1: (9, 0.3556913),\n 2: (8, 0.39637336),\n 3: (6, 0.40637004),\n 4: (9, 0.5623293),\n 5: (9, 0.39802852),\n 6: (3, 0.40637004),\n 7: (4, 0.37906048),\n 8: (2, 0.39637336),\n 9: (4, 0.5623293),\n 10: (9, 0.349115),\n 11: (9, 0.54480976),\n 12: (3, 0.3571025)}"
},
"metadata": {}
}
]
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "",
"execution_count": null,
"outputs": []
}
],
"metadata": {
"kernelspec": {
"name": "python3",
"display_name": "Python 3",
"language": "python"
},
"language_info": {
"name": "python",
"version": "3.7.3",
"mimetype": "text/x-python",
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"pygments_lexer": "ipython3",
"nbconvert_exporter": "python",
"file_extension": ".py"
},
"gist": {
"id": "",
"data": {
"description": "flair-test.ipynb",
"public": true
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment