Skip to content

Instantly share code, notes, and snippets.

@wassname
Created September 2, 2018 00:54
Show Gist options
  • Save wassname/1997ecaaa0b099c7618ea8c798c6489b to your computer and use it in GitHub Desktop.
Save wassname/1997ecaaa0b099c7618ea8c798c6489b to your computer and use it in GitHub Desktop.
example of sampling_multiple_ordered_seqs
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 50,
"metadata": {
"ExecuteTime": {
"end_time": "2018-09-02T00:53:14.764185Z",
"start_time": "2018-09-02T00:53:14.758997Z"
}
},
"outputs": [],
"source": [
"import torch\n",
"pm = torch.rand(10000)\n",
"agent_memory_length=5"
]
},
{
"cell_type": "code",
"execution_count": 51,
"metadata": {
"ExecuteTime": {
"end_time": "2018-09-02T00:53:14.954330Z",
"start_time": "2018-09-02T00:53:14.950327Z"
}
},
"outputs": [],
"source": [
"batch_start = torch.randint(0, len(pm), (4,)).long()"
]
},
{
"cell_type": "code",
"execution_count": 54,
"metadata": {
"ExecuteTime": {
"end_time": "2018-09-02T00:53:37.720369Z",
"start_time": "2018-09-02T00:53:37.712621Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"batch_indices (multiple ordered sequences) tensor([[ 1022, 1023, 1024, 1025, 1026],\n",
" [ 2701, 2702, 2703, 2704, 2705],\n",
" [ 4340, 4341, 4342, 4343, 4344],\n",
" [ 4538, 4539, 4540, 4541, 4542]])\n"
]
}
],
"source": [
"batch_indices = torch.stack([torch.arange(batch_start[i], batch_start[i]+agent_memory_length) for i in range(len(batch_start))]).long()\n",
"print('batch_indices (multiple ordered sequences)', batch_indices)"
]
},
{
"cell_type": "code",
"execution_count": 53,
"metadata": {
"ExecuteTime": {
"end_time": "2018-09-02T00:53:15.241645Z",
"start_time": "2018-09-02T00:53:15.234526Z"
}
},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[ 0.5352, 0.5940, 0.8537, 0.8471, 0.8296],\n",
" [ 0.5125, 0.4726, 0.7054, 0.2765, 0.6027],\n",
" [ 0.7839, 0.8352, 0.3167, 0.1516, 0.1262],\n",
" [ 0.8630, 0.1976, 0.1147, 0.0442, 0.8489]])"
]
},
"execution_count": 53,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pm[batch_indices]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "jupyter3.6",
"language": "python",
"name": "jupyter3.6"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.0"
},
"toc": {
"colors": {
"hover_highlight": "#DAA520",
"navigate_num": "#000000",
"navigate_text": "#333333",
"running_highlight": "#FF0000",
"selected_highlight": "#FFD700",
"sidebar_border": "#EEEEEE",
"wrapper_background": "#FFFFFF"
},
"moveMenuLeft": true,
"nav_menu": {
"height": "12px",
"width": "252px"
},
"navigate_menu": true,
"number_sections": true,
"sideBar": false,
"threshold": 4,
"toc_cell": false,
"toc_section_display": "block",
"toc_window_display": true,
"widenNotebook": false
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment