Created
August 21, 2018 18:29
-
-
Save c0nn3r/b446ccaf241003632ccc7ca4fb2a37af to your computer and use it in GitHub Desktop.
Learned Positional Embedding Masking
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "import torch", | |
"execution_count": 1, | |
"outputs": [] | |
}, | |
{ | |
"metadata": {}, | |
"cell_type": "markdown", | |
"source": "#### Learned Embedding Masking" | |
}, | |
{ | |
"metadata": {}, | |
"cell_type": "markdown", | |
"source": "First, we construct a batch of `[batch (3), sequence (3), channel (4)]`. We want an end result" | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "example_batch = torch.Tensor([\n [[1, 2, 0, 0],\n [2, 1, 3, 0],\n [2, 1, 3, 0]]\n])", | |
"execution_count": 2, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"scrolled": true, | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "example_batch", | |
"execution_count": 3, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": "tensor([[[1., 2., 0., 0.],\n [2., 1., 3., 0.],\n [2., 1., 3., 0.]]])" | |
}, | |
"execution_count": 3, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "mask = (example_batch == 0)", | |
"execution_count": 4, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "mask", | |
"execution_count": 5, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": "tensor([[[0, 0, 1, 1],\n [0, 0, 0, 1],\n [0, 0, 0, 1]]], dtype=torch.uint8)" | |
}, | |
"execution_count": 5, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "example_batch.size(-1)", | |
"execution_count": 6, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": "4" | |
}, | |
"execution_count": 6, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "index_filled = torch.arange(1, example_batch.size(-1) + 1).expand_as(example_batch)", | |
"execution_count": 7, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"scrolled": false, | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "index_filled", | |
"execution_count": 8, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": "tensor([[[1, 2, 3, 4],\n [1, 2, 3, 4],\n [1, 2, 3, 4]]])" | |
}, | |
"execution_count": 8, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "index_filled.size()", | |
"execution_count": 9, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": "torch.Size([1, 3, 4])" | |
}, | |
"execution_count": 9, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "mask", | |
"execution_count": 10, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": "tensor([[[0, 0, 1, 1],\n [0, 0, 0, 1],\n [0, 0, 0, 1]]], dtype=torch.uint8)" | |
}, | |
"execution_count": 10, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "masked = index_filled.masked_fill_(mask, 0)", | |
"execution_count": 11, | |
"outputs": [] | |
}, | |
{ | |
"metadata": {}, | |
"cell_type": "markdown", | |
"source": "Our resulting mask result seems to have the last two columns filled..." | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "masked", | |
"execution_count": 12, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": "tensor([[[1, 2, 0, 0],\n [1, 2, 0, 0],\n [1, 2, 0, 0]]])" | |
}, | |
"execution_count": 12, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
] | |
}, | |
{ | |
"metadata": {}, | |
"cell_type": "markdown", | |
"source": "However, the end result should look like this:" | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "torch.Tensor([\n [[1, 2, 0, 0],\n [1, 2, 3, 0],\n [1, 2, 3, 0]]\n])", | |
"execution_count": 13, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": "tensor([[[1., 2., 0., 0.],\n [1., 2., 3., 0.],\n [1., 2., 3., 0.]]])" | |
}, | |
"execution_count": 13, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "", | |
"execution_count": null, | |
"outputs": [] | |
} | |
], | |
"metadata": { | |
"gist": { | |
"id": "", | |
"data": { | |
"description": "Learned Positional Embedding Masking", | |
"public": true | |
} | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3", | |
"language": "python" | |
}, | |
"language_info": { | |
"name": "python", | |
"version": "3.6.6", | |
"mimetype": "text/x-python", | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"pygments_lexer": "ipython3", | |
"nbconvert_exporter": "python", | |
"file_extension": ".py" | |
}, | |
"latex_envs": { | |
"eqNumInitial": 1, | |
"eqLabelWithNumbers": true, | |
"current_citInitial": 1, | |
"cite_by": "apalike", | |
"bibliofile": "biblio.bib", | |
"LaTeX_envs_menu_present": true, | |
"labels_anchors": false, | |
"latex_user_defs": false, | |
"user_envs_cfg": false, | |
"report_style_numbering": false, | |
"autoclose": false, | |
"autocomplete": true, | |
"hotkeys": { | |
"equation": "Ctrl-E", | |
"itemize": "Ctrl-I" | |
} | |
}, | |
"varInspector": { | |
"window_display": false, | |
"cols": { | |
"lenName": 16, | |
"lenType": 16, | |
"lenVar": 40 | |
}, | |
"kernels_config": { | |
"python": { | |
"library": "var_list.py", | |
"delete_cmd_prefix": "del ", | |
"delete_cmd_postfix": "", | |
"varRefreshCmd": "print(var_dic_list())" | |
}, | |
"r": { | |
"library": "var_list.r", | |
"delete_cmd_prefix": "rm(", | |
"delete_cmd_postfix": ") ", | |
"varRefreshCmd": "cat(var_dic_list()) " | |
} | |
}, | |
"types_to_exclude": [ | |
"module", | |
"function", | |
"builtin_function_or_method", | |
"instance", | |
"_Feature" | |
] | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment