Skip to content

Instantly share code, notes, and snippets.

@velikodniy
Created January 21, 2020 10:29
Show Gist options
  • Save velikodniy/6efef837e67aee2e7152eb5900eb0258 to your computer and use it in GitHub Desktop.
Save velikodniy/6efef837e67aee2e7152eb5900eb0258 to your computer and use it in GitHub Desktop.
np.delete ported to PyTorch
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"source": [
"import torch"
],
"outputs": [],
"execution_count": 1,
"metadata": {
"collapsed": false,
"outputHidden": false,
"inputHidden": false
}
},
{
"cell_type": "code",
"source": [
"x = torch.arange(0, 9).view(3, 3)\n",
"print(x)"
],
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"tensor([[0, 1, 2],\n",
" [3, 4, 5],\n",
" [6, 7, 8]])\n"
]
}
],
"execution_count": 2,
"metadata": {
"collapsed": false,
"outputHidden": false,
"inputHidden": false
}
},
{
"cell_type": "code",
"source": [
"def delete(arr: torch.Tensor, ind: int, dim: int) -> torch.Tensor:\n",
" skip = [i for i in range(arr.size(dim)) if i != ind]\n",
" indices = [slice(None) if i != dim else skip for i in range(arr.ndim)]\n",
" return arr.__getitem__(indices)"
],
"outputs": [],
"execution_count": 3,
"metadata": {
"collapsed": false,
"outputHidden": false,
"inputHidden": false
}
},
{
"cell_type": "code",
"source": [
"delete(x, 1, 1)"
],
"outputs": [
{
"output_type": "execute_result",
"execution_count": 4,
"data": {
"text/plain": [
"tensor([[0, 2],\n",
" [3, 5],\n",
" [6, 8]])"
]
},
"metadata": {}
}
],
"execution_count": 4,
"metadata": {
"collapsed": false,
"outputHidden": false,
"inputHidden": false
}
},
{
"cell_type": "code",
"source": [],
"outputs": [],
"execution_count": 5,
"metadata": {
"collapsed": false,
"outputHidden": false,
"inputHidden": false
}
}
],
"metadata": {
"kernel_info": {
"name": "python3"
},
"language_info": {
"name": "python",
"version": "3.7.4",
"mimetype": "text/x-python",
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"pygments_lexer": "ipython3",
"nbconvert_exporter": "python",
"file_extension": ".py"
},
"kernelspec": {
"name": "python3",
"language": "python",
"display_name": "Python 3"
},
"nteract": {
"version": "0.15.0"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment