Skip to content

Instantly share code, notes, and snippets.

@c0nn3r
Created August 21, 2018 18:29
Show Gist options
  • Save c0nn3r/b446ccaf241003632ccc7ca4fb2a37af to your computer and use it in GitHub Desktop.
Save c0nn3r/b446ccaf241003632ccc7ca4fb2a37af to your computer and use it in GitHub Desktop.
Learned Positional Embedding Masking
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "import torch",
"execution_count": 1,
"outputs": []
},
{
"metadata": {},
"cell_type": "markdown",
"source": "#### Learned Embedding Masking"
},
{
"metadata": {},
"cell_type": "markdown",
"source": "First, we construct a batch of `[batch (3), sequence (3), channel (4)]`. We want an end result"
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "example_batch = torch.Tensor([\n [[1, 2, 0, 0],\n [2, 1, 3, 0],\n [2, 1, 3, 0]]\n])",
"execution_count": 2,
"outputs": []
},
{
"metadata": {
"scrolled": true,
"trusted": true
},
"cell_type": "code",
"source": "example_batch",
"execution_count": 3,
"outputs": [
{
"data": {
"text/plain": "tensor([[[1., 2., 0., 0.],\n [2., 1., 3., 0.],\n [2., 1., 3., 0.]]])"
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "mask = (example_batch == 0)",
"execution_count": 4,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "mask",
"execution_count": 5,
"outputs": [
{
"data": {
"text/plain": "tensor([[[0, 0, 1, 1],\n [0, 0, 0, 1],\n [0, 0, 0, 1]]], dtype=torch.uint8)"
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "example_batch.size(-1)",
"execution_count": 6,
"outputs": [
{
"data": {
"text/plain": "4"
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "index_filled = torch.arange(1, example_batch.size(-1) + 1).expand_as(example_batch)",
"execution_count": 7,
"outputs": []
},
{
"metadata": {
"scrolled": false,
"trusted": true
},
"cell_type": "code",
"source": "index_filled",
"execution_count": 8,
"outputs": [
{
"data": {
"text/plain": "tensor([[[1, 2, 3, 4],\n [1, 2, 3, 4],\n [1, 2, 3, 4]]])"
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "index_filled.size()",
"execution_count": 9,
"outputs": [
{
"data": {
"text/plain": "torch.Size([1, 3, 4])"
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "mask",
"execution_count": 10,
"outputs": [
{
"data": {
"text/plain": "tensor([[[0, 0, 1, 1],\n [0, 0, 0, 1],\n [0, 0, 0, 1]]], dtype=torch.uint8)"
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "masked = index_filled.masked_fill_(mask, 0)",
"execution_count": 11,
"outputs": []
},
{
"metadata": {},
"cell_type": "markdown",
"source": "Our resulting mask result seems to have the last two columns filled..."
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "masked",
"execution_count": 12,
"outputs": [
{
"data": {
"text/plain": "tensor([[[1, 2, 0, 0],\n [1, 2, 0, 0],\n [1, 2, 0, 0]]])"
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": "However, the end result should look like this:"
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "torch.Tensor([\n [[1, 2, 0, 0],\n [1, 2, 3, 0],\n [1, 2, 3, 0]]\n])",
"execution_count": 13,
"outputs": [
{
"data": {
"text/plain": "tensor([[[1., 2., 0., 0.],\n [1., 2., 3., 0.],\n [1., 2., 3., 0.]]])"
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "",
"execution_count": null,
"outputs": []
}
],
"metadata": {
"gist": {
"id": "",
"data": {
"description": "Learned Positional Embedding Masking",
"public": true
}
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3",
"language": "python"
},
"language_info": {
"name": "python",
"version": "3.6.6",
"mimetype": "text/x-python",
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"pygments_lexer": "ipython3",
"nbconvert_exporter": "python",
"file_extension": ".py"
},
"latex_envs": {
"eqNumInitial": 1,
"eqLabelWithNumbers": true,
"current_citInitial": 1,
"cite_by": "apalike",
"bibliofile": "biblio.bib",
"LaTeX_envs_menu_present": true,
"labels_anchors": false,
"latex_user_defs": false,
"user_envs_cfg": false,
"report_style_numbering": false,
"autoclose": false,
"autocomplete": true,
"hotkeys": {
"equation": "Ctrl-E",
"itemize": "Ctrl-I"
}
},
"varInspector": {
"window_display": false,
"cols": {
"lenName": 16,
"lenType": 16,
"lenVar": 40
},
"kernels_config": {
"python": {
"library": "var_list.py",
"delete_cmd_prefix": "del ",
"delete_cmd_postfix": "",
"varRefreshCmd": "print(var_dic_list())"
},
"r": {
"library": "var_list.r",
"delete_cmd_prefix": "rm(",
"delete_cmd_postfix": ") ",
"varRefreshCmd": "cat(var_dic_list()) "
}
},
"types_to_exclude": [
"module",
"function",
"builtin_function_or_method",
"instance",
"_Feature"
]
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment